Skip to main content

mockforge_data/rag/
engine.rs

1//! Core RAG engine and retrieval logic
2//!
3//! This module contains the main RAG engine implementation,
4//! including document processing, query handling, and response generation.
5
6use crate::rag::utils::Cache;
7use crate::rag::{
8    config::{EmbeddingProvider, RagConfig},
9    storage::DocumentStorage,
10};
11use crate::schema::SchemaDefinition;
12use crate::Result;
13use serde::{Deserialize, Serialize};
14use serde_json::Value;
15use std::cmp::Ordering;
16use std::collections::HashMap;
17use std::sync::Arc;
18use std::time::Duration;
19use tracing::debug;
20
21/// Document chunk for processing
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct DocumentChunk {
24    /// Unique chunk ID
25    pub id: String,
26    /// Chunk content
27    pub content: String,
28    /// Chunk metadata
29    pub metadata: HashMap<String, String>,
30    /// Embedding vector
31    pub embedding: Vec<f32>,
32    /// Source document ID
33    pub document_id: String,
34    /// Chunk position in document
35    pub position: usize,
36    /// Chunk length
37    pub length: usize,
38}
39
40impl DocumentChunk {
41    /// Create a new document chunk
42    pub fn new(
43        id: String,
44        content: String,
45        metadata: HashMap<String, String>,
46        embedding: Vec<f32>,
47        document_id: String,
48        position: usize,
49        length: usize,
50    ) -> Self {
51        Self {
52            id,
53            content,
54            metadata,
55            embedding,
56            document_id,
57            position,
58            length,
59        }
60    }
61
62    /// Get chunk size
63    pub fn size(&self) -> usize {
64        self.content.len()
65    }
66
67    /// Check if chunk is empty
68    pub fn is_empty(&self) -> bool {
69        self.content.is_empty()
70    }
71
72    /// Get metadata value
73    pub fn get_metadata(&self, key: &str) -> Option<&String> {
74        self.metadata.get(key)
75    }
76
77    /// Set metadata value
78    pub fn set_metadata(&mut self, key: String, value: String) {
79        self.metadata.insert(key, value);
80    }
81
82    /// Calculate similarity with another chunk
83    pub fn similarity(&self, other: &DocumentChunk) -> f32 {
84        cosine_similarity(&self.embedding, &other.embedding)
85    }
86
87    /// Get content preview (first 100 characters)
88    pub fn preview(&self) -> String {
89        if self.content.len() > 100 {
90            format!("{}...", &self.content[..100])
91        } else {
92            self.content.clone()
93        }
94    }
95}
96
97/// Search result with relevance score
98#[derive(Debug, Clone)]
99pub struct SearchResult {
100    /// The document chunk
101    pub chunk: DocumentChunk,
102    /// Relevance score (0.0 to 1.0)
103    pub score: f32,
104    /// Rank in results
105    pub rank: usize,
106}
107
108impl SearchResult {
109    /// Create a new search result
110    pub fn new(chunk: DocumentChunk, score: f32, rank: usize) -> Self {
111        Self { chunk, score, rank }
112    }
113}
114
115impl PartialEq for SearchResult {
116    fn eq(&self, other: &Self) -> bool {
117        self.score == other.score
118    }
119}
120
121impl Eq for SearchResult {}
122
123impl PartialOrd for SearchResult {
124    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
125        Some(self.cmp(other))
126    }
127}
128
129impl Ord for SearchResult {
130    fn cmp(&self, other: &Self) -> Ordering {
131        self.partial_cmp(other).unwrap_or(Ordering::Equal)
132    }
133}
134
135/// RAG engine for document retrieval and generation
136pub struct RagEngine {
137    /// RAG configuration
138    config: RagConfig,
139    /// Document storage backend
140    storage: Arc<dyn DocumentStorage>,
141    /// HTTP client for API calls
142    client: reqwest::Client,
143    /// Total response time in milliseconds for calculating average
144    total_response_time_ms: f64,
145    /// Number of responses for calculating average
146    response_count: usize,
147    /// Cache for query embeddings
148    embedding_cache: Cache<String, Vec<f32>>,
149}
150
151impl std::fmt::Debug for RagEngine {
152    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153        f.debug_struct("RagEngine")
154            .field("config", &self.config)
155            .field("storage", &"<DocumentStorage>")
156            .field("client", &"<reqwest::Client>")
157            .field("total_response_time_ms", &self.total_response_time_ms)
158            .field("response_count", &self.response_count)
159            .field("embedding_cache", &"<Cache>")
160            .finish()
161    }
162}
163
164impl RagEngine {
165    /// Create a new RAG engine
166    pub fn new(config: RagConfig, storage: Arc<dyn DocumentStorage>) -> Result<Self> {
167        let client = reqwest::ClientBuilder::new().timeout(config.timeout_duration()).build()?;
168
169        let cache_ttl = config.cache_ttl_duration().as_secs();
170
171        Ok(Self {
172            config,
173            storage,
174            client,
175            total_response_time_ms: 0.0,
176            response_count: 0,
177            embedding_cache: Cache::new(cache_ttl, 1000), // Cache up to 1000 embeddings
178        })
179    }
180
181    /// Record response time for stats
182    fn record_response_time(&mut self, duration: Duration) {
183        let ms = duration.as_millis() as f64;
184        self.total_response_time_ms += ms;
185        self.response_count += 1;
186    }
187
188    /// Get configuration
189    pub fn config(&self) -> &RagConfig {
190        &self.config
191    }
192
193    /// Get storage backend
194    pub fn storage(&self) -> &Arc<dyn DocumentStorage> {
195        &self.storage
196    }
197
198    /// Update configuration
199    pub fn update_config(&mut self, config: RagConfig) -> Result<()> {
200        config.validate()?;
201        self.config = config;
202        Ok(())
203    }
204
205    /// Add document to the knowledge base
206    pub async fn add_document(
207        &self,
208        document_id: String,
209        content: String,
210        metadata: HashMap<String, String>,
211    ) -> Result<()> {
212        debug!("Adding document: {}", document_id);
213
214        // Split document into chunks
215        let chunks = self.create_chunks(document_id.clone(), content, metadata).await?;
216
217        // Generate embeddings for chunks
218        let chunks_with_embeddings = self.generate_embeddings(chunks).await?;
219
220        // Store chunks
221        self.storage.store_chunks(chunks_with_embeddings).await?;
222
223        debug!("Successfully added document: {}", document_id);
224        Ok(())
225    }
226
227    /// Search for relevant documents
228    pub async fn search(&mut self, query: &str, top_k: Option<usize>) -> Result<Vec<SearchResult>> {
229        let start = tokio::time::Instant::now();
230        let top_k = top_k.unwrap_or(self.config.top_k);
231        debug!("Searching for: {} (top_k: {})", query, top_k);
232
233        // Generate embedding for query
234        let query_embedding = self.generate_query_embedding(query).await?;
235
236        // Search for similar chunks
237        let candidates = self.storage.search_similar(&query_embedding, top_k * 2).await?; // Get more candidates for reranking
238
239        // Rerank results if needed
240        let results = if self.config.hybrid_search {
241            self.hybrid_search(query, &query_embedding, candidates).await?
242        } else {
243            self.semantic_search(&query_embedding, candidates).await?
244        };
245
246        debug!("Found {} relevant chunks", results.len());
247        let duration = start.elapsed();
248        self.record_response_time(duration);
249        Ok(results)
250    }
251
252    /// Generate response using RAG
253    pub async fn generate(&mut self, query: &str, context: Option<&str>) -> Result<String> {
254        let start = tokio::time::Instant::now();
255        debug!("Generating response for query: {}", query);
256
257        // Search for relevant context
258        let search_results = self.search(query, None).await?;
259
260        // Build context from search results
261        let rag_context = self.build_context(&search_results, context);
262
263        // Generate response using LLM
264        let response = self.generate_with_llm(query, &rag_context).await?;
265
266        debug!("Generated response ({} chars)", response.len());
267        let duration = start.elapsed();
268        self.record_response_time(duration);
269        Ok(response)
270    }
271
272    /// Generate enhanced dataset using RAG
273    pub async fn generate_dataset(
274        &mut self,
275        schema: &SchemaDefinition,
276        count: usize,
277        context: Option<&str>,
278    ) -> Result<Vec<HashMap<String, Value>>> {
279        let start = tokio::time::Instant::now();
280        debug!("Generating dataset with {} rows using schema: {}", count, schema.name);
281
282        // Create generation prompt
283        let prompt = self.create_generation_prompt(schema, count, context);
284
285        // Generate response
286        let response = self.generate(&prompt, None).await?;
287
288        // Parse response into structured data
289        let dataset = self.parse_dataset_response(&response, schema)?;
290
291        debug!("Generated dataset with {} rows", dataset.len());
292        let duration = start.elapsed();
293        self.record_response_time(duration);
294        Ok(dataset)
295    }
296
297    /// Get engine statistics
298    pub async fn get_stats(&self) -> Result<RagStats> {
299        let storage_stats = self.storage.get_stats().await?;
300
301        let average_response_time_ms = if self.response_count > 0 {
302            (self.total_response_time_ms / self.response_count as f64) as f32
303        } else {
304            0.0
305        };
306
307        Ok(RagStats {
308            total_documents: storage_stats.total_documents,
309            total_chunks: storage_stats.total_chunks,
310            index_size_bytes: storage_stats.index_size_bytes,
311            last_updated: storage_stats.last_updated,
312            cache_hit_rate: self.embedding_cache.hit_rate(),
313            average_response_time_ms,
314        })
315    }
316
317    /// Create chunks from document
318    async fn create_chunks(
319        &self,
320        document_id: String,
321        content: String,
322        metadata: HashMap<String, String>,
323    ) -> Result<Vec<DocumentChunk>> {
324        let mut chunks = Vec::new();
325        let words: Vec<&str> = content.split_whitespace().collect();
326        let chunk_size = self.config.chunk_size;
327        let overlap = self.config.chunk_overlap;
328
329        for (i, chunk_start) in (0..words.len()).step_by(chunk_size - overlap).enumerate() {
330            let chunk_end = (chunk_start + chunk_size).min(words.len());
331            let chunk_words: Vec<&str> = words[chunk_start..chunk_end].to_vec();
332            let chunk_content = chunk_words.join(" ");
333
334            if !chunk_content.is_empty() {
335                let chunk_id = format!("{}_chunk_{}", document_id, i);
336
337                chunks.push(DocumentChunk::new(
338                    chunk_id,
339                    chunk_content,
340                    metadata.clone(),
341                    Vec::new(), // Embedding will be generated separately
342                    document_id.clone(),
343                    i,
344                    chunk_words.len(),
345                ));
346            }
347        }
348
349        Ok(chunks)
350    }
351
352    /// Generate embeddings for chunks
353    async fn generate_embeddings(&self, chunks: Vec<DocumentChunk>) -> Result<Vec<DocumentChunk>> {
354        let mut chunks_with_embeddings = Vec::new();
355
356        for chunk in chunks {
357            let embedding = self.generate_embedding(&chunk.content).await?;
358            let mut chunk_with_embedding = chunk;
359            chunk_with_embedding.embedding = embedding;
360            chunks_with_embeddings.push(chunk_with_embedding);
361        }
362
363        Ok(chunks_with_embeddings)
364    }
365
366    /// Generate embedding for text
367    async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
368        let provider = &self.config.embedding_provider;
369        let model = &self.config.embedding_model;
370
371        match provider {
372            EmbeddingProvider::OpenAI => self.generate_openai_embedding(text, model).await,
373            EmbeddingProvider::OpenAICompatible => {
374                self.generate_openai_compatible_embedding(text, model).await
375            }
376            EmbeddingProvider::Ollama => {
377                // Ollama uses OpenAI-compatible API for embeddings
378                self.generate_openai_compatible_embedding(text, model).await
379            }
380        }
381    }
382
383    /// Generate query embedding
384    async fn generate_query_embedding(&mut self, query: &str) -> Result<Vec<f32>> {
385        // Check cache first
386        if let Some(embedding) = self.embedding_cache.get(&query.to_string()) {
387            return Ok(embedding);
388        }
389
390        // Generate new embedding
391        let embedding = self.generate_embedding(query).await?;
392
393        // Cache the result
394        self.embedding_cache.put(query.to_string(), embedding.clone());
395
396        Ok(embedding)
397    }
398
399    /// Perform semantic search
400    async fn semantic_search(
401        &self,
402        query_embedding: &[f32],
403        candidates: Vec<DocumentChunk>,
404    ) -> Result<Vec<SearchResult>> {
405        let mut results = Vec::new();
406
407        // Calculate similarity scores
408        for (rank, chunk) in candidates.iter().enumerate() {
409            let score = cosine_similarity(query_embedding, &chunk.embedding);
410
411            results.push(SearchResult::new(chunk.clone(), score, rank));
412        }
413
414        // Sort by score and filter by threshold
415        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
416        results.retain(|r| r.score >= self.config.similarity_threshold);
417
418        // Take top-k results
419        results.truncate(self.config.top_k);
420
421        Ok(results)
422    }
423
424    /// Perform hybrid search (semantic + keyword)
425    async fn hybrid_search(
426        &self,
427        query: &str,
428        query_embedding: &[f32],
429        candidates: Vec<DocumentChunk>,
430    ) -> Result<Vec<SearchResult>> {
431        let mut results = Vec::new();
432
433        // Perform semantic search
434        let semantic_results = self.semantic_search(query_embedding, candidates.clone()).await?;
435
436        // Perform keyword search (placeholder)
437        let keyword_results = self.keyword_search(query, &candidates).await?;
438
439        // Combine results using weights
440        let semantic_weight = self.config.semantic_weight;
441        let keyword_weight = self.config.keyword_weight;
442
443        for (rank, chunk) in candidates.iter().enumerate() {
444            let semantic_score = semantic_results
445                .iter()
446                .find(|r| r.chunk.id == chunk.id)
447                .map(|r| r.score)
448                .unwrap_or(0.0);
449
450            let keyword_score = keyword_results
451                .iter()
452                .find(|r| r.chunk.id == chunk.id)
453                .map(|r| r.score)
454                .unwrap_or(0.0);
455
456            let combined_score = semantic_score * semantic_weight + keyword_score * keyword_weight;
457
458            results.push(SearchResult::new(chunk.clone(), combined_score, rank));
459        }
460
461        // Sort and filter results
462        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
463        results.retain(|r| r.score >= self.config.similarity_threshold);
464        results.truncate(self.config.top_k);
465
466        Ok(results)
467    }
468
469    /// Perform keyword search using TF-IDF-like scoring
470    ///
471    /// Scores each candidate chunk based on:
472    /// - Term frequency (how often query terms appear in the chunk)
473    /// - Exact phrase bonus (query appears as a contiguous substring)
474    async fn keyword_search(
475        &self,
476        query: &str,
477        candidates: &[DocumentChunk],
478    ) -> Result<Vec<SearchResult>> {
479        let query_lower = query.to_lowercase();
480        let query_terms: Vec<&str> = query_lower.split_whitespace().collect();
481
482        if query_terms.is_empty() {
483            return Ok(Vec::new());
484        }
485
486        let num_candidates = candidates.len();
487        let mut results = Vec::new();
488
489        for (rank, chunk) in candidates.iter().enumerate() {
490            let content_lower = chunk.content.to_lowercase();
491            let content_words: Vec<&str> = content_lower.split_whitespace().collect();
492
493            if content_words.is_empty() {
494                continue;
495            }
496
497            // Term frequency: fraction of query terms found in the chunk
498            let matching_terms = query_terms
499                .iter()
500                .filter(|term| content_words.iter().any(|w| w.contains(*term)))
501                .count();
502            let tf_score = matching_terms as f32 / query_terms.len() as f32;
503
504            // Inverse document frequency approximation: boost terms that appear
505            // in fewer candidate chunks
506            let mut idf_sum = 0.0f32;
507            for term in &query_terms {
508                let docs_with_term = candidates
509                    .iter()
510                    .filter(|c| c.content.to_lowercase().contains(term))
511                    .count()
512                    .max(1);
513                idf_sum += ((num_candidates as f32) / docs_with_term as f32).ln() + 1.0;
514            }
515            let idf_score = idf_sum / query_terms.len() as f32;
516
517            // Exact phrase bonus
518            let phrase_bonus = if query_terms.len() > 1 && content_lower.contains(&query_lower) {
519                0.3
520            } else {
521                0.0
522            };
523
524            let score = (tf_score * idf_score + phrase_bonus).min(1.0);
525
526            if score > 0.0 {
527                results.push(SearchResult::new(chunk.clone(), score, rank));
528            }
529        }
530
531        // Sort by score descending
532        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
533
534        Ok(results)
535    }
536
537    /// Build context from search results
538    fn build_context(
539        &self,
540        search_results: &[SearchResult],
541        additional_context: Option<&str>,
542    ) -> String {
543        let mut context_parts = Vec::new();
544
545        // Add search results
546        for result in search_results {
547            context_parts
548                .push(format!("Content: {}\nRelevance: {:.2}", result.chunk.content, result.score));
549        }
550
551        // Add additional context if provided
552        if let Some(context) = additional_context {
553            context_parts.push(format!("Additional Context: {}", context));
554        }
555
556        context_parts.join("\n\n")
557    }
558
559    /// Generate response using LLM
560    async fn generate_with_llm(&self, query: &str, context: &str) -> Result<String> {
561        let provider = &self.config.provider;
562        let model = &self.config.model;
563
564        match provider {
565            crate::rag::config::LlmProvider::OpenAI => {
566                self.generate_openai_response(query, context, model).await
567            }
568            crate::rag::config::LlmProvider::Anthropic => {
569                self.generate_anthropic_response(query, context, model).await
570            }
571            crate::rag::config::LlmProvider::OpenAICompatible => {
572                self.generate_openai_compatible_response(query, context, model).await
573            }
574            crate::rag::config::LlmProvider::Ollama => {
575                self.generate_ollama_response(query, context, model).await
576            }
577        }
578    }
579
580    /// Create generation prompt for dataset creation
581    fn create_generation_prompt(
582        &self,
583        schema: &SchemaDefinition,
584        count: usize,
585        context: Option<&str>,
586    ) -> String {
587        let mut prompt = format!(
588            "Generate {} rows of sample data following this schema:\n\n{:?}\n\n",
589            count, schema
590        );
591
592        if let Some(context) = context {
593            prompt.push_str(&format!("Additional context: {}\n\n", context));
594        }
595
596        prompt.push_str("Please generate the data in JSON format as an array of objects.");
597        prompt
598    }
599
600    /// Parse dataset response from LLM
601    fn parse_dataset_response(
602        &self,
603        response: &str,
604        _schema: &SchemaDefinition,
605    ) -> Result<Vec<HashMap<String, Value>>> {
606        // Try to parse as JSON array
607        match serde_json::from_str::<Vec<HashMap<String, Value>>>(response) {
608            Ok(data) => Ok(data),
609            Err(_) => {
610                // Try to extract JSON from response text
611                if let Some(json_start) = response.find('[') {
612                    if let Some(json_end) = response.rfind(']') {
613                        let json_part = &response[json_start..=json_end];
614                        serde_json::from_str(json_part).map_err(|e| {
615                            crate::Error::generic(format!("Failed to parse JSON: {}", e))
616                        })
617                    } else {
618                        Err(crate::Error::generic("No closing bracket found in response"))
619                    }
620                } else {
621                    Err(crate::Error::generic("No JSON array found in response"))
622                }
623            }
624        }
625    }
626
627    /// Generate OpenAI embedding
628    async fn generate_openai_embedding(&self, text: &str, model: &str) -> Result<Vec<f32>> {
629        let api_key = self
630            .config
631            .api_key
632            .as_ref()
633            .ok_or_else(|| crate::Error::generic("OpenAI API key not configured"))?;
634
635        let response = self
636            .client
637            .post("https://api.openai.com/v1/embeddings")
638            .header("Authorization", format!("Bearer {}", api_key))
639            .header("Content-Type", "application/json")
640            .json(&serde_json::json!({
641                "input": text,
642                "model": model
643            }))
644            .send()
645            .await?;
646
647        if !response.status().is_success() {
648            return Err(crate::Error::generic(format!("OpenAI API error: {}", response.status())));
649        }
650
651        let json: Value = response.json().await?;
652        let embedding = json["data"][0]["embedding"]
653            .as_array()
654            .ok_or_else(|| crate::Error::generic("Invalid embedding response format"))?;
655
656        Ok(embedding.iter().map(|v| v.as_f64().unwrap_or(0.0) as f32).collect())
657    }
658
659    /// Generate OpenAI compatible embedding
660    async fn generate_openai_compatible_embedding(
661        &self,
662        text: &str,
663        model: &str,
664    ) -> Result<Vec<f32>> {
665        let api_key = self
666            .config
667            .api_key
668            .as_ref()
669            .ok_or_else(|| crate::Error::generic("API key not configured"))?;
670
671        let response = self
672            .client
673            .post(format!("{}/embeddings", self.config.api_endpoint))
674            .header("Authorization", format!("Bearer {}", api_key))
675            .header("Content-Type", "application/json")
676            .json(&serde_json::json!({
677                "input": text,
678                "model": model
679            }))
680            .send()
681            .await?;
682
683        if !response.status().is_success() {
684            return Err(crate::Error::generic(format!("API error: {}", response.status())));
685        }
686
687        let json: Value = response.json().await?;
688        let embedding = json["data"][0]["embedding"]
689            .as_array()
690            .ok_or_else(|| crate::Error::generic("Invalid embedding response format"))?;
691
692        Ok(embedding.iter().map(|v| v.as_f64().unwrap_or(0.0) as f32).collect())
693    }
694
695    /// Generate OpenAI response
696    async fn generate_openai_response(
697        &self,
698        query: &str,
699        context: &str,
700        model: &str,
701    ) -> Result<String> {
702        let api_key = self
703            .config
704            .api_key
705            .as_ref()
706            .ok_or_else(|| crate::Error::generic("OpenAI API key not configured"))?;
707
708        let messages = vec![
709            serde_json::json!({
710                "role": "system",
711                "content": "You are a helpful assistant. Use the provided context to answer questions accurately."
712            }),
713            serde_json::json!({
714                "role": "user",
715                "content": format!("Context: {}\n\nQuestion: {}", context, query)
716            }),
717        ];
718
719        let response = self
720            .client
721            .post("https://api.openai.com/v1/chat/completions")
722            .header("Authorization", format!("Bearer {}", api_key))
723            .header("Content-Type", "application/json")
724            .json(&serde_json::json!({
725                "model": model,
726                "messages": messages,
727                "max_tokens": self.config.max_tokens,
728                "temperature": self.config.temperature,
729                "top_p": self.config.top_p
730            }))
731            .send()
732            .await?;
733
734        if !response.status().is_success() {
735            return Err(crate::Error::generic(format!("OpenAI API error: {}", response.status())));
736        }
737
738        let json: Value = response.json().await?;
739        let content = json["choices"][0]["message"]["content"]
740            .as_str()
741            .ok_or_else(|| crate::Error::generic("Invalid response format"))?;
742
743        Ok(content.to_string())
744    }
745
746    /// Generate Anthropic response
747    async fn generate_anthropic_response(
748        &self,
749        query: &str,
750        context: &str,
751        model: &str,
752    ) -> Result<String> {
753        let api_key = self
754            .config
755            .api_key
756            .as_ref()
757            .ok_or_else(|| crate::Error::generic("Anthropic API key not configured"))?;
758
759        let response = self
760            .client
761            .post(format!("{}/messages", self.config.api_endpoint))
762            .header("x-api-key", api_key)
763            .header("anthropic-version", "2023-06-01")
764            .header("Content-Type", "application/json")
765            .json(&serde_json::json!({
766                "model": model,
767                "max_tokens": self.config.max_tokens,
768                "temperature": self.config.temperature,
769                "messages": [{
770                    "role": "user",
771                    "content": format!("Context: {}\n\nQuestion: {}", context, query)
772                }]
773            }))
774            .send()
775            .await?;
776
777        if !response.status().is_success() {
778            return Err(crate::Error::generic(format!(
779                "Anthropic API error: {}",
780                response.status()
781            )));
782        }
783
784        let json: Value = response.json().await?;
785        let text = json["content"]
786            .as_array()
787            .and_then(|content| content.first())
788            .and_then(|entry| entry["text"].as_str())
789            .ok_or_else(|| crate::Error::generic("Invalid Anthropic response format"))?;
790
791        Ok(text.to_string())
792    }
793
794    /// Generate OpenAI compatible response
795    async fn generate_openai_compatible_response(
796        &self,
797        query: &str,
798        context: &str,
799        model: &str,
800    ) -> Result<String> {
801        let api_key = self
802            .config
803            .api_key
804            .as_ref()
805            .ok_or_else(|| crate::Error::generic("API key not configured"))?;
806
807        let messages = vec![
808            serde_json::json!({
809                "role": "system",
810                "content": "You are a helpful assistant. Use the provided context to answer questions accurately."
811            }),
812            serde_json::json!({
813                "role": "user",
814                "content": format!("Context: {}\n\nQuestion: {}", context, query)
815            }),
816        ];
817
818        let response = self
819            .client
820            .post(format!("{}/chat/completions", self.config.api_endpoint))
821            .header("Authorization", format!("Bearer {}", api_key))
822            .header("Content-Type", "application/json")
823            .json(&serde_json::json!({
824                "model": model,
825                "messages": messages,
826                "max_tokens": self.config.max_tokens,
827                "temperature": self.config.temperature,
828                "top_p": self.config.top_p
829            }))
830            .send()
831            .await?;
832
833        if !response.status().is_success() {
834            return Err(crate::Error::generic(format!("API error: {}", response.status())));
835        }
836
837        let json: Value = response.json().await?;
838        let content = json["choices"][0]["message"]["content"]
839            .as_str()
840            .ok_or_else(|| crate::Error::generic("Invalid response format"))?;
841
842        Ok(content.to_string())
843    }
844
845    /// Generate Ollama response
846    async fn generate_ollama_response(
847        &self,
848        query: &str,
849        context: &str,
850        model: &str,
851    ) -> Result<String> {
852        let response = self
853            .client
854            .post(format!("{}/generate", self.config.api_endpoint))
855            .header("Content-Type", "application/json")
856            .json(&serde_json::json!({
857                "model": model,
858                "prompt": format!("Context: {}\n\nQuestion: {}", context, query),
859                "stream": false,
860                "options": {
861                    "temperature": self.config.temperature,
862                    "top_p": self.config.top_p
863                }
864            }))
865            .send()
866            .await?;
867
868        if !response.status().is_success() {
869            return Err(crate::Error::generic(format!("Ollama API error: {}", response.status())));
870        }
871
872        let json: Value = response.json().await?;
873        let content = json["response"]
874            .as_str()
875            .ok_or_else(|| crate::Error::generic("Invalid Ollama response format"))?;
876
877        Ok(content.to_string())
878    }
879}
880
881impl Default for RagEngine {
882    fn default() -> Self {
883        use crate::rag::storage::InMemoryStorage;
884
885        // Create a default RAG engine with in-memory storage
886        // This is primarily for testing and compatibility purposes
887        let config = RagConfig::default();
888        let storage = Arc::new(InMemoryStorage::default());
889
890        // We can unwrap here since default config should be valid
891        Self::new(config, storage).expect("Failed to create default RagEngine")
892    }
893}
894
895/// Cosine similarity calculation
896fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
897    if a.len() != b.len() {
898        return 0.0;
899    }
900
901    let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
902    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
903    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
904
905    if norm_a == 0.0 || norm_b == 0.0 {
906        0.0
907    } else {
908        dot_product / (norm_a * norm_b)
909    }
910}
911
912/// RAG engine statistics
913#[derive(Debug, Clone, Serialize, Deserialize)]
914pub struct RagStats {
915    /// Total number of documents in the knowledge base
916    pub total_documents: usize,
917    /// Total number of chunks
918    pub total_chunks: usize,
919    /// Index size in bytes
920    pub index_size_bytes: u64,
921    /// Last updated timestamp
922    pub last_updated: chrono::DateTime<chrono::Utc>,
923    /// Cache hit rate (0.0 to 1.0)
924    pub cache_hit_rate: f32,
925    /// Average response time in milliseconds
926    pub average_response_time_ms: f32,
927}
928
929impl Default for RagStats {
930    fn default() -> Self {
931        Self {
932            total_documents: 0,
933            total_chunks: 0,
934            index_size_bytes: 0,
935            last_updated: chrono::Utc::now(),
936            cache_hit_rate: 0.0,
937            average_response_time_ms: 0.0,
938        }
939    }
940}
941
942/// Storage statistics
943#[derive(Debug, Clone)]
944pub struct StorageStats {
945    /// Total number of documents
946    pub total_documents: usize,
947    /// Total number of chunks
948    pub total_chunks: usize,
949    /// Index size in bytes
950    pub index_size_bytes: u64,
951    /// Last updated timestamp
952    pub last_updated: chrono::DateTime<chrono::Utc>,
953}
954
955impl Default for StorageStats {
956    fn default() -> Self {
957        Self {
958            total_documents: 0,
959            total_chunks: 0,
960            index_size_bytes: 0,
961            last_updated: chrono::Utc::now(),
962        }
963    }
964}
965
966#[cfg(test)]
967mod tests {
968    use super::*;
969
970    fn make_chunk(id: &str, content: &str) -> DocumentChunk {
971        DocumentChunk::new(
972            id.to_string(),
973            content.to_string(),
974            HashMap::new(),
975            Vec::new(),
976            "doc1".to_string(),
977            0,
978            content.split_whitespace().count(),
979        )
980    }
981
982    #[tokio::test]
983    async fn test_keyword_search_basic_term_matching() {
984        let engine = RagEngine::default();
985        let candidates = vec![
986            make_chunk("c1", "rust programming language systems"),
987            make_chunk("c2", "python scripting language web"),
988            make_chunk("c3", "java enterprise applications"),
989        ];
990
991        let results = engine.keyword_search("rust language", &candidates).await.unwrap();
992        assert!(!results.is_empty());
993        // First result should be c1 since it contains both "rust" and "language"
994        assert_eq!(results[0].chunk.id, "c1");
995    }
996
997    #[tokio::test]
998    async fn test_keyword_search_phrase_bonus() {
999        let engine = RagEngine::default();
1000        let candidates = vec![
1001            make_chunk("c1", "mock api server for testing mock endpoints"),
1002            make_chunk("c2", "this is a mock api server that works well"),
1003        ];
1004
1005        let results = engine.keyword_search("mock api server", &candidates).await.unwrap();
1006        assert!(!results.is_empty());
1007        // Both contain the phrase, but c2 has the exact contiguous phrase
1008        // Both should score, and both contain "mock api server" as an exact phrase
1009        assert!(results.len() >= 2);
1010    }
1011
1012    #[tokio::test]
1013    async fn test_keyword_search_empty_query() {
1014        let engine = RagEngine::default();
1015        let candidates = vec![make_chunk("c1", "some content here")];
1016        let results = engine.keyword_search("", &candidates).await.unwrap();
1017        assert!(results.is_empty());
1018    }
1019
1020    #[tokio::test]
1021    async fn test_keyword_search_no_match() {
1022        let engine = RagEngine::default();
1023        let candidates = vec![make_chunk("c1", "rust programming")];
1024        let results = engine.keyword_search("python django", &candidates).await.unwrap();
1025        assert!(results.is_empty());
1026    }
1027
1028    #[tokio::test]
1029    async fn test_cosine_similarity_identical() {
1030        let a = vec![1.0, 0.0, 0.0];
1031        let b = vec![1.0, 0.0, 0.0];
1032        let sim = cosine_similarity(&a, &b);
1033        assert!((sim - 1.0).abs() < 1e-6);
1034    }
1035
1036    #[tokio::test]
1037    async fn test_cosine_similarity_orthogonal() {
1038        let a = vec![1.0, 0.0];
1039        let b = vec![0.0, 1.0];
1040        let sim = cosine_similarity(&a, &b);
1041        assert!(sim.abs() < 1e-6);
1042    }
1043}