1use crate::rag::utils::Cache;
7use crate::rag::{
8 config::{EmbeddingProvider, RagConfig},
9 storage::DocumentStorage,
10};
11use crate::schema::SchemaDefinition;
12use crate::Result;
13use serde::{Deserialize, Serialize};
14use serde_json::Value;
15use std::cmp::Ordering;
16use std::collections::HashMap;
17use std::sync::Arc;
18use std::time::Duration;
19use tracing::debug;
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct DocumentChunk {
24 pub id: String,
26 pub content: String,
28 pub metadata: HashMap<String, String>,
30 pub embedding: Vec<f32>,
32 pub document_id: String,
34 pub position: usize,
36 pub length: usize,
38}
39
40impl DocumentChunk {
41 pub fn new(
43 id: String,
44 content: String,
45 metadata: HashMap<String, String>,
46 embedding: Vec<f32>,
47 document_id: String,
48 position: usize,
49 length: usize,
50 ) -> Self {
51 Self {
52 id,
53 content,
54 metadata,
55 embedding,
56 document_id,
57 position,
58 length,
59 }
60 }
61
62 pub fn size(&self) -> usize {
64 self.content.len()
65 }
66
67 pub fn is_empty(&self) -> bool {
69 self.content.is_empty()
70 }
71
72 pub fn get_metadata(&self, key: &str) -> Option<&String> {
74 self.metadata.get(key)
75 }
76
77 pub fn set_metadata(&mut self, key: String, value: String) {
79 self.metadata.insert(key, value);
80 }
81
82 pub fn similarity(&self, other: &DocumentChunk) -> f32 {
84 cosine_similarity(&self.embedding, &other.embedding)
85 }
86
87 pub fn preview(&self) -> String {
89 if self.content.len() > 100 {
90 format!("{}...", &self.content[..100])
91 } else {
92 self.content.clone()
93 }
94 }
95}
96
97#[derive(Debug, Clone)]
99pub struct SearchResult {
100 pub chunk: DocumentChunk,
102 pub score: f32,
104 pub rank: usize,
106}
107
108impl SearchResult {
109 pub fn new(chunk: DocumentChunk, score: f32, rank: usize) -> Self {
111 Self { chunk, score, rank }
112 }
113}
114
115impl PartialEq for SearchResult {
116 fn eq(&self, other: &Self) -> bool {
117 self.score == other.score
118 }
119}
120
121impl Eq for SearchResult {}
122
123impl PartialOrd for SearchResult {
124 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
125 Some(self.cmp(other))
126 }
127}
128
129impl Ord for SearchResult {
130 fn cmp(&self, other: &Self) -> Ordering {
131 self.partial_cmp(other).unwrap_or(Ordering::Equal)
132 }
133}
134
135pub struct RagEngine {
137 config: RagConfig,
139 storage: Arc<dyn DocumentStorage>,
141 client: reqwest::Client,
143 total_response_time_ms: f64,
145 response_count: usize,
147 embedding_cache: Cache<String, Vec<f32>>,
149}
150
151impl std::fmt::Debug for RagEngine {
152 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153 f.debug_struct("RagEngine")
154 .field("config", &self.config)
155 .field("storage", &"<DocumentStorage>")
156 .field("client", &"<reqwest::Client>")
157 .field("total_response_time_ms", &self.total_response_time_ms)
158 .field("response_count", &self.response_count)
159 .field("embedding_cache", &"<Cache>")
160 .finish()
161 }
162}
163
164impl RagEngine {
165 pub fn new(config: RagConfig, storage: Arc<dyn DocumentStorage>) -> Result<Self> {
167 let client = reqwest::ClientBuilder::new().timeout(config.timeout_duration()).build()?;
168
169 let cache_ttl = config.cache_ttl_duration().as_secs();
170
171 Ok(Self {
172 config,
173 storage,
174 client,
175 total_response_time_ms: 0.0,
176 response_count: 0,
177 embedding_cache: Cache::new(cache_ttl, 1000), })
179 }
180
181 fn record_response_time(&mut self, duration: Duration) {
183 let ms = duration.as_millis() as f64;
184 self.total_response_time_ms += ms;
185 self.response_count += 1;
186 }
187
188 pub fn config(&self) -> &RagConfig {
190 &self.config
191 }
192
193 pub fn storage(&self) -> &Arc<dyn DocumentStorage> {
195 &self.storage
196 }
197
198 pub fn update_config(&mut self, config: RagConfig) -> Result<()> {
200 config.validate()?;
201 self.config = config;
202 Ok(())
203 }
204
205 pub async fn add_document(
207 &self,
208 document_id: String,
209 content: String,
210 metadata: HashMap<String, String>,
211 ) -> Result<()> {
212 debug!("Adding document: {}", document_id);
213
214 let chunks = self.create_chunks(document_id.clone(), content, metadata).await?;
216
217 let chunks_with_embeddings = self.generate_embeddings(chunks).await?;
219
220 self.storage.store_chunks(chunks_with_embeddings).await?;
222
223 debug!("Successfully added document: {}", document_id);
224 Ok(())
225 }
226
227 pub async fn search(&mut self, query: &str, top_k: Option<usize>) -> Result<Vec<SearchResult>> {
229 let start = tokio::time::Instant::now();
230 let top_k = top_k.unwrap_or(self.config.top_k);
231 debug!("Searching for: {} (top_k: {})", query, top_k);
232
233 let query_embedding = self.generate_query_embedding(query).await?;
235
236 let candidates = self.storage.search_similar(&query_embedding, top_k * 2).await?; let results = if self.config.hybrid_search {
241 self.hybrid_search(query, &query_embedding, candidates).await?
242 } else {
243 self.semantic_search(&query_embedding, candidates).await?
244 };
245
246 debug!("Found {} relevant chunks", results.len());
247 let duration = start.elapsed();
248 self.record_response_time(duration);
249 Ok(results)
250 }
251
252 pub async fn generate(&mut self, query: &str, context: Option<&str>) -> Result<String> {
254 let start = tokio::time::Instant::now();
255 debug!("Generating response for query: {}", query);
256
257 let search_results = self.search(query, None).await?;
259
260 let rag_context = self.build_context(&search_results, context);
262
263 let response = self.generate_with_llm(query, &rag_context).await?;
265
266 debug!("Generated response ({} chars)", response.len());
267 let duration = start.elapsed();
268 self.record_response_time(duration);
269 Ok(response)
270 }
271
272 pub async fn generate_dataset(
274 &mut self,
275 schema: &SchemaDefinition,
276 count: usize,
277 context: Option<&str>,
278 ) -> Result<Vec<HashMap<String, Value>>> {
279 let start = tokio::time::Instant::now();
280 debug!("Generating dataset with {} rows using schema: {}", count, schema.name);
281
282 let prompt = self.create_generation_prompt(schema, count, context);
284
285 let response = self.generate(&prompt, None).await?;
287
288 let dataset = self.parse_dataset_response(&response, schema)?;
290
291 debug!("Generated dataset with {} rows", dataset.len());
292 let duration = start.elapsed();
293 self.record_response_time(duration);
294 Ok(dataset)
295 }
296
297 pub async fn get_stats(&self) -> Result<RagStats> {
299 let storage_stats = self.storage.get_stats().await?;
300
301 let average_response_time_ms = if self.response_count > 0 {
302 (self.total_response_time_ms / self.response_count as f64) as f32
303 } else {
304 0.0
305 };
306
307 Ok(RagStats {
308 total_documents: storage_stats.total_documents,
309 total_chunks: storage_stats.total_chunks,
310 index_size_bytes: storage_stats.index_size_bytes,
311 last_updated: storage_stats.last_updated,
312 cache_hit_rate: self.embedding_cache.hit_rate(),
313 average_response_time_ms,
314 })
315 }
316
317 async fn create_chunks(
319 &self,
320 document_id: String,
321 content: String,
322 metadata: HashMap<String, String>,
323 ) -> Result<Vec<DocumentChunk>> {
324 let mut chunks = Vec::new();
325 let words: Vec<&str> = content.split_whitespace().collect();
326 let chunk_size = self.config.chunk_size;
327 let overlap = self.config.chunk_overlap;
328
329 for (i, chunk_start) in (0..words.len()).step_by(chunk_size - overlap).enumerate() {
330 let chunk_end = (chunk_start + chunk_size).min(words.len());
331 let chunk_words: Vec<&str> = words[chunk_start..chunk_end].to_vec();
332 let chunk_content = chunk_words.join(" ");
333
334 if !chunk_content.is_empty() {
335 let chunk_id = format!("{}_chunk_{}", document_id, i);
336
337 chunks.push(DocumentChunk::new(
338 chunk_id,
339 chunk_content,
340 metadata.clone(),
341 Vec::new(), document_id.clone(),
343 i,
344 chunk_words.len(),
345 ));
346 }
347 }
348
349 Ok(chunks)
350 }
351
352 async fn generate_embeddings(&self, chunks: Vec<DocumentChunk>) -> Result<Vec<DocumentChunk>> {
354 let mut chunks_with_embeddings = Vec::new();
355
356 for chunk in chunks {
357 let embedding = self.generate_embedding(&chunk.content).await?;
358 let mut chunk_with_embedding = chunk;
359 chunk_with_embedding.embedding = embedding;
360 chunks_with_embeddings.push(chunk_with_embedding);
361 }
362
363 Ok(chunks_with_embeddings)
364 }
365
366 async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
368 let provider = &self.config.embedding_provider;
369 let model = &self.config.embedding_model;
370
371 match provider {
372 EmbeddingProvider::OpenAI => self.generate_openai_embedding(text, model).await,
373 EmbeddingProvider::OpenAICompatible => {
374 self.generate_openai_compatible_embedding(text, model).await
375 }
376 EmbeddingProvider::Ollama => {
377 self.generate_openai_compatible_embedding(text, model).await
379 }
380 }
381 }
382
383 async fn generate_query_embedding(&mut self, query: &str) -> Result<Vec<f32>> {
385 if let Some(embedding) = self.embedding_cache.get(&query.to_string()) {
387 return Ok(embedding);
388 }
389
390 let embedding = self.generate_embedding(query).await?;
392
393 self.embedding_cache.put(query.to_string(), embedding.clone());
395
396 Ok(embedding)
397 }
398
399 async fn semantic_search(
401 &self,
402 query_embedding: &[f32],
403 candidates: Vec<DocumentChunk>,
404 ) -> Result<Vec<SearchResult>> {
405 let mut results = Vec::new();
406
407 for (rank, chunk) in candidates.iter().enumerate() {
409 let score = cosine_similarity(query_embedding, &chunk.embedding);
410
411 results.push(SearchResult::new(chunk.clone(), score, rank));
412 }
413
414 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
416 results.retain(|r| r.score >= self.config.similarity_threshold);
417
418 results.truncate(self.config.top_k);
420
421 Ok(results)
422 }
423
424 async fn hybrid_search(
426 &self,
427 query: &str,
428 query_embedding: &[f32],
429 candidates: Vec<DocumentChunk>,
430 ) -> Result<Vec<SearchResult>> {
431 let mut results = Vec::new();
432
433 let semantic_results = self.semantic_search(query_embedding, candidates.clone()).await?;
435
436 let keyword_results = self.keyword_search(query, &candidates).await?;
438
439 let semantic_weight = self.config.semantic_weight;
441 let keyword_weight = self.config.keyword_weight;
442
443 for (rank, chunk) in candidates.iter().enumerate() {
444 let semantic_score = semantic_results
445 .iter()
446 .find(|r| r.chunk.id == chunk.id)
447 .map(|r| r.score)
448 .unwrap_or(0.0);
449
450 let keyword_score = keyword_results
451 .iter()
452 .find(|r| r.chunk.id == chunk.id)
453 .map(|r| r.score)
454 .unwrap_or(0.0);
455
456 let combined_score = semantic_score * semantic_weight + keyword_score * keyword_weight;
457
458 results.push(SearchResult::new(chunk.clone(), combined_score, rank));
459 }
460
461 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
463 results.retain(|r| r.score >= self.config.similarity_threshold);
464 results.truncate(self.config.top_k);
465
466 Ok(results)
467 }
468
469 async fn keyword_search(
475 &self,
476 query: &str,
477 candidates: &[DocumentChunk],
478 ) -> Result<Vec<SearchResult>> {
479 let query_lower = query.to_lowercase();
480 let query_terms: Vec<&str> = query_lower.split_whitespace().collect();
481
482 if query_terms.is_empty() {
483 return Ok(Vec::new());
484 }
485
486 let num_candidates = candidates.len();
487 let mut results = Vec::new();
488
489 for (rank, chunk) in candidates.iter().enumerate() {
490 let content_lower = chunk.content.to_lowercase();
491 let content_words: Vec<&str> = content_lower.split_whitespace().collect();
492
493 if content_words.is_empty() {
494 continue;
495 }
496
497 let matching_terms = query_terms
499 .iter()
500 .filter(|term| content_words.iter().any(|w| w.contains(*term)))
501 .count();
502 let tf_score = matching_terms as f32 / query_terms.len() as f32;
503
504 let mut idf_sum = 0.0f32;
507 for term in &query_terms {
508 let docs_with_term = candidates
509 .iter()
510 .filter(|c| c.content.to_lowercase().contains(term))
511 .count()
512 .max(1);
513 idf_sum += ((num_candidates as f32) / docs_with_term as f32).ln() + 1.0;
514 }
515 let idf_score = idf_sum / query_terms.len() as f32;
516
517 let phrase_bonus = if query_terms.len() > 1 && content_lower.contains(&query_lower) {
519 0.3
520 } else {
521 0.0
522 };
523
524 let score = (tf_score * idf_score + phrase_bonus).min(1.0);
525
526 if score > 0.0 {
527 results.push(SearchResult::new(chunk.clone(), score, rank));
528 }
529 }
530
531 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
533
534 Ok(results)
535 }
536
537 fn build_context(
539 &self,
540 search_results: &[SearchResult],
541 additional_context: Option<&str>,
542 ) -> String {
543 let mut context_parts = Vec::new();
544
545 for result in search_results {
547 context_parts
548 .push(format!("Content: {}\nRelevance: {:.2}", result.chunk.content, result.score));
549 }
550
551 if let Some(context) = additional_context {
553 context_parts.push(format!("Additional Context: {}", context));
554 }
555
556 context_parts.join("\n\n")
557 }
558
559 async fn generate_with_llm(&self, query: &str, context: &str) -> Result<String> {
561 let provider = &self.config.provider;
562 let model = &self.config.model;
563
564 match provider {
565 crate::rag::config::LlmProvider::OpenAI => {
566 self.generate_openai_response(query, context, model).await
567 }
568 crate::rag::config::LlmProvider::Anthropic => {
569 self.generate_anthropic_response(query, context, model).await
570 }
571 crate::rag::config::LlmProvider::OpenAICompatible => {
572 self.generate_openai_compatible_response(query, context, model).await
573 }
574 crate::rag::config::LlmProvider::Ollama => {
575 self.generate_ollama_response(query, context, model).await
576 }
577 }
578 }
579
580 fn create_generation_prompt(
582 &self,
583 schema: &SchemaDefinition,
584 count: usize,
585 context: Option<&str>,
586 ) -> String {
587 let mut prompt = format!(
588 "Generate {} rows of sample data following this schema:\n\n{:?}\n\n",
589 count, schema
590 );
591
592 if let Some(context) = context {
593 prompt.push_str(&format!("Additional context: {}\n\n", context));
594 }
595
596 prompt.push_str("Please generate the data in JSON format as an array of objects.");
597 prompt
598 }
599
600 fn parse_dataset_response(
602 &self,
603 response: &str,
604 _schema: &SchemaDefinition,
605 ) -> Result<Vec<HashMap<String, Value>>> {
606 match serde_json::from_str::<Vec<HashMap<String, Value>>>(response) {
608 Ok(data) => Ok(data),
609 Err(_) => {
610 if let Some(json_start) = response.find('[') {
612 if let Some(json_end) = response.rfind(']') {
613 let json_part = &response[json_start..=json_end];
614 serde_json::from_str(json_part).map_err(|e| {
615 crate::Error::generic(format!("Failed to parse JSON: {}", e))
616 })
617 } else {
618 Err(crate::Error::generic("No closing bracket found in response"))
619 }
620 } else {
621 Err(crate::Error::generic("No JSON array found in response"))
622 }
623 }
624 }
625 }
626
627 async fn generate_openai_embedding(&self, text: &str, model: &str) -> Result<Vec<f32>> {
629 let api_key = self
630 .config
631 .api_key
632 .as_ref()
633 .ok_or_else(|| crate::Error::generic("OpenAI API key not configured"))?;
634
635 let response = self
636 .client
637 .post("https://api.openai.com/v1/embeddings")
638 .header("Authorization", format!("Bearer {}", api_key))
639 .header("Content-Type", "application/json")
640 .json(&serde_json::json!({
641 "input": text,
642 "model": model
643 }))
644 .send()
645 .await?;
646
647 if !response.status().is_success() {
648 return Err(crate::Error::generic(format!("OpenAI API error: {}", response.status())));
649 }
650
651 let json: Value = response.json().await?;
652 let embedding = json["data"][0]["embedding"]
653 .as_array()
654 .ok_or_else(|| crate::Error::generic("Invalid embedding response format"))?;
655
656 Ok(embedding.iter().map(|v| v.as_f64().unwrap_or(0.0) as f32).collect())
657 }
658
659 async fn generate_openai_compatible_embedding(
661 &self,
662 text: &str,
663 model: &str,
664 ) -> Result<Vec<f32>> {
665 let api_key = self
666 .config
667 .api_key
668 .as_ref()
669 .ok_or_else(|| crate::Error::generic("API key not configured"))?;
670
671 let response = self
672 .client
673 .post(format!("{}/embeddings", self.config.api_endpoint))
674 .header("Authorization", format!("Bearer {}", api_key))
675 .header("Content-Type", "application/json")
676 .json(&serde_json::json!({
677 "input": text,
678 "model": model
679 }))
680 .send()
681 .await?;
682
683 if !response.status().is_success() {
684 return Err(crate::Error::generic(format!("API error: {}", response.status())));
685 }
686
687 let json: Value = response.json().await?;
688 let embedding = json["data"][0]["embedding"]
689 .as_array()
690 .ok_or_else(|| crate::Error::generic("Invalid embedding response format"))?;
691
692 Ok(embedding.iter().map(|v| v.as_f64().unwrap_or(0.0) as f32).collect())
693 }
694
695 async fn generate_openai_response(
697 &self,
698 query: &str,
699 context: &str,
700 model: &str,
701 ) -> Result<String> {
702 let api_key = self
703 .config
704 .api_key
705 .as_ref()
706 .ok_or_else(|| crate::Error::generic("OpenAI API key not configured"))?;
707
708 let messages = vec![
709 serde_json::json!({
710 "role": "system",
711 "content": "You are a helpful assistant. Use the provided context to answer questions accurately."
712 }),
713 serde_json::json!({
714 "role": "user",
715 "content": format!("Context: {}\n\nQuestion: {}", context, query)
716 }),
717 ];
718
719 let response = self
720 .client
721 .post("https://api.openai.com/v1/chat/completions")
722 .header("Authorization", format!("Bearer {}", api_key))
723 .header("Content-Type", "application/json")
724 .json(&serde_json::json!({
725 "model": model,
726 "messages": messages,
727 "max_tokens": self.config.max_tokens,
728 "temperature": self.config.temperature,
729 "top_p": self.config.top_p
730 }))
731 .send()
732 .await?;
733
734 if !response.status().is_success() {
735 return Err(crate::Error::generic(format!("OpenAI API error: {}", response.status())));
736 }
737
738 let json: Value = response.json().await?;
739 let content = json["choices"][0]["message"]["content"]
740 .as_str()
741 .ok_or_else(|| crate::Error::generic("Invalid response format"))?;
742
743 Ok(content.to_string())
744 }
745
746 async fn generate_anthropic_response(
748 &self,
749 query: &str,
750 context: &str,
751 model: &str,
752 ) -> Result<String> {
753 let api_key = self
754 .config
755 .api_key
756 .as_ref()
757 .ok_or_else(|| crate::Error::generic("Anthropic API key not configured"))?;
758
759 let response = self
760 .client
761 .post(format!("{}/messages", self.config.api_endpoint))
762 .header("x-api-key", api_key)
763 .header("anthropic-version", "2023-06-01")
764 .header("Content-Type", "application/json")
765 .json(&serde_json::json!({
766 "model": model,
767 "max_tokens": self.config.max_tokens,
768 "temperature": self.config.temperature,
769 "messages": [{
770 "role": "user",
771 "content": format!("Context: {}\n\nQuestion: {}", context, query)
772 }]
773 }))
774 .send()
775 .await?;
776
777 if !response.status().is_success() {
778 return Err(crate::Error::generic(format!(
779 "Anthropic API error: {}",
780 response.status()
781 )));
782 }
783
784 let json: Value = response.json().await?;
785 let text = json["content"]
786 .as_array()
787 .and_then(|content| content.first())
788 .and_then(|entry| entry["text"].as_str())
789 .ok_or_else(|| crate::Error::generic("Invalid Anthropic response format"))?;
790
791 Ok(text.to_string())
792 }
793
794 async fn generate_openai_compatible_response(
796 &self,
797 query: &str,
798 context: &str,
799 model: &str,
800 ) -> Result<String> {
801 let api_key = self
802 .config
803 .api_key
804 .as_ref()
805 .ok_or_else(|| crate::Error::generic("API key not configured"))?;
806
807 let messages = vec![
808 serde_json::json!({
809 "role": "system",
810 "content": "You are a helpful assistant. Use the provided context to answer questions accurately."
811 }),
812 serde_json::json!({
813 "role": "user",
814 "content": format!("Context: {}\n\nQuestion: {}", context, query)
815 }),
816 ];
817
818 let response = self
819 .client
820 .post(format!("{}/chat/completions", self.config.api_endpoint))
821 .header("Authorization", format!("Bearer {}", api_key))
822 .header("Content-Type", "application/json")
823 .json(&serde_json::json!({
824 "model": model,
825 "messages": messages,
826 "max_tokens": self.config.max_tokens,
827 "temperature": self.config.temperature,
828 "top_p": self.config.top_p
829 }))
830 .send()
831 .await?;
832
833 if !response.status().is_success() {
834 return Err(crate::Error::generic(format!("API error: {}", response.status())));
835 }
836
837 let json: Value = response.json().await?;
838 let content = json["choices"][0]["message"]["content"]
839 .as_str()
840 .ok_or_else(|| crate::Error::generic("Invalid response format"))?;
841
842 Ok(content.to_string())
843 }
844
845 async fn generate_ollama_response(
847 &self,
848 query: &str,
849 context: &str,
850 model: &str,
851 ) -> Result<String> {
852 let response = self
853 .client
854 .post(format!("{}/generate", self.config.api_endpoint))
855 .header("Content-Type", "application/json")
856 .json(&serde_json::json!({
857 "model": model,
858 "prompt": format!("Context: {}\n\nQuestion: {}", context, query),
859 "stream": false,
860 "options": {
861 "temperature": self.config.temperature,
862 "top_p": self.config.top_p
863 }
864 }))
865 .send()
866 .await?;
867
868 if !response.status().is_success() {
869 return Err(crate::Error::generic(format!("Ollama API error: {}", response.status())));
870 }
871
872 let json: Value = response.json().await?;
873 let content = json["response"]
874 .as_str()
875 .ok_or_else(|| crate::Error::generic("Invalid Ollama response format"))?;
876
877 Ok(content.to_string())
878 }
879}
880
881impl Default for RagEngine {
882 fn default() -> Self {
883 use crate::rag::storage::InMemoryStorage;
884
885 let config = RagConfig::default();
888 let storage = Arc::new(InMemoryStorage::default());
889
890 Self::new(config, storage).expect("Failed to create default RagEngine")
892 }
893}
894
895fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
897 if a.len() != b.len() {
898 return 0.0;
899 }
900
901 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
902 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
903 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
904
905 if norm_a == 0.0 || norm_b == 0.0 {
906 0.0
907 } else {
908 dot_product / (norm_a * norm_b)
909 }
910}
911
912#[derive(Debug, Clone, Serialize, Deserialize)]
914pub struct RagStats {
915 pub total_documents: usize,
917 pub total_chunks: usize,
919 pub index_size_bytes: u64,
921 pub last_updated: chrono::DateTime<chrono::Utc>,
923 pub cache_hit_rate: f32,
925 pub average_response_time_ms: f32,
927}
928
929impl Default for RagStats {
930 fn default() -> Self {
931 Self {
932 total_documents: 0,
933 total_chunks: 0,
934 index_size_bytes: 0,
935 last_updated: chrono::Utc::now(),
936 cache_hit_rate: 0.0,
937 average_response_time_ms: 0.0,
938 }
939 }
940}
941
942#[derive(Debug, Clone)]
944pub struct StorageStats {
945 pub total_documents: usize,
947 pub total_chunks: usize,
949 pub index_size_bytes: u64,
951 pub last_updated: chrono::DateTime<chrono::Utc>,
953}
954
955impl Default for StorageStats {
956 fn default() -> Self {
957 Self {
958 total_documents: 0,
959 total_chunks: 0,
960 index_size_bytes: 0,
961 last_updated: chrono::Utc::now(),
962 }
963 }
964}
965
966#[cfg(test)]
967mod tests {
968 use super::*;
969
970 fn make_chunk(id: &str, content: &str) -> DocumentChunk {
971 DocumentChunk::new(
972 id.to_string(),
973 content.to_string(),
974 HashMap::new(),
975 Vec::new(),
976 "doc1".to_string(),
977 0,
978 content.split_whitespace().count(),
979 )
980 }
981
982 #[tokio::test]
983 async fn test_keyword_search_basic_term_matching() {
984 let engine = RagEngine::default();
985 let candidates = vec![
986 make_chunk("c1", "rust programming language systems"),
987 make_chunk("c2", "python scripting language web"),
988 make_chunk("c3", "java enterprise applications"),
989 ];
990
991 let results = engine.keyword_search("rust language", &candidates).await.unwrap();
992 assert!(!results.is_empty());
993 assert_eq!(results[0].chunk.id, "c1");
995 }
996
997 #[tokio::test]
998 async fn test_keyword_search_phrase_bonus() {
999 let engine = RagEngine::default();
1000 let candidates = vec![
1001 make_chunk("c1", "mock api server for testing mock endpoints"),
1002 make_chunk("c2", "this is a mock api server that works well"),
1003 ];
1004
1005 let results = engine.keyword_search("mock api server", &candidates).await.unwrap();
1006 assert!(!results.is_empty());
1007 assert!(results.len() >= 2);
1010 }
1011
1012 #[tokio::test]
1013 async fn test_keyword_search_empty_query() {
1014 let engine = RagEngine::default();
1015 let candidates = vec![make_chunk("c1", "some content here")];
1016 let results = engine.keyword_search("", &candidates).await.unwrap();
1017 assert!(results.is_empty());
1018 }
1019
1020 #[tokio::test]
1021 async fn test_keyword_search_no_match() {
1022 let engine = RagEngine::default();
1023 let candidates = vec![make_chunk("c1", "rust programming")];
1024 let results = engine.keyword_search("python django", &candidates).await.unwrap();
1025 assert!(results.is_empty());
1026 }
1027
1028 #[tokio::test]
1029 async fn test_cosine_similarity_identical() {
1030 let a = vec![1.0, 0.0, 0.0];
1031 let b = vec![1.0, 0.0, 0.0];
1032 let sim = cosine_similarity(&a, &b);
1033 assert!((sim - 1.0).abs() < 1e-6);
1034 }
1035
1036 #[tokio::test]
1037 async fn test_cosine_similarity_orthogonal() {
1038 let a = vec![1.0, 0.0];
1039 let b = vec![0.0, 1.0];
1040 let sim = cosine_similarity(&a, &b);
1041 assert!(sim.abs() < 1e-6);
1042 }
1043}