Skip to main content

graphrag_core/core/
test_utils.rs

1//! Test utilities and mock implementations for testing
2//!
3//! This module provides mock implementations of core traits for unit testing
4//! without requiring real services or external dependencies.
5
6use crate::core::error::{GraphRAGError, Result};
7use crate::core::traits::*;
8use async_trait::async_trait;
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex};
11
12/// Mock embedder for testing
13#[derive(Clone)]
14pub struct MockEmbedder {
15    dimension: usize,
16    embeddings: Arc<Mutex<HashMap<String, Vec<f32>>>>,
17}
18
19impl MockEmbedder {
20    /// Create a new mock embedder with the given dimension
21    pub fn new(dimension: usize) -> Self {
22        Self {
23            dimension,
24            embeddings: Arc::new(Mutex::new(HashMap::new())),
25        }
26    }
27
28    /// Pre-populate with known embeddings for testing
29    pub fn with_embedding(self, text: impl Into<String>, embedding: Vec<f32>) -> Self {
30        self.embeddings
31            .lock()
32            .unwrap()
33            .insert(text.into(), embedding);
34        self
35    }
36
37    /// Generate a deterministic embedding based on text hash
38    fn generate_embedding(&self, text: &str) -> Vec<f32> {
39        use std::collections::hash_map::DefaultHasher;
40        use std::hash::{Hash, Hasher};
41
42        let mut hasher = DefaultHasher::new();
43        text.hash(&mut hasher);
44        let hash = hasher.finish();
45
46        // Generate deterministic but different values for each dimension
47        (0..self.dimension)
48            .map(|i| {
49                let seed = hash.wrapping_add(i as u64);
50                (seed % 1000) as f32 / 1000.0
51            })
52            .collect()
53    }
54}
55
56#[async_trait]
57impl AsyncEmbedder for MockEmbedder {
58    type Error = GraphRAGError;
59
60    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
61        // Check if we have a pre-populated embedding
62        if let Some(embedding) = self.embeddings.lock().unwrap().get(text) {
63            return Ok(embedding.clone());
64        }
65
66        // Otherwise generate one
67        Ok(self.generate_embedding(text))
68    }
69
70    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
71        let mut results = Vec::with_capacity(texts.len());
72        for text in texts {
73            results.push(self.embed(text).await?);
74        }
75        Ok(results)
76    }
77
78    fn dimension(&self) -> usize {
79        self.dimension
80    }
81
82    async fn is_ready(&self) -> bool {
83        true
84    }
85}
86
87/// Mock language model for testing
88#[derive(Clone)]
89pub struct MockLanguageModel {
90    responses: Arc<Mutex<HashMap<String, String>>>,
91    default_response: String,
92}
93
94impl MockLanguageModel {
95    /// Create a new mock language model
96    pub fn new() -> Self {
97        Self {
98            responses: Arc::new(Mutex::new(HashMap::new())),
99            default_response: "Mock response".to_string(),
100        }
101    }
102
103    /// Set a specific response for a prompt
104    pub fn with_response(self, prompt: impl Into<String>, response: impl Into<String>) -> Self {
105        self.responses
106            .lock()
107            .unwrap()
108            .insert(prompt.into(), response.into());
109        self
110    }
111
112    /// Set the default response for unmatched prompts
113    pub fn with_default_response(mut self, response: impl Into<String>) -> Self {
114        self.default_response = response.into();
115        self
116    }
117}
118
119impl Default for MockLanguageModel {
120    fn default() -> Self {
121        Self::new()
122    }
123}
124
125#[async_trait]
126impl AsyncLanguageModel for MockLanguageModel {
127    type Error = GraphRAGError;
128
129    async fn complete(&self, prompt: &str) -> Result<String> {
130        if let Some(response) = self.responses.lock().unwrap().get(prompt) {
131            Ok(response.clone())
132        } else {
133            Ok(self.default_response.clone())
134        }
135    }
136
137    async fn complete_with_params(
138        &self,
139        prompt: &str,
140        _params: GenerationParams,
141    ) -> Result<String> {
142        self.complete(prompt).await
143    }
144
145    async fn is_available(&self) -> bool {
146        true
147    }
148
149    async fn model_info(&self) -> ModelInfo {
150        ModelInfo {
151            name: "mock-model".to_string(),
152            version: Some("1.0.0".to_string()),
153            max_context_length: Some(4096),
154            supports_streaming: false,
155        }
156    }
157
158    async fn get_usage_stats(&self) -> Result<ModelUsageStats> {
159        Ok(ModelUsageStats {
160            total_requests: 0,
161            total_tokens_processed: 0,
162            average_response_time_ms: 0.0,
163            error_rate: 0.0,
164        })
165    }
166}
167
168/// Mock vector store for testing
169pub struct MockVectorStore {
170    vectors: Arc<Mutex<HashMap<String, Vec<f32>>>>,
171    dimension: usize,
172}
173
174impl MockVectorStore {
175    /// Create a new mock vector store
176    pub fn new(dimension: usize) -> Self {
177        Self {
178            vectors: Arc::new(Mutex::new(HashMap::new())),
179            dimension,
180        }
181    }
182
183    /// Pre-populate with vectors for testing
184    pub fn with_vector(self, id: impl Into<String>, vector: Vec<f32>) -> Self {
185        self.vectors.lock().unwrap().insert(id.into(), vector);
186        self
187    }
188
189    /// Calculate cosine similarity between two vectors
190    fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
191        let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
192        let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
193        let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
194
195        if mag_a == 0.0 || mag_b == 0.0 {
196            0.0
197        } else {
198            dot / (mag_a * mag_b)
199        }
200    }
201}
202
203#[async_trait]
204impl AsyncVectorStore for MockVectorStore {
205    type Error = GraphRAGError;
206
207    async fn add_vector(
208        &mut self,
209        id: String,
210        vector: Vec<f32>,
211        _metadata: VectorMetadata,
212    ) -> Result<()> {
213        if vector.len() != self.dimension {
214            return Err(GraphRAGError::Embedding {
215                message: format!(
216                    "Vector dimension mismatch: expected {}, got {}",
217                    self.dimension,
218                    vector.len()
219                ),
220            });
221        }
222        self.vectors.lock().unwrap().insert(id, vector);
223        Ok(())
224    }
225
226    async fn add_vectors_batch(&mut self, vectors: VectorBatch) -> Result<()> {
227        for (id, vector, metadata) in vectors {
228            self.add_vector(id, vector, metadata).await?;
229        }
230        Ok(())
231    }
232
233    async fn search(&self, query_vector: &[f32], k: usize) -> Result<Vec<SearchResult>> {
234        if query_vector.len() != self.dimension {
235            return Err(GraphRAGError::Embedding {
236                message: format!(
237                    "Query vector dimension mismatch: expected {}, got {}",
238                    self.dimension,
239                    query_vector.len()
240                ),
241            });
242        }
243
244        let vectors = self.vectors.lock().unwrap();
245        let mut results: Vec<_> = vectors
246            .iter()
247            .map(|(id, vector)| {
248                let similarity = Self::cosine_similarity(query_vector, vector);
249                SearchResult {
250                    id: id.clone(),
251                    distance: 1.0 - similarity, // Convert similarity to distance
252                    metadata: None,
253                }
254            })
255            .collect();
256
257        // Sort by distance (ascending)
258        results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
259
260        // Take top k
261        Ok(results.into_iter().take(k).collect())
262    }
263
264    async fn search_with_threshold(
265        &self,
266        query_vector: &[f32],
267        k: usize,
268        threshold: f32,
269    ) -> Result<Vec<SearchResult>> {
270        let results = self.search(query_vector, k).await?;
271        Ok(results
272            .into_iter()
273            .filter(|r| r.distance <= threshold)
274            .collect())
275    }
276
277    async fn remove_vector(&mut self, id: &str) -> Result<bool> {
278        Ok(self.vectors.lock().unwrap().remove(id).is_some())
279    }
280
281    async fn len(&self) -> usize {
282        self.vectors.lock().unwrap().len()
283    }
284}
285
286/// Mock retriever for testing
287pub struct MockRetriever {
288    results: Arc<Mutex<Vec<String>>>,
289}
290
291impl MockRetriever {
292    /// Create a new mock retriever
293    pub fn new() -> Self {
294        Self {
295            results: Arc::new(Mutex::new(Vec::new())),
296        }
297    }
298
299    /// Pre-populate with results for testing
300    pub fn with_results(self, results: Vec<String>) -> Self {
301        *self.results.lock().unwrap() = results;
302        self
303    }
304}
305
306impl Default for MockRetriever {
307    fn default() -> Self {
308        Self::new()
309    }
310}
311
312#[async_trait]
313impl AsyncRetriever for MockRetriever {
314    type Query = String;
315    type Result = String;
316    type Error = GraphRAGError;
317
318    async fn search(&self, _query: Self::Query, k: usize) -> Result<Vec<Self::Result>> {
319        let results = self.results.lock().unwrap();
320        Ok(results.iter().take(k).cloned().collect())
321    }
322
323    async fn search_with_context(
324        &self,
325        query: Self::Query,
326        _context: &str,
327        k: usize,
328    ) -> Result<Vec<Self::Result>> {
329        self.search(query, k).await
330    }
331
332    async fn update(&mut self, content: Vec<String>) -> Result<()> {
333        *self.results.lock().unwrap() = content;
334        Ok(())
335    }
336
337    async fn health_check(&self) -> Result<bool> {
338        Ok(true)
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345
346    #[tokio::test]
347    async fn test_mock_embedder() {
348        let embedder = MockEmbedder::new(128).with_embedding("test", vec![0.5; 128]);
349
350        let result = embedder.embed("test").await.unwrap();
351        assert_eq!(result.len(), 128);
352        assert_eq!(result[0], 0.5);
353
354        // Test unknown text gets generated embedding
355        let result2 = embedder.embed("unknown").await.unwrap();
356        assert_eq!(result2.len(), 128);
357    }
358
359    #[tokio::test]
360    async fn test_mock_language_model() {
361        let llm = MockLanguageModel::new()
362            .with_response("Hello", "Hi there!")
363            .with_default_response("Default response");
364
365        assert_eq!(llm.complete("Hello").await.unwrap(), "Hi there!");
366        assert_eq!(llm.complete("Unknown").await.unwrap(), "Default response");
367    }
368
369    #[tokio::test]
370    async fn test_mock_vector_store() {
371        let mut store = MockVectorStore::new(3)
372            .with_vector("vec1", vec![1.0, 0.0, 0.0])
373            .with_vector("vec2", vec![0.0, 1.0, 0.0]);
374
375        assert_eq!(store.len().await, 2);
376
377        let results = store.search(&[1.0, 0.0, 0.0], 2).await.unwrap();
378        assert_eq!(results[0].id, "vec1");
379
380        assert!(store.remove_vector("vec1").await.unwrap());
381        assert_eq!(store.len().await, 1);
382    }
383
384    #[tokio::test]
385    async fn test_mock_retriever() {
386        let retriever = MockRetriever::new().with_results(vec![
387            "result1".to_string(),
388            "result2".to_string(),
389            "result3".to_string(),
390        ]);
391
392        let results = retriever.search("query".to_string(), 2).await.unwrap();
393        assert_eq!(results.len(), 2);
394        assert_eq!(results[0], "result1");
395    }
396}