use serde::Serialize;
use crate::{search::*, util::*};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Knn {
field: String,
#[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
query_vector: Option<Vec<f32>>,
#[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
query_vector_builder: Option<QueryVectorBuilder>,
#[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
k: Option<u32>,
#[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
num_candidates: Option<u32>,
#[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
filter: Option<Box<Query>>,
#[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
similarity: Option<f32>,
#[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
boost: Option<f32>,
#[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
_name: Option<String>,
}
impl Knn {
add_boost_and_name!();
pub fn query_vector<T>(field: T, query_vector: Vec<f32>) -> Self
where
T: ToString,
{
Self {
field: field.to_string(),
query_vector: Some(query_vector),
query_vector_builder: None,
k: None,
num_candidates: None,
filter: None,
similarity: None,
boost: None,
_name: None,
}
}
pub fn query_vector_builder<T, U>(field: T, query_vector_builder: U) -> Self
where
T: ToString,
U: Into<QueryVectorBuilder>,
{
Self {
field: field.to_string(),
query_vector: None,
query_vector_builder: Some(query_vector_builder.into()),
k: None,
num_candidates: None,
filter: None,
similarity: None,
boost: None,
_name: None,
}
}
pub fn k(mut self, k: u32) -> Self {
self.k = Some(k);
self
}
pub fn num_candidates(mut self, num_candidates: u32) -> Self {
self.num_candidates = Some(num_candidates);
self
}
pub fn filter<T>(mut self, filter: T) -> Self
where
T: Into<Query>,
{
self.filter = Some(Box::new(filter.into()));
self
}
pub fn similarity(mut self, similarity: f32) -> Self {
self.similarity = Some(similarity);
self
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum QueryVectorBuilder {
TextEmbedding(TextEmbedding),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct TextEmbedding {
model_id: String,
model_text: String,
}
impl From<TextEmbedding> for QueryVectorBuilder {
fn from(embedding: TextEmbedding) -> Self {
Self::TextEmbedding(embedding)
}
}
impl TextEmbedding {
pub fn new<T, U>(model_id: T, model_text: U) -> Self
where
T: ToString,
U: ToString,
{
Self {
model_id: model_id.to_string(),
model_text: model_text.to_string(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn serialization() {
assert_serialize(
Search::new()
.knn(Knn::query_vector("test1", vec![1.0, 2.0, 3.0]))
.knn(
Knn::query_vector("test2", vec![4.0, 5.0, 6.0])
.k(3)
.num_candidates(100)
.filter(Query::term("field", "value"))
.similarity(0.5)
.boost(2.0)
.name("test2"),
)
.knn(Knn::query_vector_builder(
"test3",
TextEmbedding::new("my-text-embedding-model", "The opposite of pink"),
))
.knn(
Knn::query_vector_builder(
"test4",
TextEmbedding::new("my-text-embedding-model", "The opposite of blue"),
)
.k(5)
.num_candidates(200)
.filter(Query::term("field", "value"))
.similarity(0.7)
.boost(2.1)
.name("test4"),
),
json!({
"knn": [
{
"field": "test1",
"query_vector": [1.0, 2.0, 3.0]
},
{
"field": "test2",
"query_vector": [4.0, 5.0, 6.0],
"k": 3,
"num_candidates": 100,
"filter": {
"term": {
"field": {
"value": "value"
}
}
},
"similarity": 0.5,
"boost": 2.0,
"_name": "test2"
},
{
"field": "test3",
"query_vector_builder": {
"text_embedding": {
"model_id": "my-text-embedding-model",
"model_text": "The opposite of pink"
}
}
},
{
"field": "test4",
"query_vector_builder": {
"text_embedding": {
"model_id": "my-text-embedding-model",
"model_text": "The opposite of blue"
}
},
"k": 5,
"num_candidates": 200,
"filter": {
"term": {
"field": {
"value": "value"
}
}
},
"similarity": 0.7,
"boost": 2.1,
"_name": "test4"
}
]
}),
);
}
}