use serde::{Deserialize, Serialize};
use crate::{search::*, util::*};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(remote = "Self")]
pub struct KnnQuery {
field: String,
#[serde(default)]
query_vector: Vec<f32>,
#[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
num_candidates: Option<u32>,
#[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
filter: Option<Box<Query>>,
#[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
similarity: Option<f32>,
#[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
boost: Option<f32>,
#[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
_name: Option<String>,
}
impl KnnQuery {
add_boost_and_name!();
pub fn num_candidates(mut self, num_candidates: u32) -> Self {
self.num_candidates = Some(num_candidates);
self
}
pub fn filter<T>(mut self, filter: T) -> Self
where
T: Into<Query>,
{
self.filter = Some(Box::new(filter.into()));
self
}
pub fn similarity(mut self, similarity: f32) -> Self {
self.similarity = Some(similarity);
self
}
}
impl ShouldSkip for KnnQuery {}
serialize_with_root!("knn": KnnQuery);
deserialize_with_root!("knn": KnnQuery);
impl Query {
pub fn knn<T>(field: T, query_vector: Vec<f32>) -> KnnQuery
where
T: ToString,
{
KnnQuery {
field: field.to_string(),
query_vector,
num_candidates: None,
filter: None,
similarity: None,
boost: None,
_name: None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn serialization() {
assert_serialize_query(
Query::knn("test", vec![1.0, 2.0, 3.0]),
json!({
"knn": {
"field": "test",
"query_vector": [1.0, 2.0, 3.0]
}
}),
);
assert_serialize_query(
Query::knn("test", vec![1.0, 2.0, 3.0])
.num_candidates(100)
.filter(Query::term("field", "value"))
.similarity(0.5)
.boost(2.0)
.name("test"),
json!({
"knn": {
"field": "test",
"query_vector": [1.0, 2.0, 3.0],
"num_candidates": 100,
"filter": {
"term": {
"field": {
"value": "value"
}
}
},
"similarity": 0.5,
"boost": 2.0,
"_name": "test"
}
}),
);
}
}