1use crate::core::error::{GraphRAGError, Result};
7use crate::core::traits::*;
8use async_trait::async_trait;
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex};
11
12#[derive(Clone)]
14pub struct MockEmbedder {
15 dimension: usize,
16 embeddings: Arc<Mutex<HashMap<String, Vec<f32>>>>,
17}
18
19impl MockEmbedder {
20 pub fn new(dimension: usize) -> Self {
22 Self {
23 dimension,
24 embeddings: Arc::new(Mutex::new(HashMap::new())),
25 }
26 }
27
28 pub fn with_embedding(self, text: impl Into<String>, embedding: Vec<f32>) -> Self {
30 self.embeddings
31 .lock()
32 .unwrap()
33 .insert(text.into(), embedding);
34 self
35 }
36
37 fn generate_embedding(&self, text: &str) -> Vec<f32> {
39 use std::collections::hash_map::DefaultHasher;
40 use std::hash::{Hash, Hasher};
41
42 let mut hasher = DefaultHasher::new();
43 text.hash(&mut hasher);
44 let hash = hasher.finish();
45
46 (0..self.dimension)
48 .map(|i| {
49 let seed = hash.wrapping_add(i as u64);
50 (seed % 1000) as f32 / 1000.0
51 })
52 .collect()
53 }
54}
55
56#[async_trait]
57impl AsyncEmbedder for MockEmbedder {
58 type Error = GraphRAGError;
59
60 async fn embed(&self, text: &str) -> Result<Vec<f32>> {
61 if let Some(embedding) = self.embeddings.lock().expect("lock poisoned").get(text) {
63 return Ok(embedding.clone());
64 }
65
66 Ok(self.generate_embedding(text))
68 }
69
70 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
71 let mut results = Vec::with_capacity(texts.len());
72 for text in texts {
73 results.push(self.embed(text).await?);
74 }
75 Ok(results)
76 }
77
78 fn dimension(&self) -> usize {
79 self.dimension
80 }
81
82 async fn is_ready(&self) -> bool {
83 true
84 }
85}
86
87#[derive(Clone)]
89pub struct MockLanguageModel {
90 responses: Arc<Mutex<HashMap<String, String>>>,
91 default_response: String,
92}
93
94impl MockLanguageModel {
95 pub fn new() -> Self {
97 Self {
98 responses: Arc::new(Mutex::new(HashMap::new())),
99 default_response: "Mock response".to_string(),
100 }
101 }
102
103 pub fn with_response(self, prompt: impl Into<String>, response: impl Into<String>) -> Self {
105 self.responses
106 .lock()
107 .unwrap()
108 .insert(prompt.into(), response.into());
109 self
110 }
111
112 pub fn with_default_response(mut self, response: impl Into<String>) -> Self {
114 self.default_response = response.into();
115 self
116 }
117}
118
119impl Default for MockLanguageModel {
120 fn default() -> Self {
121 Self::new()
122 }
123}
124
125#[async_trait]
126impl AsyncLanguageModel for MockLanguageModel {
127 type Error = GraphRAGError;
128
129 async fn complete(&self, prompt: &str) -> Result<String> {
130 if let Some(response) = self.responses.lock().expect("lock poisoned").get(prompt) {
131 Ok(response.clone())
132 } else {
133 Ok(self.default_response.clone())
134 }
135 }
136
137 async fn complete_with_params(
138 &self,
139 prompt: &str,
140 _params: GenerationParams,
141 ) -> Result<String> {
142 self.complete(prompt).await
143 }
144
145 async fn is_available(&self) -> bool {
146 true
147 }
148
149 async fn model_info(&self) -> ModelInfo {
150 ModelInfo {
151 name: "mock-model".to_string(),
152 version: Some("1.0.0".to_string()),
153 max_context_length: Some(4096),
154 supports_streaming: false,
155 }
156 }
157
158 async fn get_usage_stats(&self) -> Result<ModelUsageStats> {
159 Ok(ModelUsageStats {
160 total_requests: 0,
161 total_tokens_processed: 0,
162 average_response_time_ms: 0.0,
163 error_rate: 0.0,
164 })
165 }
166}
167
168pub struct MockVectorStore {
170 vectors: Arc<Mutex<HashMap<String, Vec<f32>>>>,
171 dimension: usize,
172}
173
174impl MockVectorStore {
175 pub fn new(dimension: usize) -> Self {
177 Self {
178 vectors: Arc::new(Mutex::new(HashMap::new())),
179 dimension,
180 }
181 }
182
183 pub fn with_vector(self, id: impl Into<String>, vector: Vec<f32>) -> Self {
185 self.vectors
186 .lock()
187 .expect("lock poisoned")
188 .insert(id.into(), vector);
189 self
190 }
191
192 fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
194 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
195 let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
196 let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
197
198 if mag_a == 0.0 || mag_b == 0.0 {
199 0.0
200 } else {
201 dot / (mag_a * mag_b)
202 }
203 }
204}
205
206#[async_trait]
207impl AsyncVectorStore for MockVectorStore {
208 type Error = GraphRAGError;
209
210 async fn add_vector(
211 &mut self,
212 id: String,
213 vector: Vec<f32>,
214 _metadata: VectorMetadata,
215 ) -> Result<()> {
216 if vector.len() != self.dimension {
217 return Err(GraphRAGError::Embedding {
218 message: format!(
219 "Vector dimension mismatch: expected {}, got {}",
220 self.dimension,
221 vector.len()
222 ),
223 });
224 }
225 self.vectors
226 .lock()
227 .expect("lock poisoned")
228 .insert(id, vector);
229 Ok(())
230 }
231
232 async fn add_vectors_batch(&mut self, vectors: VectorBatch) -> Result<()> {
233 for (id, vector, metadata) in vectors {
234 self.add_vector(id, vector, metadata).await?;
235 }
236 Ok(())
237 }
238
239 async fn search(&self, query_vector: &[f32], k: usize) -> Result<Vec<SearchResult>> {
240 if query_vector.len() != self.dimension {
241 return Err(GraphRAGError::Embedding {
242 message: format!(
243 "Query vector dimension mismatch: expected {}, got {}",
244 self.dimension,
245 query_vector.len()
246 ),
247 });
248 }
249
250 let vectors = self.vectors.lock().expect("lock poisoned");
251 let mut results: Vec<_> = vectors
252 .iter()
253 .map(|(id, vector)| {
254 let similarity = Self::cosine_similarity(query_vector, vector);
255 SearchResult {
256 id: id.clone(),
257 distance: 1.0 - similarity, metadata: None,
259 }
260 })
261 .collect();
262
263 results.sort_by(|a, b| {
265 a.distance
266 .partial_cmp(&b.distance)
267 .unwrap_or(std::cmp::Ordering::Equal)
268 });
269
270 Ok(results.into_iter().take(k).collect())
272 }
273
274 async fn search_with_threshold(
275 &self,
276 query_vector: &[f32],
277 k: usize,
278 threshold: f32,
279 ) -> Result<Vec<SearchResult>> {
280 let results = self.search(query_vector, k).await?;
281 Ok(results
282 .into_iter()
283 .filter(|r| r.distance <= threshold)
284 .collect())
285 }
286
287 async fn remove_vector(&mut self, id: &str) -> Result<bool> {
288 Ok(self
289 .vectors
290 .lock()
291 .expect("lock poisoned")
292 .remove(id)
293 .is_some())
294 }
295
296 async fn len(&self) -> usize {
297 self.vectors.lock().expect("lock poisoned").len()
298 }
299}
300
301pub struct MockRetriever {
303 results: Arc<Mutex<Vec<String>>>,
304}
305
306impl MockRetriever {
307 pub fn new() -> Self {
309 Self {
310 results: Arc::new(Mutex::new(Vec::new())),
311 }
312 }
313
314 pub fn with_results(self, results: Vec<String>) -> Self {
316 *self.results.lock().expect("lock poisoned") = results;
317 self
318 }
319}
320
321impl Default for MockRetriever {
322 fn default() -> Self {
323 Self::new()
324 }
325}
326
327#[async_trait]
328impl AsyncRetriever for MockRetriever {
329 type Query = String;
330 type Result = String;
331 type Error = GraphRAGError;
332
333 async fn search(&self, _query: Self::Query, k: usize) -> Result<Vec<Self::Result>> {
334 let results = self.results.lock().expect("lock poisoned");
335 Ok(results.iter().take(k).cloned().collect())
336 }
337
338 async fn search_with_context(
339 &self,
340 query: Self::Query,
341 _context: &str,
342 k: usize,
343 ) -> Result<Vec<Self::Result>> {
344 self.search(query, k).await
345 }
346
347 async fn update(&mut self, content: Vec<String>) -> Result<()> {
348 *self.results.lock().expect("lock poisoned") = content;
349 Ok(())
350 }
351
352 async fn health_check(&self) -> Result<bool> {
353 Ok(true)
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360
361 #[tokio::test]
362 async fn test_mock_embedder() {
363 let embedder = MockEmbedder::new(128).with_embedding("test", vec![0.5; 128]);
364
365 let result = embedder.embed("test").await.unwrap();
366 assert_eq!(result.len(), 128);
367 assert_eq!(result[0], 0.5);
368
369 let result2 = embedder.embed("unknown").await.unwrap();
371 assert_eq!(result2.len(), 128);
372 }
373
374 #[tokio::test]
375 async fn test_mock_language_model() {
376 let llm = MockLanguageModel::new()
377 .with_response("Hello", "Hi there!")
378 .with_default_response("Default response");
379
380 assert_eq!(llm.complete("Hello").await.unwrap(), "Hi there!");
381 assert_eq!(llm.complete("Unknown").await.unwrap(), "Default response");
382 }
383
384 #[tokio::test]
385 async fn test_mock_vector_store() {
386 let mut store = MockVectorStore::new(3)
387 .with_vector("vec1", vec![1.0, 0.0, 0.0])
388 .with_vector("vec2", vec![0.0, 1.0, 0.0]);
389
390 assert_eq!(store.len().await, 2);
391
392 let results = store.search(&[1.0, 0.0, 0.0], 2).await.unwrap();
393 assert_eq!(results[0].id, "vec1");
394
395 assert!(store.remove_vector("vec1").await.unwrap());
396 assert_eq!(store.len().await, 1);
397 }
398
399 #[tokio::test]
400 async fn test_mock_retriever() {
401 let retriever = MockRetriever::new().with_results(vec![
402 "result1".to_string(),
403 "result2".to_string(),
404 "result3".to_string(),
405 ]);
406
407 let results = retriever.search("query".to_string(), 2).await.unwrap();
408 assert_eq!(results.len(), 2);
409 assert_eq!(results[0], "result1");
410 }
411}