hehe-store 0.0.1

Unified storage abstraction layer for hehe AI Agent framework
Documentation
use crate::error::Result;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct VectorRecord {
    pub id: String,
    pub vector: Vec<f32>,
    #[serde(default)]
    pub metadata: HashMap<String, Value>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub content: Option<String>,
}

impl VectorRecord {
    pub fn new(id: impl Into<String>, vector: Vec<f32>) -> Self {
        Self {
            id: id.into(),
            vector,
            metadata: HashMap::new(),
            content: None,
        }
    }

    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<Value>) -> Self {
        self.metadata.insert(key.into(), value.into());
        self
    }

    pub fn with_content(mut self, content: impl Into<String>) -> Self {
        self.content = Some(content.into());
        self
    }
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SearchResult {
    pub id: String,
    pub score: f32,
    #[serde(default)]
    pub metadata: HashMap<String, Value>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub content: Option<String>,
}

#[derive(Clone, Debug, Default)]
pub struct VectorFilter {
    pub conditions: Vec<FilterCondition>,
}

#[derive(Clone, Debug)]
pub enum FilterCondition {
    Eq(String, Value),
    Ne(String, Value),
    Gt(String, Value),
    Gte(String, Value),
    Lt(String, Value),
    Lte(String, Value),
    In(String, Vec<Value>),
    Contains(String, String),
}

impl VectorFilter {
    pub fn new() -> Self {
        Self::default()
    }

    pub fn eq(mut self, field: impl Into<String>, value: impl Into<Value>) -> Self {
        self.conditions
            .push(FilterCondition::Eq(field.into(), value.into()));
        self
    }

    pub fn ne(mut self, field: impl Into<String>, value: impl Into<Value>) -> Self {
        self.conditions
            .push(FilterCondition::Ne(field.into(), value.into()));
        self
    }

    pub fn gt(mut self, field: impl Into<String>, value: impl Into<Value>) -> Self {
        self.conditions
            .push(FilterCondition::Gt(field.into(), value.into()));
        self
    }

    pub fn lt(mut self, field: impl Into<String>, value: impl Into<Value>) -> Self {
        self.conditions
            .push(FilterCondition::Lt(field.into(), value.into()));
        self
    }

    pub fn contains(mut self, field: impl Into<String>, value: impl Into<String>) -> Self {
        self.conditions
            .push(FilterCondition::Contains(field.into(), value.into()));
        self
    }

    pub fn is_empty(&self) -> bool {
        self.conditions.is_empty()
    }
}

#[derive(Clone, Debug)]
pub struct CollectionInfo {
    pub name: String,
    pub dimension: usize,
    pub count: usize,
}

#[async_trait]
pub trait VectorStore: Send + Sync {
    async fn create_collection(&self, name: &str, dimension: usize) -> Result<()>;

    async fn delete_collection(&self, name: &str) -> Result<()>;

    async fn list_collections(&self) -> Result<Vec<CollectionInfo>>;

    async fn collection_exists(&self, name: &str) -> Result<bool>;

    async fn upsert(&self, collection: &str, records: &[VectorRecord]) -> Result<usize>;

    async fn search(
        &self,
        collection: &str,
        query: &[f32],
        limit: usize,
    ) -> Result<Vec<SearchResult>>;

    async fn search_with_filter(
        &self,
        collection: &str,
        query: &[f32],
        filter: &VectorFilter,
        limit: usize,
    ) -> Result<Vec<SearchResult>>;

    async fn get(&self, collection: &str, id: &str) -> Result<Option<VectorRecord>>;

    async fn delete(&self, collection: &str, ids: &[String]) -> Result<usize>;

    async fn count(&self, collection: &str) -> Result<usize>;

    fn backend_name(&self) -> &'static str;
}

pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
    if a.len() != b.len() {
        return 0.0;
    }

    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();

    if norm_a == 0.0 || norm_b == 0.0 {
        return 0.0;
    }

    dot / (norm_a * norm_b)
}

pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
    if a.len() != b.len() {
        return f32::MAX;
    }

    a.iter()
        .zip(b.iter())
        .map(|(x, y)| (x - y).powi(2))
        .sum::<f32>()
        .sqrt()
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_vector_record() {
        let record = VectorRecord::new("id1", vec![0.1, 0.2, 0.3])
            .with_metadata("key", "value")
            .with_content("some text");

        assert_eq!(record.id, "id1");
        assert_eq!(record.vector.len(), 3);
        assert!(record.metadata.contains_key("key"));
        assert_eq!(record.content, Some("some text".into()));
    }

    #[test]
    fn test_vector_filter() {
        let filter = VectorFilter::new()
            .eq("type", "article")
            .gt("score", 0.5);

        assert_eq!(filter.conditions.len(), 2);
        assert!(!filter.is_empty());
    }

    #[test]
    fn test_cosine_similarity() {
        let a = vec![1.0, 0.0, 0.0];
        let b = vec![1.0, 0.0, 0.0];
        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.0001);

        let c = vec![0.0, 1.0, 0.0];
        assert!((cosine_similarity(&a, &c) - 0.0).abs() < 0.0001);

        let d = vec![-1.0, 0.0, 0.0];
        assert!((cosine_similarity(&a, &d) - (-1.0)).abs() < 0.0001);
    }

    #[test]
    fn test_euclidean_distance() {
        let a = vec![0.0, 0.0, 0.0];
        let b = vec![1.0, 0.0, 0.0];
        assert!((euclidean_distance(&a, &b) - 1.0).abs() < 0.0001);

        let c = vec![3.0, 4.0, 0.0];
        assert!((euclidean_distance(&a, &c) - 5.0).abs() < 0.0001);
    }
}