use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
use std::collections::HashMap;
use crate::index::SparseVector;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HNSWSearchParams {
pub ef: u32,
}
impl HNSWSearchParams {
pub fn new(ef: u32) -> Self {
Self { ef }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchParams {
#[serde(skip_serializing_if = "Option::is_none")]
pub ef: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub nprobe: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub radius: Option<f64>,
}
impl SearchParams {
pub fn new() -> Self {
Self {
ef: None,
nprobe: None,
radius: None,
}
}
pub fn with_ef(mut self, ef: u32) -> Self {
self.ef = Some(ef);
self
}
pub fn with_nprobe(mut self, nprobe: u32) -> Self {
self.nprobe = Some(nprobe);
self
}
pub fn with_radius(mut self, radius: f64) -> Self {
self.radius = Some(radius);
self
}
}
impl Default for SearchParams {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AnnSearch {
#[serde(rename = "fieldName", skip_serializing_if = "Option::is_none")]
pub field_name: Option<String>,
#[serde(rename = "documentIds", skip_serializing_if = "Option::is_none")]
pub document_ids: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<AnnSearchData>,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<SearchParams>,
#[serde(skip_serializing_if = "Option::is_none")]
pub limit: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum AnnSearchData {
Vectors(Vec<Vec<f64>>),
Text(String),
TextList(Vec<String>),
}
impl AnnSearch {
pub fn new() -> Self {
Self {
field_name: Some("vector".to_string()),
document_ids: None,
data: None,
params: None,
limit: None,
}
}
pub fn with_field_name(mut self, field_name: impl Into<String>) -> Self {
self.field_name = Some(field_name.into());
self
}
pub fn with_document_ids(mut self, document_ids: Vec<String>) -> Self {
self.document_ids = Some(document_ids);
self
}
pub fn with_data(mut self, data: Vec<Vec<f64>>) -> Self {
self.data = Some(AnnSearchData::Vectors(data));
self
}
pub fn with_text(mut self, text: impl Into<String>) -> Self {
self.data = Some(AnnSearchData::Text(text.into()));
self
}
pub fn with_text_list(mut self, texts: Vec<String>) -> Self {
self.data = Some(AnnSearchData::TextList(texts));
self
}
pub fn with_params(mut self, params: SearchParams) -> Self {
self.params = Some(params);
self
}
pub fn with_limit(mut self, limit: u32) -> Self {
self.limit = Some(limit);
self
}
}
impl Default for AnnSearch {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KeywordSearch {
#[serde(rename = "fieldName", skip_serializing_if = "Option::is_none")]
pub field_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<Vec<SparseVector>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub limit: Option<u32>,
#[serde(rename = "terminateAfter", skip_serializing_if = "Option::is_none")]
pub terminate_after: Option<u32>,
#[serde(rename = "cutoffFrequency", skip_serializing_if = "Option::is_none")]
pub cutoff_frequency: Option<f64>,
}
impl KeywordSearch {
pub fn new() -> Self {
Self {
field_name: Some("sparse_vector".to_string()),
data: None,
limit: None,
terminate_after: None,
cutoff_frequency: None,
}
}
pub fn with_field_name(mut self, field_name: impl Into<String>) -> Self {
self.field_name = Some(field_name.into());
self
}
pub fn with_data(mut self, data: Vec<SparseVector>) -> Self {
self.data = Some(data);
self
}
pub fn with_limit(mut self, limit: u32) -> Self {
self.limit = Some(limit);
self
}
pub fn with_terminate_after(mut self, terminate_after: u32) -> Self {
self.terminate_after = Some(terminate_after);
self
}
pub fn with_cutoff_frequency(mut self, cutoff_frequency: f64) -> Self {
self.cutoff_frequency = Some(cutoff_frequency);
self
}
}
impl Default for KeywordSearch {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "method")]
pub enum Rerank {
#[serde(rename = "weighted")]
Weighted {
#[serde(rename = "fieldList", skip_serializing_if = "Option::is_none")]
field_list: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
weight: Option<Vec<f64>>,
},
#[serde(rename = "rrf")]
RRF {
#[serde(skip_serializing_if = "Option::is_none")]
k: Option<u32>,
},
}
impl Rerank {
pub fn weighted(field_list: Vec<String>, weight: Vec<f64>) -> Self {
Self::Weighted {
field_list: Some(field_list),
weight: Some(Self::normalize_weights(weight)),
}
}
pub fn rrf(k: u32) -> Self {
Self::RRF { k: Some(k) }
}
fn normalize_weights(weights: Vec<f64>) -> Vec<f64> {
let total: f64 = weights.iter().sum();
if total == 0.0 {
return weights;
}
let all_zero = weights.iter().all(|&w| w == 0.0);
if all_zero {
return weights;
}
let has_negative = weights.iter().any(|&w| w < 0.0);
if has_negative {
return weights;
}
weights.iter().map(|&w| w / total).collect()
}
}
#[derive(Debug, Clone, Default)]
pub struct Document {
data: Map<String, Value>,
score: Option<f64>,
}
impl Document {
pub fn new() -> Self {
Self::default()
}
pub fn with_id(mut self, id: impl Into<String>) -> Self {
self.data.insert("id".to_string(), Value::String(id.into()));
self
}
pub fn with_vector(mut self, vector: Vec<f64>) -> Self {
self.data.insert("vector".to_string(), serde_json::to_value(vector).unwrap());
self
}
pub fn with_field(mut self, key: impl Into<String>, value: impl Into<Value>) -> Self {
self.data.insert(key.into(), value.into());
self
}
pub fn with_score(mut self, score: f64) -> Self {
self.score = Some(score);
self
}
pub fn get(&self, key: &str) -> Option<&Value> {
self.data.get(key)
}
pub fn get_id(&self) -> Option<&str> {
self.data.get("id")?.as_str()
}
pub fn get_vector(&self) -> Option<Vec<f64>> {
let vector_value = self.data.get("vector")?;
serde_json::from_value(vector_value.clone()).ok()
}
pub fn get_score(&self) -> Option<f64> {
self.score
}
pub fn insert(&mut self, key: impl Into<String>, value: impl Into<Value>) {
self.data.insert(key.into(), value.into());
}
pub fn remove(&mut self, key: &str) -> Option<Value> {
self.data.remove(key)
}
pub fn keys(&self) -> impl Iterator<Item = &String> {
self.data.keys()
}
pub fn values(&self) -> impl Iterator<Item = &Value> {
self.data.values()
}
pub fn iter(&self) -> impl Iterator<Item = (&String, &Value)> {
self.data.iter()
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
pub fn len(&self) -> usize {
self.data.len()
}
}
impl Serialize for Document {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.data.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for Document {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let mut data: Map<String, Value> = Map::deserialize(deserializer)?;
let score = data.remove("score").and_then(|v| v.as_f64());
Ok(Self { data, score })
}
}
impl From<HashMap<String, Value>> for Document {
fn from(map: HashMap<String, Value>) -> Self {
let mut data = Map::new();
let mut score = None;
for (k, v) in map {
if k == "score" {
score = v.as_f64();
} else {
data.insert(k, v);
}
}
Self { data, score }
}
}
impl From<Map<String, Value>> for Document {
fn from(mut data: Map<String, Value>) -> Self {
let score = data.remove("score").and_then(|v| v.as_f64());
Self { data, score }
}
}