tcvectordb 0.1.9

Rust SDK for Tencent Cloud VectorDB
Documentation
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
use std::collections::HashMap;
use crate::index::SparseVector;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HNSWSearchParams {
    pub ef: u32,
}

impl HNSWSearchParams {
    pub fn new(ef: u32) -> Self {
        Self { ef }
    }
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchParams {
    #[serde(skip_serializing_if = "Option::is_none")]
    pub ef: Option<u32>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub nprobe: Option<u32>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub radius: Option<f64>,
}

impl SearchParams {
    pub fn new() -> Self {
        Self {
            ef: None,
            nprobe: None,
            radius: None,
        }
    }

    pub fn with_ef(mut self, ef: u32) -> Self {
        self.ef = Some(ef);
        self
    }

    pub fn with_nprobe(mut self, nprobe: u32) -> Self {
        self.nprobe = Some(nprobe);
        self
    }

    pub fn with_radius(mut self, radius: f64) -> Self {
        self.radius = Some(radius);
        self
    }
}

impl Default for SearchParams {
    fn default() -> Self {
        Self::new()
    }
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AnnSearch {
    #[serde(rename = "fieldName", skip_serializing_if = "Option::is_none")]
    pub field_name: Option<String>,
    #[serde(rename = "documentIds", skip_serializing_if = "Option::is_none")]
    pub document_ids: Option<Vec<String>>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub data: Option<AnnSearchData>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub params: Option<SearchParams>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub limit: Option<u32>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum AnnSearchData {
    Vectors(Vec<Vec<f64>>),
    Text(String),
    TextList(Vec<String>),
}

impl AnnSearch {
    pub fn new() -> Self {
        Self {
            field_name: Some("vector".to_string()),
            document_ids: None,
            data: None,
            params: None,
            limit: None,
        }
    }

    pub fn with_field_name(mut self, field_name: impl Into<String>) -> Self {
        self.field_name = Some(field_name.into());
        self
    }

    pub fn with_document_ids(mut self, document_ids: Vec<String>) -> Self {
        self.document_ids = Some(document_ids);
        self
    }

    pub fn with_data(mut self, data: Vec<Vec<f64>>) -> Self {
        self.data = Some(AnnSearchData::Vectors(data));
        self
    }

    pub fn with_text(mut self, text: impl Into<String>) -> Self {
        self.data = Some(AnnSearchData::Text(text.into()));
        self
    }

    pub fn with_text_list(mut self, texts: Vec<String>) -> Self {
        self.data = Some(AnnSearchData::TextList(texts));
        self
    }

    pub fn with_params(mut self, params: SearchParams) -> Self {
        self.params = Some(params);
        self
    }

    pub fn with_limit(mut self, limit: u32) -> Self {
        self.limit = Some(limit);
        self
    }
}

impl Default for AnnSearch {
    fn default() -> Self {
        Self::new()
    }
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KeywordSearch {
    #[serde(rename = "fieldName", skip_serializing_if = "Option::is_none")]
    pub field_name: Option<String>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub data: Option<Vec<SparseVector>>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub limit: Option<u32>,
    #[serde(rename = "terminateAfter", skip_serializing_if = "Option::is_none")]
    pub terminate_after: Option<u32>,
    #[serde(rename = "cutoffFrequency", skip_serializing_if = "Option::is_none")]
    pub cutoff_frequency: Option<f64>,
}

impl KeywordSearch {
    pub fn new() -> Self {
        Self {
            field_name: Some("sparse_vector".to_string()),
            data: None,
            limit: None,
            terminate_after: None,
            cutoff_frequency: None,
        }
    }

    pub fn with_field_name(mut self, field_name: impl Into<String>) -> Self {
        self.field_name = Some(field_name.into());
        self
    }

    pub fn with_data(mut self, data: Vec<SparseVector>) -> Self {
        self.data = Some(data);
        self
    }

    pub fn with_limit(mut self, limit: u32) -> Self {
        self.limit = Some(limit);
        self
    }

    pub fn with_terminate_after(mut self, terminate_after: u32) -> Self {
        self.terminate_after = Some(terminate_after);
        self
    }

    pub fn with_cutoff_frequency(mut self, cutoff_frequency: f64) -> Self {
        self.cutoff_frequency = Some(cutoff_frequency);
        self
    }
}

impl Default for KeywordSearch {
    fn default() -> Self {
        Self::new()
    }
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "method")]
pub enum Rerank {
    #[serde(rename = "weighted")]
    Weighted {
        #[serde(rename = "fieldList", skip_serializing_if = "Option::is_none")]
        field_list: Option<Vec<String>>,
        #[serde(skip_serializing_if = "Option::is_none")]
        weight: Option<Vec<f64>>,
    },
    #[serde(rename = "rrf")]
    RRF {
        #[serde(skip_serializing_if = "Option::is_none")]
        k: Option<u32>,
    },
}

impl Rerank {
    pub fn weighted(field_list: Vec<String>, weight: Vec<f64>) -> Self {
        Self::Weighted {
            field_list: Some(field_list),
            weight: Some(Self::normalize_weights(weight)),
        }
    }

    pub fn rrf(k: u32) -> Self {
        Self::RRF { k: Some(k) }
    }

    fn normalize_weights(weights: Vec<f64>) -> Vec<f64> {
        let total: f64 = weights.iter().sum();
        if total == 0.0 {
            return weights;
        }

        let all_zero = weights.iter().all(|&w| w == 0.0);
        if all_zero {
            return weights;
        }

        let has_negative = weights.iter().any(|&w| w < 0.0);
        if has_negative {
            return weights;
        }

        weights.iter().map(|&w| w / total).collect()
    }
}

#[derive(Debug, Clone, Default)]
pub struct Document {
    data: Map<String, Value>,
    score: Option<f64>,
}

impl Document {
    pub fn new() -> Self {
        Self::default()
    }

    pub fn with_id(mut self, id: impl Into<String>) -> Self {
        self.data.insert("id".to_string(), Value::String(id.into()));
        self
    }

    pub fn with_vector(mut self, vector: Vec<f64>) -> Self {
        self.data.insert("vector".to_string(), serde_json::to_value(vector).unwrap());
        self
    }

    pub fn with_field(mut self, key: impl Into<String>, value: impl Into<Value>) -> Self {
        self.data.insert(key.into(), value.into());
        self
    }

    pub fn with_score(mut self, score: f64) -> Self {
        self.score = Some(score);
        self
    }

    pub fn get(&self, key: &str) -> Option<&Value> {
        self.data.get(key)
    }

    pub fn get_id(&self) -> Option<&str> {
        self.data.get("id")?.as_str()
    }

    pub fn get_vector(&self) -> Option<Vec<f64>> {
        let vector_value = self.data.get("vector")?;
        serde_json::from_value(vector_value.clone()).ok()
    }

    pub fn get_score(&self) -> Option<f64> {
        self.score
    }

    pub fn insert(&mut self, key: impl Into<String>, value: impl Into<Value>) {
        self.data.insert(key.into(), value.into());
    }

    pub fn remove(&mut self, key: &str) -> Option<Value> {
        self.data.remove(key)
    }

    pub fn keys(&self) -> impl Iterator<Item = &String> {
        self.data.keys()
    }

    pub fn values(&self) -> impl Iterator<Item = &Value> {
        self.data.values()
    }

    pub fn iter(&self) -> impl Iterator<Item = (&String, &Value)> {
        self.data.iter()
    }

    pub fn is_empty(&self) -> bool {
        self.data.is_empty()
    }

    pub fn len(&self) -> usize {
        self.data.len()
    }
}

impl Serialize for Document {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: serde::Serializer,
    {
        self.data.serialize(serializer)
    }
}

impl<'de> Deserialize<'de> for Document {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: serde::Deserializer<'de>,
    {
        let mut data: Map<String, Value> = Map::deserialize(deserializer)?;
        let score = data.remove("score").and_then(|v| v.as_f64());
        
        Ok(Self { data, score })
    }
}

impl From<HashMap<String, Value>> for Document {
    fn from(map: HashMap<String, Value>) -> Self {
        let mut data = Map::new();
        let mut score = None;
        
        for (k, v) in map {
            if k == "score" {
                score = v.as_f64();
            } else {
                data.insert(k, v);
            }
        }
        
        Self { data, score }
    }
}

impl From<Map<String, Value>> for Document {
    fn from(mut data: Map<String, Value>) -> Self {
        let score = data.remove("score").and_then(|v| v.as_f64());
        Self { data, score }
    }
}