1use crate::error::Result;
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::collections::HashMap;
6
7#[derive(Clone, Debug, Serialize, Deserialize)]
8pub struct VectorRecord {
9 pub id: String,
10 pub vector: Vec<f32>,
11 #[serde(default)]
12 pub metadata: HashMap<String, Value>,
13 #[serde(skip_serializing_if = "Option::is_none")]
14 pub content: Option<String>,
15}
16
17impl VectorRecord {
18 pub fn new(id: impl Into<String>, vector: Vec<f32>) -> Self {
19 Self {
20 id: id.into(),
21 vector,
22 metadata: HashMap::new(),
23 content: None,
24 }
25 }
26
27 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<Value>) -> Self {
28 self.metadata.insert(key.into(), value.into());
29 self
30 }
31
32 pub fn with_content(mut self, content: impl Into<String>) -> Self {
33 self.content = Some(content.into());
34 self
35 }
36}
37
38#[derive(Clone, Debug, Serialize, Deserialize)]
39pub struct SearchResult {
40 pub id: String,
41 pub score: f32,
42 #[serde(default)]
43 pub metadata: HashMap<String, Value>,
44 #[serde(skip_serializing_if = "Option::is_none")]
45 pub content: Option<String>,
46}
47
48#[derive(Clone, Debug, Default)]
49pub struct VectorFilter {
50 pub conditions: Vec<FilterCondition>,
51}
52
53#[derive(Clone, Debug)]
54pub enum FilterCondition {
55 Eq(String, Value),
56 Ne(String, Value),
57 Gt(String, Value),
58 Gte(String, Value),
59 Lt(String, Value),
60 Lte(String, Value),
61 In(String, Vec<Value>),
62 Contains(String, String),
63}
64
65impl VectorFilter {
66 pub fn new() -> Self {
67 Self::default()
68 }
69
70 pub fn eq(mut self, field: impl Into<String>, value: impl Into<Value>) -> Self {
71 self.conditions
72 .push(FilterCondition::Eq(field.into(), value.into()));
73 self
74 }
75
76 pub fn ne(mut self, field: impl Into<String>, value: impl Into<Value>) -> Self {
77 self.conditions
78 .push(FilterCondition::Ne(field.into(), value.into()));
79 self
80 }
81
82 pub fn gt(mut self, field: impl Into<String>, value: impl Into<Value>) -> Self {
83 self.conditions
84 .push(FilterCondition::Gt(field.into(), value.into()));
85 self
86 }
87
88 pub fn lt(mut self, field: impl Into<String>, value: impl Into<Value>) -> Self {
89 self.conditions
90 .push(FilterCondition::Lt(field.into(), value.into()));
91 self
92 }
93
94 pub fn contains(mut self, field: impl Into<String>, value: impl Into<String>) -> Self {
95 self.conditions
96 .push(FilterCondition::Contains(field.into(), value.into()));
97 self
98 }
99
100 pub fn is_empty(&self) -> bool {
101 self.conditions.is_empty()
102 }
103}
104
105#[derive(Clone, Debug)]
106pub struct CollectionInfo {
107 pub name: String,
108 pub dimension: usize,
109 pub count: usize,
110}
111
112#[async_trait]
113pub trait VectorStore: Send + Sync {
114 async fn create_collection(&self, name: &str, dimension: usize) -> Result<()>;
115
116 async fn delete_collection(&self, name: &str) -> Result<()>;
117
118 async fn list_collections(&self) -> Result<Vec<CollectionInfo>>;
119
120 async fn collection_exists(&self, name: &str) -> Result<bool>;
121
122 async fn upsert(&self, collection: &str, records: &[VectorRecord]) -> Result<usize>;
123
124 async fn search(
125 &self,
126 collection: &str,
127 query: &[f32],
128 limit: usize,
129 ) -> Result<Vec<SearchResult>>;
130
131 async fn search_with_filter(
132 &self,
133 collection: &str,
134 query: &[f32],
135 filter: &VectorFilter,
136 limit: usize,
137 ) -> Result<Vec<SearchResult>>;
138
139 async fn get(&self, collection: &str, id: &str) -> Result<Option<VectorRecord>>;
140
141 async fn delete(&self, collection: &str, ids: &[String]) -> Result<usize>;
142
143 async fn count(&self, collection: &str) -> Result<usize>;
144
145 fn backend_name(&self) -> &'static str;
146}
147
148pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
149 if a.len() != b.len() {
150 return 0.0;
151 }
152
153 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
154 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
155 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
156
157 if norm_a == 0.0 || norm_b == 0.0 {
158 return 0.0;
159 }
160
161 dot / (norm_a * norm_b)
162}
163
164pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
165 if a.len() != b.len() {
166 return f32::MAX;
167 }
168
169 a.iter()
170 .zip(b.iter())
171 .map(|(x, y)| (x - y).powi(2))
172 .sum::<f32>()
173 .sqrt()
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179
180 #[test]
181 fn test_vector_record() {
182 let record = VectorRecord::new("id1", vec![0.1, 0.2, 0.3])
183 .with_metadata("key", "value")
184 .with_content("some text");
185
186 assert_eq!(record.id, "id1");
187 assert_eq!(record.vector.len(), 3);
188 assert!(record.metadata.contains_key("key"));
189 assert_eq!(record.content, Some("some text".into()));
190 }
191
192 #[test]
193 fn test_vector_filter() {
194 let filter = VectorFilter::new()
195 .eq("type", "article")
196 .gt("score", 0.5);
197
198 assert_eq!(filter.conditions.len(), 2);
199 assert!(!filter.is_empty());
200 }
201
202 #[test]
203 fn test_cosine_similarity() {
204 let a = vec![1.0, 0.0, 0.0];
205 let b = vec![1.0, 0.0, 0.0];
206 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.0001);
207
208 let c = vec![0.0, 1.0, 0.0];
209 assert!((cosine_similarity(&a, &c) - 0.0).abs() < 0.0001);
210
211 let d = vec![-1.0, 0.0, 0.0];
212 assert!((cosine_similarity(&a, &d) - (-1.0)).abs() < 0.0001);
213 }
214
215 #[test]
216 fn test_euclidean_distance() {
217 let a = vec![0.0, 0.0, 0.0];
218 let b = vec![1.0, 0.0, 0.0];
219 assert!((euclidean_distance(&a, &b) - 1.0).abs() < 0.0001);
220
221 let c = vec![3.0, 4.0, 0.0];
222 assert!((euclidean_distance(&a, &c) - 5.0).abs() < 0.0001);
223 }
224}