1use std::collections::HashMap;
16use std::collections::HashSet;
17
18use async_trait::async_trait;
19
20use cognis_core::Result;
21
22use super::{Filter, SearchResult, VectorStore};
23
24pub fn normalized_fingerprint(text: &str) -> String {
38 let normalised = text
40 .split_whitespace()
41 .map(|w| w.to_lowercase())
42 .collect::<Vec<_>>()
43 .join(" ");
44
45 const OFFSET: u128 = 0x6c62272e07bb014262b821756295c58d;
47 const PRIME: u128 = 0x0000000001000000000000000000013b;
48 let mut h: u128 = OFFSET;
49 for b in normalised.as_bytes() {
50 h ^= u128::from(*b);
51 h = h.wrapping_mul(PRIME);
52 }
53 format!("{h:032x}")
54}
55
56pub struct DedupVectorStore<S, F = fn(&str) -> String>
84where
85 S: VectorStore,
86 F: Fn(&str) -> String + Send + Sync,
87{
88 inner: S,
89 fingerprint_fn: F,
90 seen: HashSet<String>,
91}
92
93impl<S: VectorStore> DedupVectorStore<S, fn(&str) -> String> {
96 pub fn new(inner: S) -> Self {
99 Self {
100 inner,
101 fingerprint_fn: normalized_fingerprint,
102 seen: HashSet::new(),
103 }
104 }
105
106 pub fn with_seen(inner: S, seen: impl IntoIterator<Item = String>) -> Self {
111 Self {
112 inner,
113 fingerprint_fn: normalized_fingerprint,
114 seen: seen.into_iter().collect(),
115 }
116 }
117}
118
119impl<S, F> DedupVectorStore<S, F>
122where
123 S: VectorStore,
124 F: Fn(&str) -> String + Send + Sync,
125{
126 pub fn with_fingerprint(inner: S, f: F) -> Self {
131 Self {
132 inner,
133 fingerprint_fn: f,
134 seen: HashSet::new(),
135 }
136 }
137
138 pub fn with_fingerprint_and_seen(
141 inner: S,
142 f: F,
143 seen: impl IntoIterator<Item = String>,
144 ) -> Self {
145 Self {
146 inner,
147 fingerprint_fn: f,
148 seen: seen.into_iter().collect(),
149 }
150 }
151
152 pub fn contains(&self, text: &str) -> bool {
157 self.seen.contains(&(self.fingerprint_fn)(text))
158 }
159
160 pub fn inner(&self) -> &S {
162 &self.inner
163 }
164
165 pub fn inner_mut(&mut self) -> &mut S {
168 &mut self.inner
169 }
170
171 pub fn seen_fingerprints(&self) -> impl Iterator<Item = &str> {
174 self.seen.iter().map(|s| s.as_str())
175 }
176
177 pub fn seen_count(&self) -> usize {
180 self.seen.len()
181 }
182}
183
184#[async_trait]
189impl<S, F> VectorStore for DedupVectorStore<S, F>
190where
191 S: VectorStore + Send + Sync,
192 F: Fn(&str) -> String + Send + Sync,
193{
194 async fn add_texts(
201 &mut self,
202 texts: Vec<String>,
203 metadata: Option<Vec<HashMap<String, serde_json::Value>>>,
204 ) -> Result<Vec<String>> {
205 if texts.is_empty() {
206 return Ok(Vec::new());
207 }
208
209 let mut pass_texts: Vec<String> = Vec::new();
210 let mut pass_meta: Vec<HashMap<String, serde_json::Value>> = Vec::new();
211 let mut slots: Vec<Option<String>> = Vec::with_capacity(texts.len());
213
214 for (i, text) in texts.iter().enumerate() {
215 let fp = (self.fingerprint_fn)(text);
216 if self.seen.contains(&fp) {
217 slots.push(Some(format!("dedup:skipped:{fp}")));
218 } else {
219 self.seen.insert(fp);
220 pass_texts.push(text.clone());
221 if let Some(m) = &metadata {
222 pass_meta.push(m[i].clone());
223 }
224 slots.push(None);
225 }
226 }
227
228 let real_meta = if metadata.is_some() && !pass_meta.is_empty() {
229 Some(pass_meta)
230 } else {
231 None
232 };
233
234 let mut inner_ids = if !pass_texts.is_empty() {
235 self.inner.add_texts(pass_texts, real_meta).await?
236 } else {
237 Vec::new()
238 };
239
240 let mut inner_iter = inner_ids.drain(..);
242 let ids = slots
243 .into_iter()
244 .map(|slot| match slot {
245 Some(skipped_id) => skipped_id,
246 None => inner_iter.next().unwrap_or_default(),
247 })
248 .collect();
249 Ok(ids)
250 }
251
252 async fn add_vectors(
255 &mut self,
256 vectors: Vec<Vec<f32>>,
257 texts: Vec<String>,
258 metadata: Option<Vec<HashMap<String, serde_json::Value>>>,
259 ) -> Result<Vec<String>> {
260 if texts.is_empty() {
261 return Ok(Vec::new());
262 }
263
264 let mut pass_vecs: Vec<Vec<f32>> = Vec::new();
265 let mut pass_texts: Vec<String> = Vec::new();
266 let mut pass_meta: Vec<HashMap<String, serde_json::Value>> = Vec::new();
267 let mut slots: Vec<Option<String>> = Vec::with_capacity(texts.len());
268
269 for (i, (text, vec)) in texts.iter().zip(vectors.iter()).enumerate() {
270 let fp = (self.fingerprint_fn)(text);
271 if self.seen.contains(&fp) {
272 slots.push(Some(format!("dedup:skipped:{fp}")));
273 } else {
274 self.seen.insert(fp);
275 pass_texts.push(text.clone());
276 pass_vecs.push(vec.clone());
277 if let Some(m) = &metadata {
278 pass_meta.push(m[i].clone());
279 }
280 slots.push(None);
281 }
282 }
283
284 let real_meta = if metadata.is_some() && !pass_meta.is_empty() {
285 Some(pass_meta)
286 } else {
287 None
288 };
289
290 let mut inner_ids = if !pass_texts.is_empty() {
291 self.inner
292 .add_vectors(pass_vecs, pass_texts, real_meta)
293 .await?
294 } else {
295 Vec::new()
296 };
297
298 let mut inner_iter = inner_ids.drain(..);
299 let ids = slots
300 .into_iter()
301 .map(|slot| match slot {
302 Some(skipped_id) => skipped_id,
303 None => inner_iter.next().unwrap_or_default(),
304 })
305 .collect();
306 Ok(ids)
307 }
308
309 async fn similarity_search(&self, query: &str, k: usize) -> Result<Vec<SearchResult>> {
310 self.inner.similarity_search(query, k).await
311 }
312
313 async fn similarity_search_by_vector(
314 &self,
315 query_vector: Vec<f32>,
316 k: usize,
317 ) -> Result<Vec<SearchResult>> {
318 self.inner
319 .similarity_search_by_vector(query_vector, k)
320 .await
321 }
322
323 async fn similarity_search_with_filter(
324 &self,
325 query: &str,
326 k: usize,
327 filter: &Filter,
328 ) -> Result<Vec<SearchResult>> {
329 self.inner
330 .similarity_search_with_filter(query, k, filter)
331 .await
332 }
333
334 async fn delete(&mut self, ids: Vec<String>) -> Result<()> {
335 self.inner.delete(ids).await
336 }
337
338 fn len(&self) -> usize {
339 self.inner.len()
340 }
341}
342
343#[cfg(test)]
348mod tests {
349 use super::*;
350 use crate::embeddings::FakeEmbeddings;
351 use crate::vectorstore::InMemoryVectorStore;
352 use std::sync::Arc;
353
354 fn inner() -> InMemoryVectorStore {
355 InMemoryVectorStore::new(Arc::new(FakeEmbeddings::new(8)))
356 }
357
358 #[tokio::test]
361 async fn skips_duplicate_on_second_add() {
362 let mut store = DedupVectorStore::new(inner());
363 store
364 .add_texts(vec!["the workspace uses Go".into()], None)
365 .await
366 .unwrap();
367 store
368 .add_texts(vec!["the workspace uses Go".into()], None)
369 .await
370 .unwrap();
371 assert_eq!(store.len(), 1);
372 }
373
374 #[tokio::test]
375 async fn case_and_whitespace_normalisation_deduplicates() {
376 let mut store = DedupVectorStore::new(inner());
377 store
378 .add_texts(vec!["The workspace uses Go.".into()], None)
379 .await
380 .unwrap();
381 store
382 .add_texts(vec![" THE WORKSPACE USES GO. ".into()], None)
383 .await
384 .unwrap();
385 assert_eq!(store.len(), 1);
386 }
387
388 #[tokio::test]
389 async fn distinct_content_both_stored() {
390 let mut store = DedupVectorStore::new(inner());
391 store.add_texts(vec!["Fact A.".into()], None).await.unwrap();
392 store.add_texts(vec!["Fact B.".into()], None).await.unwrap();
393 assert_eq!(store.len(), 2);
394 }
395
396 #[tokio::test]
397 async fn batch_add_with_mixed_duplicates() {
398 let mut store = DedupVectorStore::new(inner());
399 let ids1 = store
400 .add_texts(vec!["unique one".into(), "unique two".into()], None)
401 .await
402 .unwrap();
403 assert_eq!(ids1.len(), 2);
404 assert!(!ids1[0].starts_with("dedup:skipped:"));
405 assert!(!ids1[1].starts_with("dedup:skipped:"));
406
407 let ids2 = store
408 .add_texts(
409 vec![
410 "unique one".into(),
411 "unique three".into(),
412 "unique two".into(),
413 ],
414 None,
415 )
416 .await
417 .unwrap();
418 assert_eq!(ids2.len(), 3);
419 assert!(ids2[0].starts_with("dedup:skipped:"));
421 assert!(
422 !ids2[1].starts_with("dedup:skipped:"),
423 "unique three should pass through"
424 );
425 assert!(ids2[2].starts_with("dedup:skipped:"));
426 assert_eq!(store.len(), 3);
427 }
428
429 #[tokio::test]
430 async fn with_seen_skips_pre_populated_fingerprints() {
431 let fp = normalized_fingerprint("already known fact");
432 let mut store = DedupVectorStore::with_seen(inner(), [fp]);
433 store
434 .add_texts(vec!["already known fact".into()], None)
435 .await
436 .unwrap();
437 assert_eq!(store.len(), 0);
438 }
439
440 #[tokio::test]
441 async fn read_operations_pass_through() {
442 let mut store = DedupVectorStore::new(inner());
443 store
444 .add_texts(vec!["searchable fact".into()], None)
445 .await
446 .unwrap();
447 let results = store.similarity_search("fact", 5).await.unwrap();
448 assert!(!results.is_empty());
449 }
450
451 #[tokio::test]
452 async fn delete_passes_through() {
453 let mut store = DedupVectorStore::new(inner());
454 let ids = store
455 .add_texts(vec!["deletable".into()], None)
456 .await
457 .unwrap();
458 assert_eq!(store.len(), 1);
459 store.delete(ids).await.unwrap();
460 assert_eq!(store.len(), 0);
461 }
462
463 #[tokio::test]
464 async fn seen_count_tracks_unique_fingerprints() {
465 let mut store = DedupVectorStore::new(inner());
466 store
467 .add_texts(vec!["a".into(), "b".into()], None)
468 .await
469 .unwrap();
470 store.add_texts(vec!["a".into()], None).await.unwrap(); assert_eq!(store.seen_count(), 2);
472 }
473
474 #[tokio::test]
475 async fn contains_reflects_seen_set() {
476 let mut store = DedupVectorStore::new(inner());
477 assert!(!store.contains("new fact"));
478 store
479 .add_texts(vec!["new fact".into()], None)
480 .await
481 .unwrap();
482 assert!(store.contains("new fact"));
483 assert!(store.contains("NEW FACT"));
485 }
486
487 #[tokio::test]
490 async fn custom_fingerprint_uses_provided_function() {
491 let mut store = DedupVectorStore::with_fingerprint(inner(), |text: &str| {
494 text.split_whitespace()
495 .next()
496 .unwrap_or("")
497 .to_lowercase()
498 .to_string()
499 });
500 store
501 .add_texts(vec!["rust is great".into()], None)
502 .await
503 .unwrap();
504 store
505 .add_texts(vec!["rust is also fast".into()], None)
506 .await
507 .unwrap();
508 assert_eq!(store.len(), 1);
510 }
511
512 #[tokio::test]
515 async fn add_vectors_deduplicates() {
516 let mut store = DedupVectorStore::new(inner());
517 let vec = vec![0.1_f32; 8];
518 store
519 .add_vectors(vec![vec.clone()], vec!["vec fact".into()], None)
520 .await
521 .unwrap();
522 store
523 .add_vectors(vec![vec.clone()], vec!["vec fact".into()], None)
524 .await
525 .unwrap();
526 assert_eq!(store.len(), 1);
527 }
528
529 #[test]
532 fn fingerprint_is_deterministic() {
533 assert_eq!(
534 normalized_fingerprint("hello world"),
535 normalized_fingerprint("hello world")
536 );
537 }
538
539 #[test]
540 fn fingerprint_is_case_insensitive() {
541 assert_eq!(
542 normalized_fingerprint("Hello World"),
543 normalized_fingerprint("hello world")
544 );
545 }
546
547 #[test]
548 fn fingerprint_collapses_whitespace() {
549 assert_eq!(
550 normalized_fingerprint("hello world"),
551 normalized_fingerprint("hello world")
552 );
553 }
554
555 #[test]
556 fn fingerprint_distinguishes_different_content() {
557 assert_ne!(
558 normalized_fingerprint("hello world"),
559 normalized_fingerprint("goodbye world")
560 );
561 }
562}