reflex/vectordb/
mock.rs

1use crate::vectordb::{SearchResult, VectorDbClient, VectorDbError, VectorPoint, WriteConsistency};
2use std::collections::HashMap;
3
4#[derive(Default)]
5/// In-memory mock implementation of [`VectorDbClient`].
6pub struct MockVectorDbClient {
7    collections: std::sync::RwLock<HashMap<String, MockCollection>>,
8}
9
10#[derive(Default, Clone)]
11struct MockCollection {
12    vector_size: u64,
13    points: HashMap<u64, MockStoredPoint>,
14}
15
16#[derive(Clone)]
17struct MockStoredPoint {
18    vector: Vec<f32>,
19    tenant_id: u64,
20    context_hash: u64,
21    timestamp: i64,
22    storage_key: Option<String>,
23}
24
25impl MockVectorDbClient {
26    /// Creates an empty mock client.
27    pub fn new() -> Self {
28        Self::default()
29    }
30
31    /// Returns the number of points currently stored in `collection`.
32    pub fn point_count(&self, collection: &str) -> Option<usize> {
33        self.collections
34            .read()
35            .ok()?
36            .get(collection)
37            .map(|c| c.points.len())
38    }
39
40    /// Poisons the internal RwLock for testing error handling paths.
41    /// This method is only available in test builds.
42    #[cfg(test)]
43    pub fn poison_lock(&self) {
44        use std::thread;
45
46        let collections_ptr = &self.collections as *const _ as usize;
47        let handle = thread::spawn(move || {
48            // SAFETY: We're in test code, the pointer is valid for the duration
49            let collections: &std::sync::RwLock<HashMap<String, MockCollection>> =
50                unsafe { &*(collections_ptr as *const _) };
51            let _guard = collections.write().unwrap();
52            panic!("Intentional panic to poison lock for testing");
53        });
54        // Wait for the thread to panic, which poisons the lock
55        let _ = handle.join();
56    }
57}
58
59/// Computes cosine similarity between two f32 vectors.
60impl VectorDbClient for MockVectorDbClient {
61    async fn ensure_collection(&self, name: &str, vector_size: u64) -> Result<(), VectorDbError> {
62        let mut collections =
63            self.collections
64                .write()
65                .map_err(|_| VectorDbError::CreateCollectionFailed {
66                    collection: name.to_string(),
67                    message: "lock poisoned".to_string(),
68                })?;
69
70        collections
71            .entry(name.to_string())
72            .or_insert(MockCollection {
73                vector_size,
74                points: HashMap::new(),
75            });
76
77        Ok(())
78    }
79
80    async fn upsert_points(
81        &self,
82        collection: &str,
83        points: Vec<VectorPoint>,
84        _consistency: WriteConsistency,
85    ) -> Result<(), VectorDbError> {
86        let mut collections =
87            self.collections
88                .write()
89                .map_err(|_| VectorDbError::UpsertFailed {
90                    collection: collection.to_string(),
91                    message: "lock poisoned".to_string(),
92                })?;
93
94        let coll =
95            collections
96                .get_mut(collection)
97                .ok_or_else(|| VectorDbError::CollectionNotFound {
98                    collection: collection.to_string(),
99                })?;
100
101        for point in points {
102            if point.vector.len() as u64 != coll.vector_size {
103                return Err(VectorDbError::InvalidDimension {
104                    expected: coll.vector_size as usize,
105                    actual: point.vector.len(),
106                });
107            }
108
109            coll.points.insert(
110                point.id,
111                MockStoredPoint {
112                    vector: point.vector,
113                    tenant_id: point.tenant_id,
114                    context_hash: point.context_hash,
115                    timestamp: point.timestamp,
116                    storage_key: point.storage_key,
117                },
118            );
119        }
120
121        Ok(())
122    }
123
124    async fn search(
125        &self,
126        collection: &str,
127        query: Vec<f32>,
128        limit: u64,
129        tenant_filter: Option<u64>,
130    ) -> Result<Vec<SearchResult>, VectorDbError> {
131        let collections = self
132            .collections
133            .read()
134            .map_err(|_| VectorDbError::SearchFailed {
135                collection: collection.to_string(),
136                message: "lock poisoned".to_string(),
137            })?;
138
139        let coll =
140            collections
141                .get(collection)
142                .ok_or_else(|| VectorDbError::CollectionNotFound {
143                    collection: collection.to_string(),
144                })?;
145
146        let mut results: Vec<SearchResult> = coll
147            .points
148            .iter()
149            .filter(|(_, p)| tenant_filter.is_none() || tenant_filter == Some(p.tenant_id))
150            .map(|(&id, p)| {
151                let score = cosine_similarity(&query, &p.vector);
152                SearchResult {
153                    id,
154                    score,
155                    tenant_id: p.tenant_id,
156                    context_hash: p.context_hash,
157                    timestamp: p.timestamp,
158                    storage_key: p.storage_key.clone(),
159                }
160            })
161            .collect();
162
163        results.sort_by(|a, b| {
164            b.score
165                .partial_cmp(&a.score)
166                .unwrap_or(std::cmp::Ordering::Equal)
167        });
168
169        results.truncate(limit as usize);
170        Ok(results)
171    }
172
173    async fn delete_points(&self, collection: &str, ids: Vec<u64>) -> Result<(), VectorDbError> {
174        let mut collections =
175            self.collections
176                .write()
177                .map_err(|_| VectorDbError::DeleteFailed {
178                    collection: collection.to_string(),
179                    message: "lock poisoned".to_string(),
180                })?;
181
182        let coll =
183            collections
184                .get_mut(collection)
185                .ok_or_else(|| VectorDbError::CollectionNotFound {
186                    collection: collection.to_string(),
187                })?;
188
189        for id in ids {
190            coll.points.remove(&id);
191        }
192
193        Ok(())
194    }
195}
196
197/// Computes cosine similarity between two f32 vectors.
198pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
199    if a.len() != b.len() || a.is_empty() {
200        return 0.0;
201    }
202
203    let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
204    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
205    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
206
207    if norm_a == 0.0 || norm_b == 0.0 {
208        0.0
209    } else {
210        dot_product / (norm_a * norm_b)
211    }
212}