hehe_store/traits/
vector.rs

1use crate::error::Result;
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::collections::HashMap;
6
7#[derive(Clone, Debug, Serialize, Deserialize)]
8pub struct VectorRecord {
9    pub id: String,
10    pub vector: Vec<f32>,
11    #[serde(default)]
12    pub metadata: HashMap<String, Value>,
13    #[serde(skip_serializing_if = "Option::is_none")]
14    pub content: Option<String>,
15}
16
17impl VectorRecord {
18    pub fn new(id: impl Into<String>, vector: Vec<f32>) -> Self {
19        Self {
20            id: id.into(),
21            vector,
22            metadata: HashMap::new(),
23            content: None,
24        }
25    }
26
27    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<Value>) -> Self {
28        self.metadata.insert(key.into(), value.into());
29        self
30    }
31
32    pub fn with_content(mut self, content: impl Into<String>) -> Self {
33        self.content = Some(content.into());
34        self
35    }
36}
37
38#[derive(Clone, Debug, Serialize, Deserialize)]
39pub struct SearchResult {
40    pub id: String,
41    pub score: f32,
42    #[serde(default)]
43    pub metadata: HashMap<String, Value>,
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub content: Option<String>,
46}
47
48#[derive(Clone, Debug, Default)]
49pub struct VectorFilter {
50    pub conditions: Vec<FilterCondition>,
51}
52
53#[derive(Clone, Debug)]
54pub enum FilterCondition {
55    Eq(String, Value),
56    Ne(String, Value),
57    Gt(String, Value),
58    Gte(String, Value),
59    Lt(String, Value),
60    Lte(String, Value),
61    In(String, Vec<Value>),
62    Contains(String, String),
63}
64
65impl VectorFilter {
66    pub fn new() -> Self {
67        Self::default()
68    }
69
70    pub fn eq(mut self, field: impl Into<String>, value: impl Into<Value>) -> Self {
71        self.conditions
72            .push(FilterCondition::Eq(field.into(), value.into()));
73        self
74    }
75
76    pub fn ne(mut self, field: impl Into<String>, value: impl Into<Value>) -> Self {
77        self.conditions
78            .push(FilterCondition::Ne(field.into(), value.into()));
79        self
80    }
81
82    pub fn gt(mut self, field: impl Into<String>, value: impl Into<Value>) -> Self {
83        self.conditions
84            .push(FilterCondition::Gt(field.into(), value.into()));
85        self
86    }
87
88    pub fn lt(mut self, field: impl Into<String>, value: impl Into<Value>) -> Self {
89        self.conditions
90            .push(FilterCondition::Lt(field.into(), value.into()));
91        self
92    }
93
94    pub fn contains(mut self, field: impl Into<String>, value: impl Into<String>) -> Self {
95        self.conditions
96            .push(FilterCondition::Contains(field.into(), value.into()));
97        self
98    }
99
100    pub fn is_empty(&self) -> bool {
101        self.conditions.is_empty()
102    }
103}
104
105#[derive(Clone, Debug)]
106pub struct CollectionInfo {
107    pub name: String,
108    pub dimension: usize,
109    pub count: usize,
110}
111
112#[async_trait]
113pub trait VectorStore: Send + Sync {
114    async fn create_collection(&self, name: &str, dimension: usize) -> Result<()>;
115
116    async fn delete_collection(&self, name: &str) -> Result<()>;
117
118    async fn list_collections(&self) -> Result<Vec<CollectionInfo>>;
119
120    async fn collection_exists(&self, name: &str) -> Result<bool>;
121
122    async fn upsert(&self, collection: &str, records: &[VectorRecord]) -> Result<usize>;
123
124    async fn search(
125        &self,
126        collection: &str,
127        query: &[f32],
128        limit: usize,
129    ) -> Result<Vec<SearchResult>>;
130
131    async fn search_with_filter(
132        &self,
133        collection: &str,
134        query: &[f32],
135        filter: &VectorFilter,
136        limit: usize,
137    ) -> Result<Vec<SearchResult>>;
138
139    async fn get(&self, collection: &str, id: &str) -> Result<Option<VectorRecord>>;
140
141    async fn delete(&self, collection: &str, ids: &[String]) -> Result<usize>;
142
143    async fn count(&self, collection: &str) -> Result<usize>;
144
145    fn backend_name(&self) -> &'static str;
146}
147
148pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
149    if a.len() != b.len() {
150        return 0.0;
151    }
152
153    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
154    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
155    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
156
157    if norm_a == 0.0 || norm_b == 0.0 {
158        return 0.0;
159    }
160
161    dot / (norm_a * norm_b)
162}
163
164pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
165    if a.len() != b.len() {
166        return f32::MAX;
167    }
168
169    a.iter()
170        .zip(b.iter())
171        .map(|(x, y)| (x - y).powi(2))
172        .sum::<f32>()
173        .sqrt()
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    #[test]
181    fn test_vector_record() {
182        let record = VectorRecord::new("id1", vec![0.1, 0.2, 0.3])
183            .with_metadata("key", "value")
184            .with_content("some text");
185
186        assert_eq!(record.id, "id1");
187        assert_eq!(record.vector.len(), 3);
188        assert!(record.metadata.contains_key("key"));
189        assert_eq!(record.content, Some("some text".into()));
190    }
191
192    #[test]
193    fn test_vector_filter() {
194        let filter = VectorFilter::new()
195            .eq("type", "article")
196            .gt("score", 0.5);
197
198        assert_eq!(filter.conditions.len(), 2);
199        assert!(!filter.is_empty());
200    }
201
202    #[test]
203    fn test_cosine_similarity() {
204        let a = vec![1.0, 0.0, 0.0];
205        let b = vec![1.0, 0.0, 0.0];
206        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.0001);
207
208        let c = vec![0.0, 1.0, 0.0];
209        assert!((cosine_similarity(&a, &c) - 0.0).abs() < 0.0001);
210
211        let d = vec![-1.0, 0.0, 0.0];
212        assert!((cosine_similarity(&a, &d) - (-1.0)).abs() < 0.0001);
213    }
214
215    #[test]
216    fn test_euclidean_distance() {
217        let a = vec![0.0, 0.0, 0.0];
218        let b = vec![1.0, 0.0, 0.0];
219        assert!((euclidean_distance(&a, &b) - 1.0).abs() < 0.0001);
220
221        let c = vec![3.0, 4.0, 0.0];
222        assert!((euclidean_distance(&a, &c) - 5.0).abs() < 0.0001);
223    }
224}