hehe-store 0.0.1

Unified storage abstraction layer for hehe AI Agent framework
Documentation
use crate::error::{Result, StoreError};
use crate::traits::{
    cosine_similarity, CollectionInfo, SearchResult, VectorFilter, VectorRecord, VectorStore,
};
use async_trait::async_trait;
use parking_lot::RwLock;
use serde_json::Value;
use std::collections::HashMap;

struct Collection {
    dimension: usize,
    records: HashMap<String, VectorRecord>,
}

pub struct MemoryVectorStore {
    collections: RwLock<HashMap<String, Collection>>,
}

impl MemoryVectorStore {
    pub fn new() -> Self {
        Self {
            collections: RwLock::new(HashMap::new()),
        }
    }
}

impl Default for MemoryVectorStore {
    fn default() -> Self {
        Self::new()
    }
}

fn matches_filter(record: &VectorRecord, filter: &VectorFilter) -> bool {
    use crate::traits::vector::FilterCondition;

    for condition in &filter.conditions {
        let matched = match condition {
            FilterCondition::Eq(field, value) => {
                record.metadata.get(field).map(|v| v == value).unwrap_or(false)
            }
            FilterCondition::Ne(field, value) => {
                record.metadata.get(field).map(|v| v != value).unwrap_or(true)
            }
            FilterCondition::Gt(field, value) => match (record.metadata.get(field), value) {
                (Some(Value::Number(a)), Value::Number(b)) => {
                    a.as_f64().unwrap_or(0.0) > b.as_f64().unwrap_or(0.0)
                }
                _ => false,
            },
            FilterCondition::Gte(field, value) => match (record.metadata.get(field), value) {
                (Some(Value::Number(a)), Value::Number(b)) => {
                    a.as_f64().unwrap_or(0.0) >= b.as_f64().unwrap_or(0.0)
                }
                _ => false,
            },
            FilterCondition::Lt(field, value) => match (record.metadata.get(field), value) {
                (Some(Value::Number(a)), Value::Number(b)) => {
                    a.as_f64().unwrap_or(0.0) < b.as_f64().unwrap_or(0.0)
                }
                _ => false,
            },
            FilterCondition::Lte(field, value) => match (record.metadata.get(field), value) {
                (Some(Value::Number(a)), Value::Number(b)) => {
                    a.as_f64().unwrap_or(0.0) <= b.as_f64().unwrap_or(0.0)
                }
                _ => false,
            },
            FilterCondition::In(field, values) => record
                .metadata
                .get(field)
                .map(|v| values.contains(v))
                .unwrap_or(false),
            FilterCondition::Contains(field, substr) => record
                .metadata
                .get(field)
                .and_then(|v| v.as_str())
                .map(|s| s.contains(substr))
                .unwrap_or(false),
        };

        if !matched {
            return false;
        }
    }

    true
}

#[async_trait]
impl VectorStore for MemoryVectorStore {
    async fn create_collection(&self, name: &str, dimension: usize) -> Result<()> {
        let mut collections = self.collections.write();
        if collections.contains_key(name) {
            return Err(StoreError::AlreadyExists(format!("Collection '{}'", name)));
        }
        collections.insert(
            name.to_string(),
            Collection {
                dimension,
                records: HashMap::new(),
            },
        );
        Ok(())
    }

    async fn delete_collection(&self, name: &str) -> Result<()> {
        let mut collections = self.collections.write();
        if collections.remove(name).is_none() {
            return Err(StoreError::not_found(format!("Collection '{}'", name)));
        }
        Ok(())
    }

    async fn list_collections(&self) -> Result<Vec<CollectionInfo>> {
        let collections = self.collections.read();
        Ok(collections
            .iter()
            .map(|(name, col)| CollectionInfo {
                name: name.clone(),
                dimension: col.dimension,
                count: col.records.len(),
            })
            .collect())
    }

    async fn collection_exists(&self, name: &str) -> Result<bool> {
        Ok(self.collections.read().contains_key(name))
    }

