Skip to main content

graphrag_core/generation/
async_mock_llm.rs

1//! Async implementation of MockLLM demonstrating async trait patterns
2//!
3//! This module provides an async version of MockLLM that implements the AsyncLanguageModel trait,
4//! showcasing how to migrate synchronous implementations to async patterns.
5
6use crate::core::traits::{AsyncLanguageModel, GenerationParams, ModelInfo, ModelUsageStats};
7use crate::core::{GraphRAGError, Result};
8use crate::generation::LLMInterface;
9use crate::text::TextProcessor;
10use async_trait::async_trait;
11use std::collections::HashMap;
12use std::sync::atomic::{AtomicU64, Ordering};
13use std::sync::Arc;
14use std::time::{Duration, Instant};
15use tokio::sync::RwLock;
16
17/// Async version of MockLLM that implements AsyncLanguageModel trait
18#[derive(Debug)]
19pub struct AsyncMockLLM {
20    response_templates: Arc<RwLock<HashMap<String, String>>>,
21    text_processor: Arc<TextProcessor>,
22    stats: Arc<AsyncLLMStats>,
23    simulate_delay: Option<Duration>,
24}
25
26/// Statistics tracking for the async LLM
27#[derive(Debug, Default)]
28struct AsyncLLMStats {
29    total_requests: AtomicU64,
30    total_tokens_processed: AtomicU64,
31    total_response_time: Arc<RwLock<Duration>>,
32    error_count: AtomicU64,
33}
34
35impl AsyncMockLLM {
36    /// Create a new async mock LLM
37    pub async fn new() -> Result<Self> {
38        let mut templates = HashMap::new();
39
40        // Default response templates
41        templates.insert(
42            "default".to_string(),
43            "Based on the provided context, here is what I found: {context}".to_string(),
44        );
45        templates.insert(
46            "not_found".to_string(),
47            "I could not find specific information about this in the provided context.".to_string(),
48        );
49        templates.insert(
50            "insufficient_context".to_string(),
51            "The available context is insufficient to provide a complete answer.".to_string(),
52        );
53
54        let text_processor = TextProcessor::new(1000, 100)?;
55
56        Ok(Self {
57            response_templates: Arc::new(RwLock::new(templates)),
58            text_processor: Arc::new(text_processor),
59            stats: Arc::new(AsyncLLMStats::default()),
60            simulate_delay: Some(Duration::from_millis(100)), // Simulate realistic delay
61        })
62    }
63
64    /// Create with custom templates
65    pub async fn with_templates(templates: HashMap<String, String>) -> Result<Self> {
66        let text_processor = TextProcessor::new(1000, 100)?;
67
68        Ok(Self {
69            response_templates: Arc::new(RwLock::new(templates)),
70            text_processor: Arc::new(text_processor),
71            stats: Arc::new(AsyncLLMStats::default()),
72            simulate_delay: Some(Duration::from_millis(100)),
73        })
74    }
75
76    /// Set artificial delay to simulate network latency
77    pub fn set_simulate_delay(&mut self, delay: Option<Duration>) {
78        self.simulate_delay = delay;
79    }
80
81    /// Generate extractive answer from context with improved relevance scoring
82    async fn generate_extractive_answer(&self, context: &str, query: &str) -> Result<String> {
83        // Simulate processing delay
84        if let Some(delay) = self.simulate_delay {
85            tokio::time::sleep(delay).await;
86        }
87
88        let sentences = self.text_processor.extract_sentences(context);
89        if sentences.is_empty() {
90            return Ok("No relevant context found.".to_string());
91        }
92
93        // Enhanced scoring with partial word matching and named entity recognition
94        let query_lower = query.to_lowercase();
95        let query_words: Vec<&str> = query_lower
96            .split_whitespace()
97            .filter(|w| w.len() > 2) // Filter out short words
98            .collect();
99
100        if query_words.is_empty() {
101            return Ok("Query too short or contains no meaningful words.".to_string());
102        }
103
104        let mut sentence_scores: Vec<(usize, f32)> = sentences
105            .iter()
106            .enumerate()
107            .map(|(i, sentence)| {
108                let sentence_lower = sentence.to_lowercase();
109                let mut total_score = 0.0;
110                let mut matches = 0;
111
112                for word in &query_words {
113                    // Exact word match (highest score)
114                    if sentence_lower.contains(word) {
115                        total_score += 2.0;
116                        matches += 1;
117                    }
118                    // Partial match for longer words
119                    else if word.len() > 4 {
120                        for sentence_word in sentence_lower.split_whitespace() {
121                            if sentence_word.contains(word) || word.contains(sentence_word) {
122                                total_score += 1.0;
123                                matches += 1;
124                                break;
125                            }
126                        }
127                    }
128                }
129
130                // Boost score for sentences with multiple matches
131                let coverage_bonus = (matches as f32 / query_words.len() as f32) * 0.5;
132                let final_score = total_score + coverage_bonus;
133
134                (i, final_score)
135            })
136            .collect();
137
138        // Sort by relevance
139        sentence_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
140
141        // Select top sentences with a minimum relevance threshold
142        let mut answer_sentences = Vec::new();
143        for (idx, score) in sentence_scores.iter().take(5) {
144            if *score > 0.5 {
145                // Higher threshold for better quality
146                answer_sentences.push(format!(
147                    "{} (relevance: {:.1})",
148                    sentences[*idx].trim(),
149                    score
150                ));
151            }
152        }
153
154        if answer_sentences.is_empty() {
155            // If no high-quality matches, provide the best available with lower threshold
156            for (idx, score) in sentence_scores.iter().take(2) {
157                if *score > 0.0 {
158                    answer_sentences.push(format!(
159                        "{} (low confidence: {:.1})",
160                        sentences[*idx].trim(),
161                        score
162                    ));
163                }
164            }
165        }
166
167        if answer_sentences.is_empty() {
168            Ok("No directly relevant information found in the context.".to_string())
169        } else {
170            Ok(answer_sentences.join("\n\n"))
171        }
172    }
173
174    /// Generate smart contextual answer
175    async fn generate_smart_answer(&self, context: &str, question: &str) -> Result<String> {
176        // First try extractive approach
177        let extractive_result = self.generate_extractive_answer(context, question).await?;
178
179        // If extractive failed, generate a contextual response
180        if extractive_result.contains("No relevant") || extractive_result.contains("No directly") {
181            return self.generate_contextual_response(context, question).await;
182        }
183
184        Ok(extractive_result)
185    }
186
187    /// Generate contextual response when direct extraction fails
188    async fn generate_contextual_response(&self, context: &str, question: &str) -> Result<String> {
189        let question_lower = question.to_lowercase();
190        let context_lower = context.to_lowercase();
191
192        // Pattern matching for common question types
193        if question_lower.contains("who") && question_lower.contains("friend") {
194            // Look for character names and relationships
195            let names = self.extract_character_names(&context_lower).await;
196            if !names.is_empty() {
197                return Ok(format!("Based on the context, the main characters mentioned include: {}. These appear to be friends and companions in the story.", names.join(", ")));
198            }
199        }
200
201        if question_lower.contains("what")
202            && (question_lower.contains("adventure") || question_lower.contains("happen"))
203        {
204            let events = self.extract_key_events(&context_lower).await;
205            if !events.is_empty() {
206                return Ok(format!(
207                    "The context describes several events: {}",
208                    events.join(", ")
209                ));
210            }
211        }
212
213        if question_lower.contains("where") {
214            let locations = self.extract_locations(&context_lower).await;
215            if !locations.is_empty() {
216                return Ok(format!(
217                    "The story takes place in locations such as: {}",
218                    locations.join(", ")
219                ));
220            }
221        }
222
223        // Fallback: provide a summary of the context
224        let summary = self.generate_summary_async(context, 150).await?;
225        Ok(format!("Based on the available context: {summary}"))
226    }
227
228    /// Generate response for direct questions
229    async fn generate_question_response(&self, question: &str) -> Result<String> {
230        let question_lower = question.to_lowercase();
231
232        // Generic pattern-based responses for common query types
233        if question_lower.contains("friend") || question_lower.contains("relationship") {
234            return Ok("The text describes various character relationships and friendships throughout the narrative.".to_string());
235        }
236
237        if question_lower.contains("main character") || question_lower.contains("protagonist") {
238            return Ok(
239                "The text features several important characters who drive the narrative forward."
240                    .to_string(),
241            );
242        }
243
244        if question_lower.contains("event") || question_lower.contains("scene") {
245            return Ok(
246                "The text contains various significant events and scenes that advance the story."
247                    .to_string(),
248            );
249        }
250
251        Ok(
252            "I need more specific context to provide a detailed answer to this question."
253                .to_string(),
254        )
255    }
256
257    /// Extract capitalized words that might be names from text
258    async fn extract_character_names(&self, text: &str) -> Vec<String> {
259        let mut found_names = Vec::new();
260
261        // Extract capitalized words as potential names
262        for word in text.split_whitespace() {
263            let clean_word = word.trim_matches(|c: char| !c.is_alphabetic());
264            if clean_word.len() > 2
265                && clean_word.chars().next().unwrap().is_uppercase()
266                && clean_word.chars().all(|c| c.is_alphabetic())
267            {
268                found_names.push(clean_word.to_lowercase());
269            }
270        }
271
272        found_names
273    }
274
275    /// Extract key events/actions from text
276    async fn extract_key_events(&self, text: &str) -> Vec<String> {
277        let event_keywords = [
278            "adventure",
279            "treasure",
280            "cave",
281            "island",
282            "painting",
283            "school",
284            "church",
285            "graveyard",
286            "river",
287        ];
288        let mut found_events = Vec::new();
289
290        for event in &event_keywords {
291            if text.contains(event) {
292                found_events.push(format!("events involving {event}"));
293            }
294        }
295
296        found_events
297    }
298
299    /// Extract locations from text
300    async fn extract_locations(&self, text: &str) -> Vec<String> {
301        let locations = [
302            "village",
303            "mississippi",
304            "river",
305            "cave",
306            "island",
307            "town",
308            "church",
309            "school",
310            "house",
311        ];
312        let mut found_locations = Vec::new();
313
314        for location in &locations {
315            if text.contains(location) {
316                found_locations.push(location.to_string());
317            }
318        }
319
320        found_locations
321    }
322
323    /// Generate summary asynchronously
324    async fn generate_summary_async(&self, content: &str, max_length: usize) -> Result<String> {
325        let sentences = self.text_processor.extract_sentences(content);
326        if sentences.is_empty() {
327            return Ok(String::new());
328        }
329
330        let mut summary = String::new();
331        for sentence in sentences.iter().take(3) {
332            if summary.len() + sentence.len() > max_length {
333                break;
334            }
335            if !summary.is_empty() {
336                summary.push(' ');
337            }
338            summary.push_str(sentence);
339        }
340
341        Ok(summary)
342    }
343
344    /// Update statistics after a request
345    async fn update_stats(&self, tokens: usize, response_time: Duration, is_error: bool) {
346        self.stats.total_requests.fetch_add(1, Ordering::Relaxed);
347
348        if is_error {
349            self.stats.error_count.fetch_add(1, Ordering::Relaxed);
350        } else {
351            self.stats
352                .total_tokens_processed
353                .fetch_add(tokens as u64, Ordering::Relaxed);
354        }
355
356        let mut total_time = self.stats.total_response_time.write().await;
357        *total_time += response_time;
358    }
359}
360
361#[async_trait]
362impl AsyncLanguageModel for AsyncMockLLM {
363    type Error = GraphRAGError;
364
365    async fn complete(&self, prompt: &str) -> Result<String> {
366        let start_time = Instant::now();
367
368        // Simulate processing delay
369        if let Some(delay) = self.simulate_delay {
370            tokio::time::sleep(delay).await;
371        }
372
373        let result = self.generate_response_internal(prompt).await;
374        let response_time = start_time.elapsed();
375
376        // Estimate tokens (rough approximation)
377        let tokens = prompt.len() / 4;
378        self.update_stats(tokens, response_time, result.is_err())
379            .await;
380
381        result
382    }
383
384    async fn complete_with_params(
385        &self,
386        prompt: &str,
387        _params: GenerationParams,
388    ) -> Result<String> {
389        // For mock LLM, we ignore parameters and just use the basic complete
390        self.complete(prompt).await
391    }
392
393    async fn complete_batch(&self, prompts: &[&str]) -> Result<Vec<String>> {
394        // Process prompts concurrently for better performance
395        let mut handles = Vec::new();
396
397        for prompt in prompts {
398            let prompt_owned = prompt.to_string();
399            let self_clone = self.clone();
400            handles.push(tokio::spawn(async move {
401                self_clone.complete(&prompt_owned).await
402            }));
403        }
404
405        let mut results = Vec::with_capacity(prompts.len());
406        for handle in handles {
407            match handle.await {
408                Ok(result) => results.push(result?),
409                Err(e) => {
410                    return Err(GraphRAGError::Generation {
411                        message: format!("Task join error: {e}"),
412                    })
413                },
414            }
415        }
416
417        Ok(results)
418    }
419
420    async fn is_available(&self) -> bool {
421        true
422    }
423
424    async fn model_info(&self) -> ModelInfo {
425        ModelInfo {
426            name: "AsyncMockLLM".to_string(),
427            version: Some("1.0.0".to_string()),
428            max_context_length: Some(4096),
429            supports_streaming: true,
430        }
431    }
432
433    async fn get_usage_stats(&self) -> Result<ModelUsageStats> {
434        let total_requests = self.stats.total_requests.load(Ordering::Relaxed);
435        let total_tokens = self.stats.total_tokens_processed.load(Ordering::Relaxed);
436        let error_count = self.stats.error_count.load(Ordering::Relaxed);
437        let total_time = *self.stats.total_response_time.read().await;
438
439        let average_response_time_ms = if total_requests > 0 {
440            total_time.as_millis() as f64 / total_requests as f64
441        } else {
442            0.0
443        };
444
445        let error_rate = if total_requests > 0 {
446            error_count as f64 / total_requests as f64
447        } else {
448            0.0
449        };
450
451        Ok(ModelUsageStats {
452            total_requests,
453            total_tokens_processed: total_tokens,
454            average_response_time_ms,
455            error_rate,
456        })
457    }
458
459    async fn estimate_tokens(&self, prompt: &str) -> Result<usize> {
460        // Simple estimation: ~4 characters per token
461        Ok(prompt.len() / 4)
462    }
463}
464
465impl AsyncMockLLM {
466    /// Internal response generation method
467    async fn generate_response_internal(&self, prompt: &str) -> Result<String> {
468        let prompt_lower = prompt.to_lowercase();
469
470        // Handle Q&A format prompts
471        if prompt_lower.contains("context:") && prompt_lower.contains("question:") {
472            if let Some(context_start) = prompt.find("Context:") {
473                let context_section = &prompt[context_start + 8..];
474                if let Some(question_start) = context_section.find("Question:") {
475                    let context = context_section[..question_start].trim();
476                    let question_section = context_section[question_start + 9..].trim();
477
478                    return self.generate_smart_answer(context, question_section).await;
479                }
480            }
481        }
482
483        // Handle direct questions about specific topics
484        if prompt_lower.contains("who")
485            || prompt_lower.contains("what")
486            || prompt_lower.contains("where")
487            || prompt_lower.contains("when")
488            || prompt_lower.contains("how")
489            || prompt_lower.contains("why")
490        {
491            return self.generate_question_response(prompt).await;
492        }
493
494        // Fallback to template
495        let templates = self.response_templates.read().await;
496        Ok(templates
497            .get("default")
498            .unwrap_or(&"I cannot provide a response based on the given prompt.".to_string())
499            .replace("{context}", &prompt[..prompt.len().min(200)]))
500    }
501}
502
503// Implement Clone for AsyncMockLLM
504impl Clone for AsyncMockLLM {
505    fn clone(&self) -> Self {
506        Self {
507            response_templates: Arc::clone(&self.response_templates),
508            text_processor: Arc::clone(&self.text_processor),
509            stats: Arc::clone(&self.stats),
510            simulate_delay: self.simulate_delay,
511        }
512    }
513}
514
515/// Synchronous LLMInterface implementation for backward compatibility
516#[async_trait]
517impl LLMInterface for AsyncMockLLM {
518    fn generate_response(&self, prompt: &str) -> Result<String> {
519        // For sync compatibility, use tokio's block_in_place if we're in a tokio context
520        if tokio::runtime::Handle::try_current().is_ok() {
521            tokio::task::block_in_place(|| {
522                tokio::runtime::Handle::current().block_on(self.complete(prompt))
523            })
524        } else {
525            // If not in async context, create a new runtime
526            let rt = tokio::runtime::Runtime::new().map_err(|e| GraphRAGError::Generation {
527                message: format!("Failed to create async runtime: {e}"),
528            })?;
529            rt.block_on(self.complete(prompt))
530        }
531    }
532
533    fn generate_summary(&self, content: &str, max_length: usize) -> Result<String> {
534        if tokio::runtime::Handle::try_current().is_ok() {
535            tokio::task::block_in_place(|| {
536                tokio::runtime::Handle::current()
537                    .block_on(self.generate_summary_async(content, max_length))
538            })
539        } else {
540            let rt = tokio::runtime::Runtime::new().map_err(|e| GraphRAGError::Generation {
541                message: format!("Failed to create async runtime: {e}"),
542            })?;
543            rt.block_on(self.generate_summary_async(content, max_length))
544        }
545    }
546
547    fn extract_key_points(&self, content: &str, num_points: usize) -> Result<Vec<String>> {
548        let keywords = self
549            .text_processor
550            .extract_keywords(content, num_points * 2);
551        let sentences = self.text_processor.extract_sentences(content);
552
553        let mut key_points = Vec::new();
554        for keyword in keywords.iter().take(num_points) {
555            // Find a sentence containing this keyword
556            if let Some(sentence) = sentences
557                .iter()
558                .find(|s| s.to_lowercase().contains(&keyword.to_lowercase()))
559            {
560                key_points.push(sentence.clone());
561            } else {
562                key_points.push(format!("Key concept: {keyword}"));
563            }
564        }
565
566        Ok(key_points)
567    }
568}
569
570#[cfg(test)]
571mod tests {
572    use super::*;
573
574    #[tokio::test]
575    async fn test_async_mock_llm_creation() {
576        let llm = AsyncMockLLM::new().await;
577        assert!(llm.is_ok());
578    }
579
580    #[tokio::test]
581    async fn test_async_completion() {
582        let llm = AsyncMockLLM::new().await.unwrap();
583        let result = llm.complete("Hello, world!").await;
584        assert!(result.is_ok());
585    }
586
587    #[tokio::test]
588    async fn test_async_batch_completion() {
589        let llm = AsyncMockLLM::new().await.unwrap();
590        let prompts = vec!["Hello", "World", "Test"];
591        let results = llm.complete_batch(&prompts).await;
592        assert!(results.is_ok());
593        assert_eq!(results.unwrap().len(), 3);
594    }
595
596    #[tokio::test]
597    async fn test_async_usage_stats() {
598        let llm = AsyncMockLLM::new().await.unwrap();
599
600        // Make some requests
601        let _ = llm.complete("Test prompt 1").await;
602        let _ = llm.complete("Test prompt 2").await;
603
604        let stats = llm.get_usage_stats().await.unwrap();
605        assert_eq!(stats.total_requests, 2);
606        assert!(stats.average_response_time_ms > 0.0);
607    }
608
609    #[tokio::test]
610    async fn test_async_model_availability() {
611        let llm = AsyncMockLLM::new().await.unwrap();
612        let is_available = llm.is_available().await;
613        assert!(is_available);
614    }
615
616    #[tokio::test]
617    async fn test_async_model_info() {
618        let llm = AsyncMockLLM::new().await.unwrap();
619        let info = llm.model_info().await;
620        assert_eq!(info.name, "AsyncMockLLM");
621        assert_eq!(info.version, Some("1.0.0".to_string()));
622        assert!(info.supports_streaming);
623    }
624
625    #[tokio::test]
626    async fn test_token_estimation() {
627        let llm = AsyncMockLLM::new().await.unwrap();
628        let tokens = llm.estimate_tokens("This is a test prompt").await.unwrap();
629        assert!(tokens > 0);
630    }
631}