1use crate::core::error::{GraphRAGError, Result};
7use crate::core::traits::*;
8use async_trait::async_trait;
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex};
11
12#[derive(Clone)]
14pub struct MockEmbedder {
15 dimension: usize,
16 embeddings: Arc<Mutex<HashMap<String, Vec<f32>>>>,
17}
18
19impl MockEmbedder {
20 pub fn new(dimension: usize) -> Self {
22 Self {
23 dimension,
24 embeddings: Arc::new(Mutex::new(HashMap::new())),
25 }
26 }
27
28 pub fn with_embedding(self, text: impl Into<String>, embedding: Vec<f32>) -> Self {
30 self.embeddings
31 .lock()
32 .unwrap()
33 .insert(text.into(), embedding);
34 self
35 }
36
37 fn generate_embedding(&self, text: &str) -> Vec<f32> {
39 use std::collections::hash_map::DefaultHasher;
40 use std::hash::{Hash, Hasher};
41
42 let mut hasher = DefaultHasher::new();
43 text.hash(&mut hasher);
44 let hash = hasher.finish();
45
46 (0..self.dimension)
48 .map(|i| {
49 let seed = hash.wrapping_add(i as u64);
50 (seed % 1000) as f32 / 1000.0
51 })
52 .collect()
53 }
54}
55
56#[async_trait]
57impl AsyncEmbedder for MockEmbedder {
58 type Error = GraphRAGError;
59
60 async fn embed(&self, text: &str) -> Result<Vec<f32>> {
61 if let Some(embedding) = self.embeddings.lock().unwrap().get(text) {
63 return Ok(embedding.clone());
64 }
65
66 Ok(self.generate_embedding(text))
68 }
69
70 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
71 let mut results = Vec::with_capacity(texts.len());
72 for text in texts {
73 results.push(self.embed(text).await?);
74 }
75 Ok(results)
76 }
77
78 fn dimension(&self) -> usize {
79 self.dimension
80 }
81
82 async fn is_ready(&self) -> bool {
83 true
84 }
85}
86
87#[derive(Clone)]
89pub struct MockLanguageModel {
90 responses: Arc<Mutex<HashMap<String, String>>>,
91 default_response: String,
92}
93
94impl MockLanguageModel {
95 pub fn new() -> Self {
97 Self {
98 responses: Arc::new(Mutex::new(HashMap::new())),
99 default_response: "Mock response".to_string(),
100 }
101 }
102
103 pub fn with_response(self, prompt: impl Into<String>, response: impl Into<String>) -> Self {
105 self.responses
106 .lock()
107 .unwrap()
108 .insert(prompt.into(), response.into());
109 self
110 }
111
112 pub fn with_default_response(mut self, response: impl Into<String>) -> Self {
114 self.default_response = response.into();
115 self
116 }
117}
118
119impl Default for MockLanguageModel {
120 fn default() -> Self {
121 Self::new()
122 }
123}
124
125#[async_trait]
126impl AsyncLanguageModel for MockLanguageModel {
127 type Error = GraphRAGError;
128
129 async fn complete(&self, prompt: &str) -> Result<String> {
130 if let Some(response) = self.responses.lock().unwrap().get(prompt) {
131 Ok(response.clone())
132 } else {
133 Ok(self.default_response.clone())
134 }
135 }
136
137 async fn complete_with_params(
138 &self,
139 prompt: &str,
140 _params: GenerationParams,
141 ) -> Result<String> {
142 self.complete(prompt).await
143 }
144
145 async fn is_available(&self) -> bool {
146 true
147 }
148
149 async fn model_info(&self) -> ModelInfo {
150 ModelInfo {
151 name: "mock-model".to_string(),
152 version: Some("1.0.0".to_string()),
153 max_context_length: Some(4096),
154 supports_streaming: false,
155 }
156 }
157
158 async fn get_usage_stats(&self) -> Result<ModelUsageStats> {
159 Ok(ModelUsageStats {
160 total_requests: 0,
161 total_tokens_processed: 0,
162 average_response_time_ms: 0.0,
163 error_rate: 0.0,
164 })
165 }
166}
167
168pub struct MockVectorStore {
170 vectors: Arc<Mutex<HashMap<String, Vec<f32>>>>,
171 dimension: usize,
172}
173
174impl MockVectorStore {
175 pub fn new(dimension: usize) -> Self {
177 Self {
178 vectors: Arc::new(Mutex::new(HashMap::new())),
179 dimension,
180 }
181 }
182
183 pub fn with_vector(self, id: impl Into<String>, vector: Vec<f32>) -> Self {
185 self.vectors.lock().unwrap().insert(id.into(), vector);
186 self
187 }
188
189 fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
191 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
192 let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
193 let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
194
195 if mag_a == 0.0 || mag_b == 0.0 {
196 0.0
197 } else {
198 dot / (mag_a * mag_b)
199 }
200 }
201}
202
203#[async_trait]
204impl AsyncVectorStore for MockVectorStore {
205 type Error = GraphRAGError;
206
207 async fn add_vector(
208 &mut self,
209 id: String,
210 vector: Vec<f32>,
211 _metadata: VectorMetadata,
212 ) -> Result<()> {
213 if vector.len() != self.dimension {
214 return Err(GraphRAGError::Embedding {
215 message: format!(
216 "Vector dimension mismatch: expected {}, got {}",
217 self.dimension,
218 vector.len()
219 ),
220 });
221 }
222 self.vectors.lock().unwrap().insert(id, vector);
223 Ok(())
224 }
225
226 async fn add_vectors_batch(&mut self, vectors: VectorBatch) -> Result<()> {
227 for (id, vector, metadata) in vectors {
228 self.add_vector(id, vector, metadata).await?;
229 }
230 Ok(())
231 }
232
233 async fn search(&self, query_vector: &[f32], k: usize) -> Result<Vec<SearchResult>> {
234 if query_vector.len() != self.dimension {
235 return Err(GraphRAGError::Embedding {
236 message: format!(
237 "Query vector dimension mismatch: expected {}, got {}",
238 self.dimension,
239 query_vector.len()
240 ),
241 });
242 }
243
244 let vectors = self.vectors.lock().unwrap();
245 let mut results: Vec<_> = vectors
246 .iter()
247 .map(|(id, vector)| {
248 let similarity = Self::cosine_similarity(query_vector, vector);
249 SearchResult {
250 id: id.clone(),
251 distance: 1.0 - similarity, metadata: None,
253 }
254 })
255 .collect();
256
257 results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
259
260 Ok(results.into_iter().take(k).collect())
262 }
263
264 async fn search_with_threshold(
265 &self,
266 query_vector: &[f32],
267 k: usize,
268 threshold: f32,
269 ) -> Result<Vec<SearchResult>> {
270 let results = self.search(query_vector, k).await?;
271 Ok(results
272 .into_iter()
273 .filter(|r| r.distance <= threshold)
274 .collect())
275 }
276
277 async fn remove_vector(&mut self, id: &str) -> Result<bool> {
278 Ok(self.vectors.lock().unwrap().remove(id).is_some())
279 }
280
281 async fn len(&self) -> usize {
282 self.vectors.lock().unwrap().len()
283 }
284}
285
286pub struct MockRetriever {
288 results: Arc<Mutex<Vec<String>>>,
289}
290
291impl MockRetriever {
292 pub fn new() -> Self {
294 Self {
295 results: Arc::new(Mutex::new(Vec::new())),
296 }
297 }
298
299 pub fn with_results(self, results: Vec<String>) -> Self {
301 *self.results.lock().unwrap() = results;
302 self
303 }
304}
305
306impl Default for MockRetriever {
307 fn default() -> Self {
308 Self::new()
309 }
310}
311
312#[async_trait]
313impl AsyncRetriever for MockRetriever {
314 type Query = String;
315 type Result = String;
316 type Error = GraphRAGError;
317
318 async fn search(&self, _query: Self::Query, k: usize) -> Result<Vec<Self::Result>> {
319 let results = self.results.lock().unwrap();
320 Ok(results.iter().take(k).cloned().collect())
321 }
322
323 async fn search_with_context(
324 &self,
325 query: Self::Query,
326 _context: &str,
327 k: usize,
328 ) -> Result<Vec<Self::Result>> {
329 self.search(query, k).await
330 }
331
332 async fn update(&mut self, content: Vec<String>) -> Result<()> {
333 *self.results.lock().unwrap() = content;
334 Ok(())
335 }
336
337 async fn health_check(&self) -> Result<bool> {
338 Ok(true)
339 }
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345
346 #[tokio::test]
347 async fn test_mock_embedder() {
348 let embedder = MockEmbedder::new(128).with_embedding("test", vec![0.5; 128]);
349
350 let result = embedder.embed("test").await.unwrap();
351 assert_eq!(result.len(), 128);
352 assert_eq!(result[0], 0.5);
353
354 let result2 = embedder.embed("unknown").await.unwrap();
356 assert_eq!(result2.len(), 128);
357 }
358
359 #[tokio::test]
360 async fn test_mock_language_model() {
361 let llm = MockLanguageModel::new()
362 .with_response("Hello", "Hi there!")
363 .with_default_response("Default response");
364
365 assert_eq!(llm.complete("Hello").await.unwrap(), "Hi there!");
366 assert_eq!(llm.complete("Unknown").await.unwrap(), "Default response");
367 }
368
369 #[tokio::test]
370 async fn test_mock_vector_store() {
371 let mut store = MockVectorStore::new(3)
372 .with_vector("vec1", vec![1.0, 0.0, 0.0])
373 .with_vector("vec2", vec![0.0, 1.0, 0.0]);
374
375 assert_eq!(store.len().await, 2);
376
377 let results = store.search(&[1.0, 0.0, 0.0], 2).await.unwrap();
378 assert_eq!(results[0].id, "vec1");
379
380 assert!(store.remove_vector("vec1").await.unwrap());
381 assert_eq!(store.len().await, 1);
382 }
383
384 #[tokio::test]
385 async fn test_mock_retriever() {
386 let retriever = MockRetriever::new().with_results(vec![
387 "result1".to_string(),
388 "result2".to_string(),
389 "result3".to_string(),
390 ]);
391
392 let results = retriever.search("query".to_string(), 2).await.unwrap();
393 assert_eq!(results.len(), 2);
394 assert_eq!(results[0], "result1");
395 }
396}