    async fn upsert(&self, collection: &str, records: &[VectorRecord]) -> Result<usize> {
        let mut collections = self.collections.write();
        let col = collections
            .get_mut(collection)
            .ok_or_else(|| StoreError::not_found(format!("Collection '{}'", collection)))?;

        let mut count = 0;
        for record in records {
            if record.vector.len() != col.dimension {
                return Err(StoreError::invalid_input(format!(
                    "Vector dimension mismatch: expected {}, got {}",
                    col.dimension,
                    record.vector.len()
                )));
            }
            col.records.insert(record.id.clone(), record.clone());
            count += 1;
        }

        Ok(count)
    }

    async fn search(
        &self,
        collection: &str,
        query: &[f32],
        limit: usize,
    ) -> Result<Vec<SearchResult>> {
        self.search_with_filter(collection, query, &VectorFilter::default(), limit)
            .await
    }

    async fn search_with_filter(
        &self,
        collection: &str,
        query: &[f32],
        filter: &VectorFilter,
        limit: usize,
    ) -> Result<Vec<SearchResult>> {
        let collections = self.collections.read();
        let col = collections
            .get(collection)
            .ok_or_else(|| StoreError::not_found(format!("Collection '{}'", collection)))?;

        if query.len() != col.dimension {
            return Err(StoreError::invalid_input(format!(
                "Query dimension mismatch: expected {}, got {}",
                col.dimension,
                query.len()
            )));
        }

        let mut scored: Vec<(String, f32, HashMap<String, Value>, Option<String>)> = col
            .records
            .values()
            .filter(|r| filter.is_empty() || matches_filter(r, filter))
            .map(|r| {
                let score = cosine_similarity(query, &r.vector);
                (r.id.clone(), score, r.metadata.clone(), r.content.clone())
            })
            .collect();

        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));

        Ok(scored
            .into_iter()
            .take(limit)
            .map(|(id, score, metadata, content)| SearchResult {
                id,
                score,
                metadata,
                content,
            })
            .collect())
    }

    async fn get(&self, collection: &str, id: &str) -> Result<Option<VectorRecord>> {
        let collections = self.collections.read();
        let col = collections
            .get(collection)
            .ok_or_else(|| StoreError::not_found(format!("Collection '{}'", collection)))?;

        Ok(col.records.get(id).cloned())
    }

    async fn delete(&self, collection: &str, ids: &[String]) -> Result<usize> {
        let mut collections = self.collections.write();
        let col = collections
            .get_mut(collection)
            .ok_or_else(|| StoreError::not_found(format!("Collection '{}'", collection)))?;

        let mut count = 0;
        for id in ids {
            if col.records.remove(id).is_some() {
                count += 1;
            }
        }

        Ok(count)
    }

    async fn count(&self, collection: &str) -> Result<usize> {
        let collections = self.collections.read();
        let col = collections
            .get(collection)
            .ok_or_else(|| StoreError::not_found(format!("Collection '{}'", collection)))?;

        Ok(col.records.len())
    }

    fn backend_name(&self) -> &'static str {
        "memory"
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_collection_lifecycle() {
        let store = MemoryVectorStore::new();

        assert!(!store.collection_exists("test").await.unwrap());

        store.create_collection("test", 3).await.unwrap();
        assert!(store.collection_exists("test").await.unwrap());

        let err = store.create_collection("test", 3).await;
        assert!(err.is_err());

        store.delete_collection("test").await.unwrap();
        assert!(!store.collection_exists("test").await.unwrap());
    }

    #[tokio::test]
    async fn test_upsert_and_search() {
        let store = MemoryVectorStore::new();
        store.create_collection("docs", 3).await.unwrap();

        let records = vec![
            VectorRecord::new("doc1", vec![1.0, 0.0, 0.0])
                .with_metadata("category", "a")
                .with_content("Document one"),
            VectorRecord::new("doc2", vec![0.0, 1.0, 0.0])
                .with_metadata("category", "b")
                .with_content("Document two"),
            VectorRecord::new("doc3", vec![0.707, 0.707, 0.0])
                .with_metadata("category", "a")
                .with_content("Document three"),
        ];

        let count = store.upsert("docs", &records).await.unwrap();
        assert_eq!(count, 3);
        assert_eq!(store.count("docs").await.unwrap(), 3);

        let results = store.search("docs", &[1.0, 0.0, 0.0], 2).await.unwrap();
        assert_eq!(results.len(), 2);
        assert_eq!(results[0].id, "doc1");
        assert!((results[0].score - 1.0).abs() < 0.001);
    }

    #[tokio::test]
    async fn test_search_with_filter() {
        let store = MemoryVectorStore::new();
        store.create_collection("items", 2).await.unwrap();

        let records = vec![
            VectorRecord::new("item1", vec![1.0, 0.0]).with_metadata("type", "book"),
            VectorRecord::new("item2", vec![0.9, 0.1]).with_metadata("type", "book"),
            VectorRecord::new("item3", vec![0.8, 0.2]).with_metadata("type", "video"),
        ];
        store.upsert("items", &records).await.unwrap();

        let filter = VectorFilter::new().eq("type", "book");
        let results = store
            .search_with_filter("items", &[1.0, 0.0], &filter, 10)
            .await
            .unwrap();

        assert_eq!(results.len(), 2);
        for r in &results {
            assert_eq!(r.metadata.get("type"), Some(&Value::String("book".into())));
        }
    }

    #[tokio::test]
    async fn test_get_and_delete() {
        let store = MemoryVectorStore::new();
        store.create_collection("test", 2).await.unwrap();

        store
            .upsert("test", &[VectorRecord::new("id1", vec![1.0, 0.0])])
            .await
            .unwrap();

        let record = store.get("test", "id1").await.unwrap();
        assert!(record.is_some());
        assert_eq!(record.unwrap().id, "id1");

        let deleted = store
            .delete("test", &["id1".to_string()])
            .await
            .unwrap();
        assert_eq!(deleted, 1);

        let record = store.get("test", "id1").await.unwrap();
        assert!(record.is_none());
    }

    #[tokio::test]
    async fn test_dimension_validation() {
        let store = MemoryVectorStore::new();
        store.create_collection("test", 3).await.unwrap();

        let result = store
            .upsert("test", &[VectorRecord::new("id1", vec![1.0, 0.0])])
            .await;
        assert!(result.is_err());

        let result = store.search("test", &[1.0, 0.0], 10).await;
        assert!(result.is_err());
    }

    #[tokio::test]
    async fn test_list_collections() {
        let store = MemoryVectorStore::new();

        store.create_collection("col1", 10).await.unwrap();
        store.create_collection("col2", 20).await.unwrap();

        let list = store.list_collections().await.unwrap();
        assert_eq!(list.len(), 2);

        let names: Vec<_> = list.iter().map(|c| c.name.as_str()).collect();
        assert!(names.contains(&"col1"));
        assert!(names.contains(&"col2"));
    }

    #[tokio::test]
    async fn test_upsert_updates_existing() {
        let store = MemoryVectorStore::new();
        store.create_collection("test", 2).await.unwrap();

        store
            .upsert(
                "test",
                &[VectorRecord::new("id1", vec![1.0, 0.0]).with_metadata("version", 1)],
            )
            .await
            .unwrap();

        store
            .upsert(
                "test",
                &[VectorRecord::new("id1", vec![0.0, 1.0]).with_metadata("version", 2)],
            )
            .await
            .unwrap();

        assert_eq!(store.count("test").await.unwrap(), 1);

        let record = store.get("test", "id1").await.unwrap().unwrap();
        assert_eq!(record.vector, vec![0.0, 1.0]);
        assert_eq!(
            record.metadata.get("version"),
            Some(&Value::Number(2.into()))
        );
    }
}