Skip to main content

graphrag_core/core/
test_utils.rs

1//! Test utilities and mock implementations for testing
2//!
3//! This module provides mock implementations of core traits for unit testing
4//! without requiring real services or external dependencies.
5
6use crate::core::error::{GraphRAGError, Result};
7use crate::core::traits::*;
8use async_trait::async_trait;
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex};
11
12/// Mock embedder for testing
13#[derive(Clone)]
14pub struct MockEmbedder {
15    dimension: usize,
16    embeddings: Arc<Mutex<HashMap<String, Vec<f32>>>>,
17}
18
19impl MockEmbedder {
20    /// Create a new mock embedder with the given dimension
21    pub fn new(dimension: usize) -> Self {
22        Self {
23            dimension,
24            embeddings: Arc::new(Mutex::new(HashMap::new())),
25        }
26    }
27
28    /// Pre-populate with known embeddings for testing
29    pub fn with_embedding(self, text: impl Into<String>, embedding: Vec<f32>) -> Self {
30        self.embeddings
31            .lock()
32            .unwrap()
33            .insert(text.into(), embedding);
34        self
35    }
36
37    /// Generate a deterministic embedding based on text hash
38    fn generate_embedding(&self, text: &str) -> Vec<f32> {
39        use std::collections::hash_map::DefaultHasher;
40        use std::hash::{Hash, Hasher};
41
42        let mut hasher = DefaultHasher::new();
43        text.hash(&mut hasher);
44        let hash = hasher.finish();
45
46        // Generate deterministic but different values for each dimension
47        (0..self.dimension)
48            .map(|i| {
49                let seed = hash.wrapping_add(i as u64);
50                (seed % 1000) as f32 / 1000.0
51            })
52            .collect()
53    }
54}
55
56#[async_trait]
57impl AsyncEmbedder for MockEmbedder {
58    type Error = GraphRAGError;
59
60    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
61        // Check if we have a pre-populated embedding
62        if let Some(embedding) = self.embeddings.lock().expect("lock poisoned").get(text) {
63            return Ok(embedding.clone());
64        }
65
66        // Otherwise generate one
67        Ok(self.generate_embedding(text))
68    }
69
70    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
71        let mut results = Vec::with_capacity(texts.len());
72        for text in texts {
73            results.push(self.embed(text).await?);
74        }
75        Ok(results)
76    }
77
78    fn dimension(&self) -> usize {
79        self.dimension
80    }
81
82    async fn is_ready(&self) -> bool {
83        true
84    }
85}
86
87/// Mock language model for testing
88#[derive(Clone)]
89pub struct MockLanguageModel {
90    responses: Arc<Mutex<HashMap<String, String>>>,
91    default_response: String,
92}
93
94impl MockLanguageModel {
95    /// Create a new mock language model
96    pub fn new() -> Self {
97        Self {
98            responses: Arc::new(Mutex::new(HashMap::new())),
99            default_response: "Mock response".to_string(),
100        }
101    }
102
103    /// Set a specific response for a prompt
104    pub fn with_response(self, prompt: impl Into<String>, response: impl Into<String>) -> Self {
105        self.responses
106            .lock()
107            .unwrap()
108            .insert(prompt.into(), response.into());
109        self
110    }
111
112    /// Set the default response for unmatched prompts
113    pub fn with_default_response(mut self, response: impl Into<String>) -> Self {
114        self.default_response = response.into();
115        self
116    }
117}
118
119impl Default for MockLanguageModel {
120    fn default() -> Self {
121        Self::new()
122    }
123}
124
125#[async_trait]
126impl AsyncLanguageModel for MockLanguageModel {
127    type Error = GraphRAGError;
128
129    async fn complete(&self, prompt: &str) -> Result<String> {
130        if let Some(response) = self.responses.lock().expect("lock poisoned").get(prompt) {
131            Ok(response.clone())
132        } else {
133            Ok(self.default_response.clone())
134        }
135    }
136
137    async fn complete_with_params(
138        &self,
139        prompt: &str,
140        _params: GenerationParams,
141    ) -> Result<String> {
142        self.complete(prompt).await
143    }
144
145    async fn is_available(&self) -> bool {
146        true
147    }
148
149    async fn model_info(&self) -> ModelInfo {
150        ModelInfo {
151            name: "mock-model".to_string(),
152            version: Some("1.0.0".to_string()),
153            max_context_length: Some(4096),
154            supports_streaming: false,
155        }
156    }
157
158    async fn get_usage_stats(&self) -> Result<ModelUsageStats> {
159        Ok(ModelUsageStats {
160            total_requests: 0,
161            total_tokens_processed: 0,
162            average_response_time_ms: 0.0,
163            error_rate: 0.0,
164        })
165    }
166}
167
168/// Mock vector store for testing
169pub struct MockVectorStore {
170    vectors: Arc<Mutex<HashMap<String, Vec<f32>>>>,
171    dimension: usize,
172}
173
174impl MockVectorStore {
175    /// Create a new mock vector store
176    pub fn new(dimension: usize) -> Self {
177        Self {
178            vectors: Arc::new(Mutex::new(HashMap::new())),
179            dimension,
180        }
181    }
182
183    /// Pre-populate with vectors for testing
184    pub fn with_vector(self, id: impl Into<String>, vector: Vec<f32>) -> Self {
185        self.vectors
186            .lock()
187            .expect("lock poisoned")
188            .insert(id.into(), vector);
189        self
190    }
191
192    /// Calculate cosine similarity between two vectors
193    fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
194        let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
195        let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
196        let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
197
198        if mag_a == 0.0 || mag_b == 0.0 {
199            0.0
200        } else {
201            dot / (mag_a * mag_b)
202        }
203    }
204}
205
206#[async_trait]
207impl AsyncVectorStore for MockVectorStore {
208    type Error = GraphRAGError;
209
210    async fn add_vector(
211        &mut self,
212        id: String,
213        vector: Vec<f32>,
214        _metadata: VectorMetadata,
215    ) -> Result<()> {
216        if vector.len() != self.dimension {
217            return Err(GraphRAGError::Embedding {
218                message: format!(
219                    "Vector dimension mismatch: expected {}, got {}",
220                    self.dimension,
221                    vector.len()
222                ),
223            });
224        }
225        self.vectors
226            .lock()
227            .expect("lock poisoned")
228            .insert(id, vector);
229        Ok(())
230    }
231
232    async fn add_vectors_batch(&mut self, vectors: VectorBatch) -> Result<()> {
233        for (id, vector, metadata) in vectors {
234            self.add_vector(id, vector, metadata).await?;
235        }
236        Ok(())
237    }
238
239    async fn search(&self, query_vector: &[f32], k: usize) -> Result<Vec<SearchResult>> {
240        if query_vector.len() != self.dimension {
241            return Err(GraphRAGError::Embedding {
242                message: format!(
243                    "Query vector dimension mismatch: expected {}, got {}",
244                    self.dimension,
245                    query_vector.len()
246                ),
247            });
248        }
249
250        let vectors = self.vectors.lock().expect("lock poisoned");
251        let mut results: Vec<_> = vectors
252            .iter()
253            .map(|(id, vector)| {
254                let similarity = Self::cosine_similarity(query_vector, vector);
255                SearchResult {
256                    id: id.clone(),
257                    distance: 1.0 - similarity, // Convert similarity to distance
258                    metadata: None,
259                }
260            })
261            .collect();
262
263        // Sort by distance (ascending)
264        results.sort_by(|a, b| {
265            a.distance
266                .partial_cmp(&b.distance)
267                .unwrap_or(std::cmp::Ordering::Equal)
268        });
269
270        // Take top k
271        Ok(results.into_iter().take(k).collect())
272    }
273
274    async fn search_with_threshold(
275        &self,
276        query_vector: &[f32],
277        k: usize,
278        threshold: f32,
279    ) -> Result<Vec<SearchResult>> {
280        let results = self.search(query_vector, k).await?;
281        Ok(results
282            .into_iter()
283            .filter(|r| r.distance <= threshold)
284            .collect())
285    }
286
287    async fn remove_vector(&mut self, id: &str) -> Result<bool> {
288        Ok(self
289            .vectors
290            .lock()
291            .expect("lock poisoned")
292            .remove(id)
293            .is_some())
294    }
295
296    async fn len(&self) -> usize {
297        self.vectors.lock().expect("lock poisoned").len()
298    }
299}
300
301/// Mock retriever for testing
302pub struct MockRetriever {
303    results: Arc<Mutex<Vec<String>>>,
304}
305
306impl MockRetriever {
307    /// Create a new mock retriever
308    pub fn new() -> Self {
309        Self {
310            results: Arc::new(Mutex::new(Vec::new())),
311        }
312    }
313
314    /// Pre-populate with results for testing
315    pub fn with_results(self, results: Vec<String>) -> Self {
316        *self.results.lock().expect("lock poisoned") = results;
317        self
318    }
319}
320
321impl Default for MockRetriever {
322    fn default() -> Self {
323        Self::new()
324    }
325}
326
327#[async_trait]
328impl AsyncRetriever for MockRetriever {
329    type Query = String;
330    type Result = String;
331    type Error = GraphRAGError;
332
333    async fn search(&self, _query: Self::Query, k: usize) -> Result<Vec<Self::Result>> {
334        let results = self.results.lock().expect("lock poisoned");
335        Ok(results.iter().take(k).cloned().collect())
336    }
337
338    async fn search_with_context(
339        &self,
340        query: Self::Query,
341        _context: &str,
342        k: usize,
343    ) -> Result<Vec<Self::Result>> {
344        self.search(query, k).await
345    }
346
347    async fn update(&mut self, content: Vec<String>) -> Result<()> {
348        *self.results.lock().expect("lock poisoned") = content;
349        Ok(())
350    }
351
352    async fn health_check(&self) -> Result<bool> {
353        Ok(true)
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360
361    #[tokio::test]
362    async fn test_mock_embedder() {
363        let embedder = MockEmbedder::new(128).with_embedding("test", vec![0.5; 128]);
364
365        let result = embedder.embed("test").await.unwrap();
366        assert_eq!(result.len(), 128);
367        assert_eq!(result[0], 0.5);
368
369        // Test unknown text gets generated embedding
370        let result2 = embedder.embed("unknown").await.unwrap();
371        assert_eq!(result2.len(), 128);
372    }
373
374    #[tokio::test]
375    async fn test_mock_language_model() {
376        let llm = MockLanguageModel::new()
377            .with_response("Hello", "Hi there!")
378            .with_default_response("Default response");
379
380        assert_eq!(llm.complete("Hello").await.unwrap(), "Hi there!");
381        assert_eq!(llm.complete("Unknown").await.unwrap(), "Default response");
382    }
383
384    #[tokio::test]
385    async fn test_mock_vector_store() {
386        let mut store = MockVectorStore::new(3)
387            .with_vector("vec1", vec![1.0, 0.0, 0.0])
388            .with_vector("vec2", vec![0.0, 1.0, 0.0]);
389
390        assert_eq!(store.len().await, 2);
391
392        let results = store.search(&[1.0, 0.0, 0.0], 2).await.unwrap();
393        assert_eq!(results[0].id, "vec1");
394
395        assert!(store.remove_vector("vec1").await.unwrap());
396        assert_eq!(store.len().await, 1);
397    }
398
399    #[tokio::test]
400    async fn test_mock_retriever() {
401        let retriever = MockRetriever::new().with_results(vec![
402            "result1".to_string(),
403            "result2".to_string(),
404            "result3".to_string(),
405        ]);
406
407        let results = retriever.search("query".to_string(), 2).await.unwrap();
408        assert_eq!(results.len(), 2);
409        assert_eq!(results[0], "result1");
410    }
411}