1use crate::rag::utils::Cache;
7use crate::rag::{
8 config::{EmbeddingProvider, RagConfig},
9 storage::DocumentStorage,
10};
11use crate::schema::SchemaDefinition;
12use mockforge_core::Result;
13use serde::{Deserialize, Serialize};
14use serde_json::Value;
15use std::cmp::Ordering;
16use std::collections::HashMap;
17use std::sync::Arc;
18use std::time::Duration;
19use tracing::debug;
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct DocumentChunk {
24 pub id: String,
26 pub content: String,
28 pub metadata: HashMap<String, String>,
30 pub embedding: Vec<f32>,
32 pub document_id: String,
34 pub position: usize,
36 pub length: usize,
38}
39
40impl DocumentChunk {
41 pub fn new(
43 id: String,
44 content: String,
45 metadata: HashMap<String, String>,
46 embedding: Vec<f32>,
47 document_id: String,
48 position: usize,
49 length: usize,
50 ) -> Self {
51 Self {
52 id,
53 content,
54 metadata,
55 embedding,
56 document_id,
57 position,
58 length,
59 }
60 }
61
62 pub fn size(&self) -> usize {
64 self.content.len()
65 }
66
67 pub fn is_empty(&self) -> bool {
69 self.content.is_empty()
70 }
71
72 pub fn get_metadata(&self, key: &str) -> Option<&String> {
74 self.metadata.get(key)
75 }
76
77 pub fn set_metadata(&mut self, key: String, value: String) {
79 self.metadata.insert(key, value);
80 }
81
82 pub fn similarity(&self, other: &DocumentChunk) -> f32 {
84 cosine_similarity(&self.embedding, &other.embedding)
85 }
86
87 pub fn preview(&self) -> String {
89 if self.content.len() > 100 {
90 format!("{}...", &self.content[..100])
91 } else {
92 self.content.clone()
93 }
94 }
95}
96
97#[derive(Debug, Clone)]
99pub struct SearchResult {
100 pub chunk: DocumentChunk,
102 pub score: f32,
104 pub rank: usize,
106}
107
108impl SearchResult {
109 pub fn new(chunk: DocumentChunk, score: f32, rank: usize) -> Self {
111 Self { chunk, score, rank }
112 }
113}
114
115impl PartialEq for SearchResult {
116 fn eq(&self, other: &Self) -> bool {
117 self.score == other.score
118 }
119}
120
121impl Eq for SearchResult {}
122
123impl PartialOrd for SearchResult {
124 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
125 Some(self.cmp(other))
126 }
127}
128
129impl Ord for SearchResult {
130 fn cmp(&self, other: &Self) -> Ordering {
131 self.partial_cmp(other).unwrap_or(Ordering::Equal)
132 }
133}
134
135pub struct RagEngine {
137 config: RagConfig,
139 storage: Arc<dyn DocumentStorage>,
141 client: reqwest::Client,
143 total_response_time_ms: f64,
145 response_count: usize,
147 embedding_cache: Cache<String, Vec<f32>>,
149}
150
151impl std::fmt::Debug for RagEngine {
152 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153 f.debug_struct("RagEngine")
154 .field("config", &self.config)
155 .field("storage", &"<DocumentStorage>")
156 .field("client", &"<reqwest::Client>")
157 .field("total_response_time_ms", &self.total_response_time_ms)
158 .field("response_count", &self.response_count)
159 .field("embedding_cache", &"<Cache>")
160 .finish()
161 }
162}
163
164impl RagEngine {
165 pub fn new(config: RagConfig, storage: Arc<dyn DocumentStorage>) -> Result<Self> {
167 let client = reqwest::ClientBuilder::new().timeout(config.timeout_duration()).build()?;
168
169 let cache_ttl = config.cache_ttl_duration().as_secs();
170
171 Ok(Self {
172 config,
173 storage,
174 client,
175 total_response_time_ms: 0.0,
176 response_count: 0,
177 embedding_cache: Cache::new(cache_ttl, 1000), })
179 }
180
181 fn record_response_time(&mut self, duration: Duration) {
183 let ms = duration.as_millis() as f64;
184 self.total_response_time_ms += ms;
185 self.response_count += 1;
186 }
187
188 pub fn config(&self) -> &RagConfig {
190 &self.config
191 }
192
193 pub fn storage(&self) -> &Arc<dyn DocumentStorage> {
195 &self.storage
196 }
197
198 pub fn update_config(&mut self, config: RagConfig) -> Result<()> {
200 config.validate()?;
201 self.config = config;
202 Ok(())
203 }
204
205 pub async fn add_document(
207 &self,
208 document_id: String,
209 content: String,
210 metadata: HashMap<String, String>,
211 ) -> Result<()> {
212 debug!("Adding document: {}", document_id);
213
214 let chunks = self.create_chunks(document_id.clone(), content, metadata).await?;
216
217 let chunks_with_embeddings = self.generate_embeddings(chunks).await?;
219
220 self.storage.store_chunks(chunks_with_embeddings).await?;
222
223 debug!("Successfully added document: {}", document_id);
224 Ok(())
225 }
226
227 pub async fn search(&mut self, query: &str, top_k: Option<usize>) -> Result<Vec<SearchResult>> {
229 let start = tokio::time::Instant::now();
230 let top_k = top_k.unwrap_or(self.config.top_k);
231 debug!("Searching for: {} (top_k: {})", query, top_k);
232
233 let query_embedding = self.generate_query_embedding(query).await?;
235
236 let candidates = self.storage.search_similar(&query_embedding, top_k * 2).await?; let results = if self.config.hybrid_search {
241 self.hybrid_search(query, &query_embedding, candidates).await?
242 } else {
243 self.semantic_search(&query_embedding, candidates).await?
244 };
245
246 debug!("Found {} relevant chunks", results.len());
247 let duration = start.elapsed();
248 self.record_response_time(duration);
249 Ok(results)
250 }
251
252 pub async fn generate(&mut self, query: &str, context: Option<&str>) -> Result<String> {
254 let start = tokio::time::Instant::now();
255 debug!("Generating response for query: {}", query);
256
257 let search_results = self.search(query, None).await?;
259
260 let rag_context = self.build_context(&search_results, context);
262
263 let response = self.generate_with_llm(query, &rag_context).await?;
265
266 debug!("Generated response ({} chars)", response.len());
267 let duration = start.elapsed();
268 self.record_response_time(duration);
269 Ok(response)
270 }
271
272 pub async fn generate_dataset(
274 &mut self,
275 schema: &SchemaDefinition,
276 count: usize,
277 context: Option<&str>,
278 ) -> Result<Vec<HashMap<String, Value>>> {
279 let start = tokio::time::Instant::now();
280 debug!("Generating dataset with {} rows using schema: {}", count, schema.name);
281
282 let prompt = self.create_generation_prompt(schema, count, context);
284
285 let response = self.generate(&prompt, None).await?;
287
288 let dataset = self.parse_dataset_response(&response, schema)?;
290
291 debug!("Generated dataset with {} rows", dataset.len());
292 let duration = start.elapsed();
293 self.record_response_time(duration);
294 Ok(dataset)
295 }
296
297 pub async fn get_stats(&self) -> Result<RagStats> {
299 let storage_stats = self.storage.get_stats().await?;
300
301 let average_response_time_ms = if self.response_count > 0 {
302 (self.total_response_time_ms / self.response_count as f64) as f32
303 } else {
304 0.0
305 };
306
307 Ok(RagStats {
308 total_documents: storage_stats.total_documents,
309 total_chunks: storage_stats.total_chunks,
310 index_size_bytes: storage_stats.index_size_bytes,
311 last_updated: storage_stats.last_updated,
312 cache_hit_rate: self.embedding_cache.hit_rate(),
313 average_response_time_ms,
314 })
315 }
316
317 async fn create_chunks(
319 &self,
320 document_id: String,
321 content: String,
322 metadata: HashMap<String, String>,
323 ) -> Result<Vec<DocumentChunk>> {
324 let mut chunks = Vec::new();
325 let words: Vec<&str> = content.split_whitespace().collect();
326 let chunk_size = self.config.chunk_size;
327 let overlap = self.config.chunk_overlap;
328
329 for (i, chunk_start) in (0..words.len()).step_by(chunk_size - overlap).enumerate() {
330 let chunk_end = (chunk_start + chunk_size).min(words.len());
331 let chunk_words: Vec<&str> = words[chunk_start..chunk_end].to_vec();
332 let chunk_content = chunk_words.join(" ");
333
334 if !chunk_content.is_empty() {
335 let chunk_id = format!("{}_chunk_{}", document_id, i);
336
337 chunks.push(DocumentChunk::new(
338 chunk_id,
339 chunk_content,
340 metadata.clone(),
341 Vec::new(), document_id.clone(),
343 i,
344 chunk_words.len(),
345 ));
346 }
347 }
348
349 Ok(chunks)
350 }
351
352 async fn generate_embeddings(&self, chunks: Vec<DocumentChunk>) -> Result<Vec<DocumentChunk>> {
354 let mut chunks_with_embeddings = Vec::new();
355
356 for chunk in chunks {
357 let embedding = self.generate_embedding(&chunk.content).await?;
358 let mut chunk_with_embedding = chunk;
359 chunk_with_embedding.embedding = embedding;
360 chunks_with_embeddings.push(chunk_with_embedding);
361 }
362
363 Ok(chunks_with_embeddings)
364 }
365
366 async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
368 let provider = &self.config.embedding_provider;
369 let model = &self.config.embedding_model;
370
371 match provider {
372 EmbeddingProvider::OpenAI => self.generate_openai_embedding(text, model).await,
373 EmbeddingProvider::OpenAICompatible => {
374 self.generate_openai_compatible_embedding(text, model).await
375 }
376 }
377 }
378
379 async fn generate_query_embedding(&mut self, query: &str) -> Result<Vec<f32>> {
381 if let Some(embedding) = self.embedding_cache.get(&query.to_string()) {
383 return Ok(embedding);
384 }
385
386 let embedding = self.generate_embedding(query).await?;
388
389 self.embedding_cache.put(query.to_string(), embedding.clone());
391
392 Ok(embedding)
393 }
394
395 async fn semantic_search(
397 &self,
398 query_embedding: &[f32],
399 candidates: Vec<DocumentChunk>,
400 ) -> Result<Vec<SearchResult>> {
401 let mut results = Vec::new();
402
403 for (rank, chunk) in candidates.iter().enumerate() {
405 let score = cosine_similarity(query_embedding, &chunk.embedding);
406
407 results.push(SearchResult::new(chunk.clone(), score, rank));
408 }
409
410 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
412 results.retain(|r| r.score >= self.config.similarity_threshold);
413
414 results.truncate(self.config.top_k);
416
417 Ok(results)
418 }
419
420 async fn hybrid_search(
422 &self,
423 query: &str,
424 query_embedding: &[f32],
425 candidates: Vec<DocumentChunk>,
426 ) -> Result<Vec<SearchResult>> {
427 let mut results = Vec::new();
428
429 let semantic_results = self.semantic_search(query_embedding, candidates.clone()).await?;
431
432 let keyword_results = self.keyword_search(query, &candidates).await?;
434
435 let semantic_weight = self.config.semantic_weight;
437 let keyword_weight = self.config.keyword_weight;
438
439 for (rank, chunk) in candidates.iter().enumerate() {
440 let semantic_score = semantic_results
441 .iter()
442 .find(|r| r.chunk.id == chunk.id)
443 .map(|r| r.score)
444 .unwrap_or(0.0);
445
446 let keyword_score = keyword_results
447 .iter()
448 .find(|r| r.chunk.id == chunk.id)
449 .map(|r| r.score)
450 .unwrap_or(0.0);
451
452 let combined_score = semantic_score * semantic_weight + keyword_score * keyword_weight;
453
454 results.push(SearchResult::new(chunk.clone(), combined_score, rank));
455 }
456
457 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
459 results.retain(|r| r.score >= self.config.similarity_threshold);
460 results.truncate(self.config.top_k);
461
462 Ok(results)
463 }
464
465 async fn keyword_search(
467 &self,
468 _query: &str,
469 _candidates: &[DocumentChunk],
470 ) -> Result<Vec<SearchResult>> {
471 Ok(Vec::new())
473 }
474
475 fn build_context(
477 &self,
478 search_results: &[SearchResult],
479 additional_context: Option<&str>,
480 ) -> String {
481 let mut context_parts = Vec::new();
482
483 for result in search_results {
485 context_parts
486 .push(format!("Content: {}\nRelevance: {:.2}", result.chunk.content, result.score));
487 }
488
489 if let Some(context) = additional_context {
491 context_parts.push(format!("Additional Context: {}", context));
492 }
493
494 context_parts.join("\n\n")
495 }
496
497 async fn generate_with_llm(&self, query: &str, context: &str) -> Result<String> {
499 let provider = &self.config.provider;
500 let model = &self.config.model;
501
502 match provider {
503 crate::rag::config::LlmProvider::OpenAI => {
504 self.generate_openai_response(query, context, model).await
505 }
506 crate::rag::config::LlmProvider::Anthropic => {
507 self.generate_anthropic_response(query, context, model).await
508 }
509 crate::rag::config::LlmProvider::OpenAICompatible => {
510 self.generate_openai_compatible_response(query, context, model).await
511 }
512 crate::rag::config::LlmProvider::Ollama => {
513 self.generate_ollama_response(query, context, model).await
514 }
515 }
516 }
517
518 fn create_generation_prompt(
520 &self,
521 schema: &SchemaDefinition,
522 count: usize,
523 context: Option<&str>,
524 ) -> String {
525 let mut prompt = format!(
526 "Generate {} rows of sample data following this schema:\n\n{:?}\n\n",
527 count, schema
528 );
529
530 if let Some(context) = context {
531 prompt.push_str(&format!("Additional context: {}\n\n", context));
532 }
533
534 prompt.push_str("Please generate the data in JSON format as an array of objects.");
535 prompt
536 }
537
538 fn parse_dataset_response(
540 &self,
541 response: &str,
542 _schema: &SchemaDefinition,
543 ) -> Result<Vec<HashMap<String, Value>>> {
544 match serde_json::from_str::<Vec<HashMap<String, Value>>>(response) {
546 Ok(data) => Ok(data),
547 Err(_) => {
548 if let Some(json_start) = response.find('[') {
550 if let Some(json_end) = response.rfind(']') {
551 let json_part = &response[json_start..=json_end];
552 serde_json::from_str(json_part).map_err(|e| {
553 mockforge_core::Error::generic(format!("Failed to parse JSON: {}", e))
554 })
555 } else {
556 Err(mockforge_core::Error::generic("No closing bracket found in response"))
557 }
558 } else {
559 Err(mockforge_core::Error::generic("No JSON array found in response"))
560 }
561 }
562 }
563 }
564
565 async fn generate_openai_embedding(&self, text: &str, model: &str) -> Result<Vec<f32>> {
567 let api_key = self
568 .config
569 .api_key
570 .as_ref()
571 .ok_or_else(|| mockforge_core::Error::generic("OpenAI API key not configured"))?;
572
573 let response = self
574 .client
575 .post("https://api.openai.com/v1/embeddings")
576 .header("Authorization", format!("Bearer {}", api_key))
577 .header("Content-Type", "application/json")
578 .json(&serde_json::json!({
579 "input": text,
580 "model": model
581 }))
582 .send()
583 .await?;
584
585 if !response.status().is_success() {
586 return Err(mockforge_core::Error::generic(format!(
587 "OpenAI API error: {}",
588 response.status()
589 )));
590 }
591
592 let json: Value = response.json().await?;
593 let embedding = json["data"][0]["embedding"]
594 .as_array()
595 .ok_or_else(|| mockforge_core::Error::generic("Invalid embedding response format"))?;
596
597 Ok(embedding.iter().map(|v| v.as_f64().unwrap_or(0.0) as f32).collect())
598 }
599
600 async fn generate_openai_compatible_embedding(
602 &self,
603 text: &str,
604 model: &str,
605 ) -> Result<Vec<f32>> {
606 let api_key = self
607 .config
608 .api_key
609 .as_ref()
610 .ok_or_else(|| mockforge_core::Error::generic("API key not configured"))?;
611
612 let response = self
613 .client
614 .post(format!("{}/embeddings", self.config.api_endpoint))
615 .header("Authorization", format!("Bearer {}", api_key))
616 .header("Content-Type", "application/json")
617 .json(&serde_json::json!({
618 "input": text,
619 "model": model
620 }))
621 .send()
622 .await?;
623
624 if !response.status().is_success() {
625 return Err(mockforge_core::Error::generic(format!(
626 "API error: {}",
627 response.status()
628 )));
629 }
630
631 let json: Value = response.json().await?;
632 let embedding = json["data"][0]["embedding"]
633 .as_array()
634 .ok_or_else(|| mockforge_core::Error::generic("Invalid embedding response format"))?;
635
636 Ok(embedding.iter().map(|v| v.as_f64().unwrap_or(0.0) as f32).collect())
637 }
638
639 async fn generate_openai_response(
641 &self,
642 query: &str,
643 context: &str,
644 model: &str,
645 ) -> Result<String> {
646 let api_key = self
647 .config
648 .api_key
649 .as_ref()
650 .ok_or_else(|| mockforge_core::Error::generic("OpenAI API key not configured"))?;
651
652 let messages = vec![
653 serde_json::json!({
654 "role": "system",
655 "content": "You are a helpful assistant. Use the provided context to answer questions accurately."
656 }),
657 serde_json::json!({
658 "role": "user",
659 "content": format!("Context: {}\n\nQuestion: {}", context, query)
660 }),
661 ];
662
663 let response = self
664 .client
665 .post("https://api.openai.com/v1/chat/completions")
666 .header("Authorization", format!("Bearer {}", api_key))
667 .header("Content-Type", "application/json")
668 .json(&serde_json::json!({
669 "model": model,
670 "messages": messages,
671 "max_tokens": self.config.max_tokens,
672 "temperature": self.config.temperature,
673 "top_p": self.config.top_p
674 }))
675 .send()
676 .await?;
677
678 if !response.status().is_success() {
679 return Err(mockforge_core::Error::generic(format!(
680 "OpenAI API error: {}",
681 response.status()
682 )));
683 }
684
685 let json: Value = response.json().await?;
686 let content = json["choices"][0]["message"]["content"]
687 .as_str()
688 .ok_or_else(|| mockforge_core::Error::generic("Invalid response format"))?;
689
690 Ok(content.to_string())
691 }
692
693 async fn generate_anthropic_response(
695 &self,
696 _query: &str,
697 _context: &str,
698 _model: &str,
699 ) -> Result<String> {
700 Ok("Anthropic response placeholder".to_string())
702 }
703
704 async fn generate_openai_compatible_response(
706 &self,
707 query: &str,
708 context: &str,
709 model: &str,
710 ) -> Result<String> {
711 let api_key = self
712 .config
713 .api_key
714 .as_ref()
715 .ok_or_else(|| mockforge_core::Error::generic("API key not configured"))?;
716
717 let messages = vec![
718 serde_json::json!({
719 "role": "system",
720 "content": "You are a helpful assistant. Use the provided context to answer questions accurately."
721 }),
722 serde_json::json!({
723 "role": "user",
724 "content": format!("Context: {}\n\nQuestion: {}", context, query)
725 }),
726 ];
727
728 let response = self
729 .client
730 .post(format!("{}/chat/completions", self.config.api_endpoint))
731 .header("Authorization", format!("Bearer {}", api_key))
732 .header("Content-Type", "application/json")
733 .json(&serde_json::json!({
734 "model": model,
735 "messages": messages,
736 "max_tokens": self.config.max_tokens,
737 "temperature": self.config.temperature,
738 "top_p": self.config.top_p
739 }))
740 .send()
741 .await?;
742
743 if !response.status().is_success() {
744 return Err(mockforge_core::Error::generic(format!(
745 "API error: {}",
746 response.status()
747 )));
748 }
749
750 let json: Value = response.json().await?;
751 let content = json["choices"][0]["message"]["content"]
752 .as_str()
753 .ok_or_else(|| mockforge_core::Error::generic("Invalid response format"))?;
754
755 Ok(content.to_string())
756 }
757
758 async fn generate_ollama_response(
760 &self,
761 _query: &str,
762 _context: &str,
763 _model: &str,
764 ) -> Result<String> {
765 Ok("Ollama response placeholder".to_string())
767 }
768}
769
770impl Default for RagEngine {
771 fn default() -> Self {
772 use crate::rag::storage::InMemoryStorage;
773
774 let config = crate::rag::config::RagConfig::default();
777 let storage = Arc::new(InMemoryStorage::default());
778
779 Self::new(config, storage).expect("Failed to create default RagEngine")
781 }
782}
783
784fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
786 if a.len() != b.len() {
787 return 0.0;
788 }
789
790 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
791 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
792 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
793
794 if norm_a == 0.0 || norm_b == 0.0 {
795 0.0
796 } else {
797 dot_product / (norm_a * norm_b)
798 }
799}
800
801#[derive(Debug, Clone, Serialize, Deserialize)]
803pub struct RagStats {
804 pub total_documents: usize,
806 pub total_chunks: usize,
808 pub index_size_bytes: u64,
810 pub last_updated: chrono::DateTime<chrono::Utc>,
812 pub cache_hit_rate: f32,
814 pub average_response_time_ms: f32,
816}
817
818impl Default for RagStats {
819 fn default() -> Self {
820 Self {
821 total_documents: 0,
822 total_chunks: 0,
823 index_size_bytes: 0,
824 last_updated: chrono::Utc::now(),
825 cache_hit_rate: 0.0,
826 average_response_time_ms: 0.0,
827 }
828 }
829}
830
831#[derive(Debug, Clone)]
833pub struct StorageStats {
834 pub total_documents: usize,
836 pub total_chunks: usize,
838 pub index_size_bytes: u64,
840 pub last_updated: chrono::DateTime<chrono::Utc>,
842}
843
844impl Default for StorageStats {
845 fn default() -> Self {
846 Self {
847 total_documents: 0,
848 total_chunks: 0,
849 index_size_bytes: 0,
850 last_updated: chrono::Utc::now(),
851 }
852 }
853}
854
855#[cfg(test)]
856mod tests {
857
858 #[test]
859 fn test_module_compiles() {
860 }
862}