graphrag_core/generation/
async_mock_llm.rs1use crate::core::traits::{AsyncLanguageModel, GenerationParams, ModelInfo, ModelUsageStats};
7use crate::core::{GraphRAGError, Result};
8use crate::generation::LLMInterface;
9use crate::text::TextProcessor;
10use async_trait::async_trait;
11use std::collections::HashMap;
12use std::sync::atomic::{AtomicU64, Ordering};
13use std::sync::Arc;
14use std::time::{Duration, Instant};
15use tokio::sync::RwLock;
16
17#[derive(Debug)]
19pub struct AsyncMockLLM {
20 response_templates: Arc<RwLock<HashMap<String, String>>>,
21 text_processor: Arc<TextProcessor>,
22 stats: Arc<AsyncLLMStats>,
23 simulate_delay: Option<Duration>,
24}
25
26#[derive(Debug, Default)]
28struct AsyncLLMStats {
29 total_requests: AtomicU64,
30 total_tokens_processed: AtomicU64,
31 total_response_time: Arc<RwLock<Duration>>,
32 error_count: AtomicU64,
33}
34
35impl AsyncMockLLM {
36 pub async fn new() -> Result<Self> {
38 let mut templates = HashMap::new();
39
40 templates.insert(
42 "default".to_string(),
43 "Based on the provided context, here is what I found: {context}".to_string(),
44 );
45 templates.insert(
46 "not_found".to_string(),
47 "I could not find specific information about this in the provided context.".to_string(),
48 );
49 templates.insert(
50 "insufficient_context".to_string(),
51 "The available context is insufficient to provide a complete answer.".to_string(),
52 );
53
54 let text_processor = TextProcessor::new(1000, 100)?;
55
56 Ok(Self {
57 response_templates: Arc::new(RwLock::new(templates)),
58 text_processor: Arc::new(text_processor),
59 stats: Arc::new(AsyncLLMStats::default()),
60 simulate_delay: Some(Duration::from_millis(100)), })
62 }
63
64 pub async fn with_templates(templates: HashMap<String, String>) -> Result<Self> {
66 let text_processor = TextProcessor::new(1000, 100)?;
67
68 Ok(Self {
69 response_templates: Arc::new(RwLock::new(templates)),
70 text_processor: Arc::new(text_processor),
71 stats: Arc::new(AsyncLLMStats::default()),
72 simulate_delay: Some(Duration::from_millis(100)),
73 })
74 }
75
76 pub fn set_simulate_delay(&mut self, delay: Option<Duration>) {
78 self.simulate_delay = delay;
79 }
80
81 async fn generate_extractive_answer(&self, context: &str, query: &str) -> Result<String> {
83 if let Some(delay) = self.simulate_delay {
85 tokio::time::sleep(delay).await;
86 }
87
88 let sentences = self.text_processor.extract_sentences(context);
89 if sentences.is_empty() {
90 return Ok("No relevant context found.".to_string());
91 }
92
93 let query_lower = query.to_lowercase();
95 let query_words: Vec<&str> = query_lower
96 .split_whitespace()
97 .filter(|w| w.len() > 2) .collect();
99
100 if query_words.is_empty() {
101 return Ok("Query too short or contains no meaningful words.".to_string());
102 }
103
104 let mut sentence_scores: Vec<(usize, f32)> = sentences
105 .iter()
106 .enumerate()
107 .map(|(i, sentence)| {
108 let sentence_lower = sentence.to_lowercase();
109 let mut total_score = 0.0;
110 let mut matches = 0;
111
112 for word in &query_words {
113 if sentence_lower.contains(word) {
115 total_score += 2.0;
116 matches += 1;
117 }
118 else if word.len() > 4 {
120 for sentence_word in sentence_lower.split_whitespace() {
121 if sentence_word.contains(word) || word.contains(sentence_word) {
122 total_score += 1.0;
123 matches += 1;
124 break;
125 }
126 }
127 }
128 }
129
130 let coverage_bonus = (matches as f32 / query_words.len() as f32) * 0.5;
132 let final_score = total_score + coverage_bonus;
133
134 (i, final_score)
135 })
136 .collect();
137
138 sentence_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
140
141 let mut answer_sentences = Vec::new();
143 for (idx, score) in sentence_scores.iter().take(5) {
144 if *score > 0.5 {
145 answer_sentences.push(format!(
147 "{} (relevance: {:.1})",
148 sentences[*idx].trim(),
149 score
150 ));
151 }
152 }
153
154 if answer_sentences.is_empty() {
155 for (idx, score) in sentence_scores.iter().take(2) {
157 if *score > 0.0 {
158 answer_sentences.push(format!(
159 "{} (low confidence: {:.1})",
160 sentences[*idx].trim(),
161 score
162 ));
163 }
164 }
165 }
166
167 if answer_sentences.is_empty() {
168 Ok("No directly relevant information found in the context.".to_string())
169 } else {
170 Ok(answer_sentences.join("\n\n"))
171 }
172 }
173
174 async fn generate_smart_answer(&self, context: &str, question: &str) -> Result<String> {
176 let extractive_result = self.generate_extractive_answer(context, question).await?;
178
179 if extractive_result.contains("No relevant") || extractive_result.contains("No directly") {
181 return self.generate_contextual_response(context, question).await;
182 }
183
184 Ok(extractive_result)
185 }
186
187 async fn generate_contextual_response(&self, context: &str, question: &str) -> Result<String> {
189 let question_lower = question.to_lowercase();
190 let context_lower = context.to_lowercase();
191
192 if question_lower.contains("who") && question_lower.contains("friend") {
194 let names = self.extract_character_names(&context_lower).await;
196 if !names.is_empty() {
197 return Ok(format!("Based on the context, the main characters mentioned include: {}. These appear to be friends and companions in the story.", names.join(", ")));
198 }
199 }
200
201 if question_lower.contains("what")
202 && (question_lower.contains("adventure") || question_lower.contains("happen"))
203 {
204 let events = self.extract_key_events(&context_lower).await;
205 if !events.is_empty() {
206 return Ok(format!(
207 "The context describes several events: {}",
208 events.join(", ")
209 ));
210 }
211 }
212
213 if question_lower.contains("where") {
214 let locations = self.extract_locations(&context_lower).await;
215 if !locations.is_empty() {
216 return Ok(format!(
217 "The story takes place in locations such as: {}",
218 locations.join(", ")
219 ));
220 }
221 }
222
223 let summary = self.generate_summary_async(context, 150).await?;
225 Ok(format!("Based on the available context: {summary}"))
226 }
227
228 async fn generate_question_response(&self, question: &str) -> Result<String> {
230 let question_lower = question.to_lowercase();
231
232 if question_lower.contains("friend") || question_lower.contains("relationship") {
234 return Ok("The text describes various character relationships and friendships throughout the narrative.".to_string());
235 }
236
237 if question_lower.contains("main character") || question_lower.contains("protagonist") {
238 return Ok(
239 "The text features several important characters who drive the narrative forward."
240 .to_string(),
241 );
242 }
243
244 if question_lower.contains("event") || question_lower.contains("scene") {
245 return Ok(
246 "The text contains various significant events and scenes that advance the story."
247 .to_string(),
248 );
249 }
250
251 Ok(
252 "I need more specific context to provide a detailed answer to this question."
253 .to_string(),
254 )
255 }
256
257 async fn extract_character_names(&self, text: &str) -> Vec<String> {
259 let mut found_names = Vec::new();
260
261 for word in text.split_whitespace() {
263 let clean_word = word.trim_matches(|c: char| !c.is_alphabetic());
264 if clean_word.len() > 2
265 && clean_word.chars().next().unwrap().is_uppercase()
266 && clean_word.chars().all(|c| c.is_alphabetic())
267 {
268 found_names.push(clean_word.to_lowercase());
269 }
270 }
271
272 found_names
273 }
274
275 async fn extract_key_events(&self, text: &str) -> Vec<String> {
277 let event_keywords = [
278 "adventure",
279 "treasure",
280 "cave",
281 "island",
282 "painting",
283 "school",
284 "church",
285 "graveyard",
286 "river",
287 ];
288 let mut found_events = Vec::new();
289
290 for event in &event_keywords {
291 if text.contains(event) {
292 found_events.push(format!("events involving {event}"));
293 }
294 }
295
296 found_events
297 }
298
299 async fn extract_locations(&self, text: &str) -> Vec<String> {
301 let locations = [
302 "village",
303 "mississippi",
304 "river",
305 "cave",
306 "island",
307 "town",
308 "church",
309 "school",
310 "house",
311 ];
312 let mut found_locations = Vec::new();
313
314 for location in &locations {
315 if text.contains(location) {
316 found_locations.push(location.to_string());
317 }
318 }
319
320 found_locations
321 }
322
323 async fn generate_summary_async(&self, content: &str, max_length: usize) -> Result<String> {
325 let sentences = self.text_processor.extract_sentences(content);
326 if sentences.is_empty() {
327 return Ok(String::new());
328 }
329
330 let mut summary = String::new();
331 for sentence in sentences.iter().take(3) {
332 if summary.len() + sentence.len() > max_length {
333 break;
334 }
335 if !summary.is_empty() {
336 summary.push(' ');
337 }
338 summary.push_str(sentence);
339 }
340
341 Ok(summary)
342 }
343
344 async fn update_stats(&self, tokens: usize, response_time: Duration, is_error: bool) {
346 self.stats.total_requests.fetch_add(1, Ordering::Relaxed);
347
348 if is_error {
349 self.stats.error_count.fetch_add(1, Ordering::Relaxed);
350 } else {
351 self.stats
352 .total_tokens_processed
353 .fetch_add(tokens as u64, Ordering::Relaxed);
354 }
355
356 let mut total_time = self.stats.total_response_time.write().await;
357 *total_time += response_time;
358 }
359}
360
361#[async_trait]
362impl AsyncLanguageModel for AsyncMockLLM {
363 type Error = GraphRAGError;
364
365 async fn complete(&self, prompt: &str) -> Result<String> {
366 let start_time = Instant::now();
367
368 if let Some(delay) = self.simulate_delay {
370 tokio::time::sleep(delay).await;
371 }
372
373 let result = self.generate_response_internal(prompt).await;
374 let response_time = start_time.elapsed();
375
376 let tokens = prompt.len() / 4;
378 self.update_stats(tokens, response_time, result.is_err())
379 .await;
380
381 result
382 }
383
384 async fn complete_with_params(
385 &self,
386 prompt: &str,
387 _params: GenerationParams,
388 ) -> Result<String> {
389 self.complete(prompt).await
391 }
392
393 async fn complete_batch(&self, prompts: &[&str]) -> Result<Vec<String>> {
394 let mut handles = Vec::new();
396
397 for prompt in prompts {
398 let prompt_owned = prompt.to_string();
399 let self_clone = self.clone();
400 handles.push(tokio::spawn(async move {
401 self_clone.complete(&prompt_owned).await
402 }));
403 }
404
405 let mut results = Vec::with_capacity(prompts.len());
406 for handle in handles {
407 match handle.await {
408 Ok(result) => results.push(result?),
409 Err(e) => {
410 return Err(GraphRAGError::Generation {
411 message: format!("Task join error: {e}"),
412 })
413 },
414 }
415 }
416
417 Ok(results)
418 }
419
420 async fn is_available(&self) -> bool {
421 true
422 }
423
424 async fn model_info(&self) -> ModelInfo {
425 ModelInfo {
426 name: "AsyncMockLLM".to_string(),
427 version: Some("1.0.0".to_string()),
428 max_context_length: Some(4096),
429 supports_streaming: true,
430 }
431 }
432
433 async fn get_usage_stats(&self) -> Result<ModelUsageStats> {
434 let total_requests = self.stats.total_requests.load(Ordering::Relaxed);
435 let total_tokens = self.stats.total_tokens_processed.load(Ordering::Relaxed);
436 let error_count = self.stats.error_count.load(Ordering::Relaxed);
437 let total_time = *self.stats.total_response_time.read().await;
438
439 let average_response_time_ms = if total_requests > 0 {
440 total_time.as_millis() as f64 / total_requests as f64
441 } else {
442 0.0
443 };
444
445 let error_rate = if total_requests > 0 {
446 error_count as f64 / total_requests as f64
447 } else {
448 0.0
449 };
450
451 Ok(ModelUsageStats {
452 total_requests,
453 total_tokens_processed: total_tokens,
454 average_response_time_ms,
455 error_rate,
456 })
457 }
458
459 async fn estimate_tokens(&self, prompt: &str) -> Result<usize> {
460 Ok(prompt.len() / 4)
462 }
463}
464
465impl AsyncMockLLM {
466 async fn generate_response_internal(&self, prompt: &str) -> Result<String> {
468 let prompt_lower = prompt.to_lowercase();
469
470 if prompt_lower.contains("context:") && prompt_lower.contains("question:") {
472 if let Some(context_start) = prompt.find("Context:") {
473 let context_section = &prompt[context_start + 8..];
474 if let Some(question_start) = context_section.find("Question:") {
475 let context = context_section[..question_start].trim();
476 let question_section = context_section[question_start + 9..].trim();
477
478 return self.generate_smart_answer(context, question_section).await;
479 }
480 }
481 }
482
483 if prompt_lower.contains("who")
485 || prompt_lower.contains("what")
486 || prompt_lower.contains("where")
487 || prompt_lower.contains("when")
488 || prompt_lower.contains("how")
489 || prompt_lower.contains("why")
490 {
491 return self.generate_question_response(prompt).await;
492 }
493
494 let templates = self.response_templates.read().await;
496 Ok(templates
497 .get("default")
498 .unwrap_or(&"I cannot provide a response based on the given prompt.".to_string())
499 .replace("{context}", &prompt[..prompt.len().min(200)]))
500 }
501}
502
503impl Clone for AsyncMockLLM {
505 fn clone(&self) -> Self {
506 Self {
507 response_templates: Arc::clone(&self.response_templates),
508 text_processor: Arc::clone(&self.text_processor),
509 stats: Arc::clone(&self.stats),
510 simulate_delay: self.simulate_delay,
511 }
512 }
513}
514
515#[async_trait]
517impl LLMInterface for AsyncMockLLM {
518 fn generate_response(&self, prompt: &str) -> Result<String> {
519 if tokio::runtime::Handle::try_current().is_ok() {
521 tokio::task::block_in_place(|| {
522 tokio::runtime::Handle::current().block_on(self.complete(prompt))
523 })
524 } else {
525 let rt = tokio::runtime::Runtime::new().map_err(|e| GraphRAGError::Generation {
527 message: format!("Failed to create async runtime: {e}"),
528 })?;
529 rt.block_on(self.complete(prompt))
530 }
531 }
532
533 fn generate_summary(&self, content: &str, max_length: usize) -> Result<String> {
534 if tokio::runtime::Handle::try_current().is_ok() {
535 tokio::task::block_in_place(|| {
536 tokio::runtime::Handle::current()
537 .block_on(self.generate_summary_async(content, max_length))
538 })
539 } else {
540 let rt = tokio::runtime::Runtime::new().map_err(|e| GraphRAGError::Generation {
541 message: format!("Failed to create async runtime: {e}"),
542 })?;
543 rt.block_on(self.generate_summary_async(content, max_length))
544 }
545 }
546
547 fn extract_key_points(&self, content: &str, num_points: usize) -> Result<Vec<String>> {
548 let keywords = self
549 .text_processor
550 .extract_keywords(content, num_points * 2);
551 let sentences = self.text_processor.extract_sentences(content);
552
553 let mut key_points = Vec::new();
554 for keyword in keywords.iter().take(num_points) {
555 if let Some(sentence) = sentences
557 .iter()
558 .find(|s| s.to_lowercase().contains(&keyword.to_lowercase()))
559 {
560 key_points.push(sentence.clone());
561 } else {
562 key_points.push(format!("Key concept: {keyword}"));
563 }
564 }
565
566 Ok(key_points)
567 }
568}
569
570#[cfg(test)]
571mod tests {
572 use super::*;
573
574 #[tokio::test]
575 async fn test_async_mock_llm_creation() {
576 let llm = AsyncMockLLM::new().await;
577 assert!(llm.is_ok());
578 }
579
580 #[tokio::test]
581 async fn test_async_completion() {
582 let llm = AsyncMockLLM::new().await.unwrap();
583 let result = llm.complete("Hello, world!").await;
584 assert!(result.is_ok());
585 }
586
587 #[tokio::test]
588 async fn test_async_batch_completion() {
589 let llm = AsyncMockLLM::new().await.unwrap();
590 let prompts = vec!["Hello", "World", "Test"];
591 let results = llm.complete_batch(&prompts).await;
592 assert!(results.is_ok());
593 assert_eq!(results.unwrap().len(), 3);
594 }
595
596 #[tokio::test]
597 async fn test_async_usage_stats() {
598 let llm = AsyncMockLLM::new().await.unwrap();
599
600 let _ = llm.complete("Test prompt 1").await;
602 let _ = llm.complete("Test prompt 2").await;
603
604 let stats = llm.get_usage_stats().await.unwrap();
605 assert_eq!(stats.total_requests, 2);
606 assert!(stats.average_response_time_ms > 0.0);
607 }
608
609 #[tokio::test]
610 async fn test_async_model_availability() {
611 let llm = AsyncMockLLM::new().await.unwrap();
612 let is_available = llm.is_available().await;
613 assert!(is_available);
614 }
615
616 #[tokio::test]
617 async fn test_async_model_info() {
618 let llm = AsyncMockLLM::new().await.unwrap();
619 let info = llm.model_info().await;
620 assert_eq!(info.name, "AsyncMockLLM");
621 assert_eq!(info.version, Some("1.0.0".to_string()));
622 assert!(info.supports_streaming);
623 }
624
625 #[tokio::test]
626 async fn test_token_estimation() {
627 let llm = AsyncMockLLM::new().await.unwrap();
628 let tokens = llm.estimate_tokens("This is a test prompt").await.unwrap();
629 assert!(tokens > 0);
630 }
631}