1use hnsw_rs::prelude::*;
6use std::collections::HashMap;
7use std::sync::{Arc, RwLock};
8
9pub const EMBEDDING_DIM: usize = 384;
11
12#[derive(Debug, thiserror::Error)]
14pub enum VectorStoreError {
15 #[error("HNSW error: {0}")]
16 Hnsw(String),
17
18 #[error("Not found: {0}")]
19 NotFound(String),
20
21 #[error("Invalid embedding dimension: expected {EMBEDDING_DIM}, got {0}")]
22 InvalidDimension(usize),
23
24 #[error("IO error: {0}")]
25 Io(#[from] std::io::Error),
26}
27
28pub type Result<T> = std::result::Result<T, VectorStoreError>;
29
30#[derive(Debug, Clone)]
32pub struct DocumentChunk {
33 pub id: String,
35
36 pub content: String,
38
39 pub source: String,
41
42 pub metadata: Option<String>,
44}
45
46#[derive(Debug, Clone)]
48pub struct SearchResult {
49 pub chunk: DocumentChunk,
51
52 pub score: f32,
54}
55
56pub struct VectorStore {
58 hnsw: Arc<RwLock<Hnsw<'static, f32, DistL2>>>,
60
61 documents: Arc<RwLock<HashMap<usize, DocumentChunk>>>,
63
64 next_id: Arc<RwLock<usize>>,
66
67 config: VectorStoreConfig,
69}
70
71#[derive(Debug, Clone)]
73pub struct VectorStoreConfig {
74 pub max_connections: usize,
76
77 pub ef_construction: usize,
79
80 pub max_elements: usize,
82
83 pub ef_search: usize,
85}
86
87impl Default for VectorStoreConfig {
88 fn default() -> Self {
89 Self {
90 max_connections: 16,
91 ef_construction: 200,
92 max_elements: 100_000,
93 ef_search: 100,
94 }
95 }
96}
97
98impl VectorStore {
99 pub fn new(config: VectorStoreConfig) -> Result<Self> {
101 let hnsw = Hnsw::<f32, DistL2>::new(
102 config.max_connections,
103 config.max_elements,
104 EMBEDDING_DIM,
105 config.ef_construction,
106 DistL2 {},
107 );
108
109 Ok(Self {
110 hnsw: Arc::new(RwLock::new(hnsw)),
111 documents: Arc::new(RwLock::new(HashMap::new())),
112 next_id: Arc::new(RwLock::new(0)),
113 config,
114 })
115 }
116
117 pub fn add_chunk(&self, chunk: DocumentChunk, embedding: &[f32]) -> Result<usize> {
119 if embedding.len() != EMBEDDING_DIM {
120 return Err(VectorStoreError::InvalidDimension(embedding.len()));
121 }
122
123 let mut next_id = self.next_id.write().unwrap();
125 let id = *next_id;
126 *next_id += 1;
127 drop(next_id);
128
129 let hnsw = self.hnsw.write().unwrap();
131 hnsw.insert((embedding, id));
132 drop(hnsw);
133
134 let mut documents = self.documents.write().unwrap();
136 documents.insert(id, chunk);
137
138 Ok(id)
139 }
140
141 pub fn search(&self, query_embedding: &[f32], top_k: usize) -> Result<Vec<SearchResult>> {
143 if query_embedding.len() != EMBEDDING_DIM {
144 return Err(VectorStoreError::InvalidDimension(query_embedding.len()));
145 }
146
147 let results = {
149 let hnsw = self.hnsw.read().unwrap();
150 hnsw.search(query_embedding, top_k, self.config.ef_search)
151 };
152
153 let documents = self.documents.read().unwrap();
155 let search_results: Vec<SearchResult> = results
156 .iter()
157 .filter_map(|neighbor| {
158 let id = neighbor.d_id;
159 documents.get(&id).map(|chunk| SearchResult {
160 chunk: chunk.clone(),
161 score: 1.0 / (1.0 + neighbor.distance), })
163 })
164 .collect();
165
166 Ok(search_results)
167 }
168
169 pub fn get_chunk(&self, id: usize) -> Result<DocumentChunk> {
171 let documents = self.documents.read().unwrap();
172 documents
173 .get(&id)
174 .cloned()
175 .ok_or_else(|| VectorStoreError::NotFound(format!("Document ID {}", id)))
176 }
177
178 pub fn len(&self) -> usize {
180 self.documents.read().unwrap().len()
181 }
182
183 pub fn is_empty(&self) -> bool {
185 self.len() == 0
186 }
187
188 pub fn clear(&self) {
190 let mut documents = self.documents.write().unwrap();
191 documents.clear();
192
193 let mut next_id = self.next_id.write().unwrap();
194 *next_id = 0;
195
196 let mut hnsw = self.hnsw.write().unwrap();
198 *hnsw = Hnsw::<f32, DistL2>::new(
199 self.config.max_connections,
200 self.config.max_elements,
201 EMBEDDING_DIM,
202 self.config.ef_construction,
203 DistL2 {},
204 );
205 }
206
207 pub fn stats(&self) -> VectorStoreStats {
209 let documents = self.documents.read().unwrap();
210 VectorStoreStats {
211 num_chunks: documents.len(),
212 embedding_dim: EMBEDDING_DIM,
213 }
214 }
215}
216
217#[derive(Debug, Clone)]
219pub struct VectorStoreStats {
220 pub num_chunks: usize,
221 pub embedding_dim: usize,
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 fn create_embedding(seed: f32) -> Vec<f32> {
230 (0..EMBEDDING_DIM)
231 .map(|i| (i as f32 * seed).sin())
232 .collect()
233 }
234
235 #[test]
236 fn test_vector_store_creation() {
237 let store = VectorStore::new(VectorStoreConfig::default()).unwrap();
238 assert_eq!(store.len(), 0);
239 assert!(store.is_empty());
240 }
241
242 #[test]
243 fn test_add_chunk() {
244 let store = VectorStore::new(VectorStoreConfig::default()).unwrap();
245
246 let chunk = DocumentChunk {
247 id: "chunk1".to_string(),
248 content: "This is a test document".to_string(),
249 source: "test.txt".to_string(),
250 metadata: None,
251 };
252
253 let embedding = create_embedding(1.0);
254 let id = store.add_chunk(chunk, &embedding).unwrap();
255
256 assert_eq!(store.len(), 1);
257
258 let retrieved = store.get_chunk(id).unwrap();
259 assert_eq!(retrieved.id, "chunk1");
260 assert_eq!(retrieved.content, "This is a test document");
261 }
262
263 #[test]
264 fn test_invalid_embedding_dimension() {
265 let store = VectorStore::new(VectorStoreConfig::default()).unwrap();
266
267 let chunk = DocumentChunk {
268 id: "chunk1".to_string(),
269 content: "Test".to_string(),
270 source: "test.txt".to_string(),
271 metadata: None,
272 };
273
274 let wrong_embedding = vec![0.0; 128]; let result = store.add_chunk(chunk, &wrong_embedding);
276
277 assert!(result.is_err());
278 assert!(matches!(
279 result.unwrap_err(),
280 VectorStoreError::InvalidDimension(_)
281 ));
282 }
283
284 #[test]
285 fn test_search() {
286 let store = VectorStore::new(VectorStoreConfig::default()).unwrap();
287
288 for i in 0..5 {
290 let chunk = DocumentChunk {
291 id: format!("chunk{}", i),
292 content: format!("Document number {}", i),
293 source: "test.txt".to_string(),
294 metadata: None,
295 };
296 let embedding = create_embedding(i as f32);
297 store.add_chunk(chunk, &embedding).unwrap();
298 }
299
300 let query = create_embedding(2.0);
302 let results = store.search(&query, 3).unwrap();
303
304 assert!(results.len() > 0);
305 assert!(results.len() <= 3);
306
307 assert_eq!(results[0].chunk.id, "chunk2");
309 }
310
311 #[test]
312 fn test_clear() {
313 let store = VectorStore::new(VectorStoreConfig::default()).unwrap();
314
315 for i in 0..3 {
317 let chunk = DocumentChunk {
318 id: format!("chunk{}", i),
319 content: format!("Document {}", i),
320 source: "test.txt".to_string(),
321 metadata: None,
322 };
323 let embedding = create_embedding(i as f32);
324 store.add_chunk(chunk, &embedding).unwrap();
325 }
326
327 assert_eq!(store.len(), 3);
328
329 store.clear();
330
331 assert_eq!(store.len(), 0);
332 assert!(store.is_empty());
333 }
334
335 #[test]
336 fn test_stats() {
337 let store = VectorStore::new(VectorStoreConfig::default()).unwrap();
338
339 for i in 0..10 {
341 let chunk = DocumentChunk {
342 id: format!("chunk{}", i),
343 content: format!("Document {}", i),
344 source: "test.txt".to_string(),
345 metadata: None,
346 };
347 let embedding = create_embedding(i as f32);
348 store.add_chunk(chunk, &embedding).unwrap();
349 }
350
351 let stats = store.stats();
352 assert_eq!(stats.num_chunks, 10);
353 assert_eq!(stats.embedding_dim, EMBEDDING_DIM);
354 }
355
356 #[test]
357 fn test_search_ordering() {
358 let store = VectorStore::new(VectorStoreConfig::default()).unwrap();
359
360 for i in 0..5 {
362 let chunk = DocumentChunk {
363 id: format!("chunk{}", i),
364 content: format!("Document {}", i),
365 source: "test.txt".to_string(),
366 metadata: None,
367 };
368 let embedding = create_embedding(i as f32 * 10.0);
369 store.add_chunk(chunk, &embedding).unwrap();
370 }
371
372 let query = create_embedding(30.0);
374 let results = store.search(&query, 5).unwrap();
375
376 for i in 0..results.len() - 1 {
378 assert!(results[i].score >= results[i + 1].score);
379 }
380 }
381}