1use crate::rag::utils::Cache;
7use crate::rag::{
8 config::{EmbeddingProvider, RagConfig},
9 storage::DocumentStorage,
10};
11use crate::schema::SchemaDefinition;
12use crate::Result;
13use serde::{Deserialize, Serialize};
14use serde_json::Value;
15use std::cmp::Ordering;
16use std::collections::HashMap;
17use std::sync::Arc;
18use std::time::Duration;
19use tracing::debug;
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct DocumentChunk {
24 pub id: String,
26 pub content: String,
28 pub metadata: HashMap<String, String>,
30 pub embedding: Vec<f32>,
32 pub document_id: String,
34 pub position: usize,
36 pub length: usize,
38}
39
40impl DocumentChunk {
41 pub fn new(
43 id: String,
44 content: String,
45 metadata: HashMap<String, String>,
46 embedding: Vec<f32>,
47 document_id: String,
48 position: usize,
49 length: usize,
50 ) -> Self {
51 Self {
52 id,
53 content,
54 metadata,
55 embedding,
56 document_id,
57 position,
58 length,
59 }
60 }
61
62 pub fn size(&self) -> usize {
64 self.content.len()
65 }
66
67 pub fn is_empty(&self) -> bool {
69 self.content.is_empty()
70 }
71
72 pub fn get_metadata(&self, key: &str) -> Option<&String> {
74 self.metadata.get(key)
75 }
76
77 pub fn set_metadata(&mut self, key: String, value: String) {
79 self.metadata.insert(key, value);
80 }
81
82 pub fn similarity(&self, other: &DocumentChunk) -> f32 {
84 cosine_similarity(&self.embedding, &other.embedding)
85 }
86
87 pub fn preview(&self) -> String {
89 if self.content.len() > 100 {
90 format!("{}...", &self.content[..100])
91 } else {
92 self.content.clone()
93 }
94 }
95}
96
97#[derive(Debug, Clone)]
99pub struct SearchResult {
100 pub chunk: DocumentChunk,
102 pub score: f32,
104 pub rank: usize,
106}
107
108impl SearchResult {
109 pub fn new(chunk: DocumentChunk, score: f32, rank: usize) -> Self {
111 Self { chunk, score, rank }
112 }
113}
114
115impl PartialEq for SearchResult {
116 fn eq(&self, other: &Self) -> bool {
117 self.score == other.score
118 }
119}
120
121impl Eq for SearchResult {}
122
123impl PartialOrd for SearchResult {
124 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
125 Some(self.cmp(other))
126 }
127}
128
129impl Ord for SearchResult {
130 fn cmp(&self, other: &Self) -> Ordering {
131 self.partial_cmp(other).unwrap_or(Ordering::Equal)
132 }
133}
134
135pub struct RagEngine {
137 config: RagConfig,
139 storage: Arc<dyn DocumentStorage>,
141 client: reqwest::Client,
143 total_response_time_ms: f64,
145 response_count: usize,
147 embedding_cache: Cache<String, Vec<f32>>,
149}
150
151impl std::fmt::Debug for RagEngine {
152 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153 f.debug_struct("RagEngine")
154 .field("config", &self.config)
155 .field("storage", &"<DocumentStorage>")
156 .field("client", &"<reqwest::Client>")
157 .field("total_response_time_ms", &self.total_response_time_ms)
158 .field("response_count", &self.response_count)
159 .field("embedding_cache", &"<Cache>")
160 .finish()
161 }
162}
163
164impl RagEngine {
165 pub fn new(config: RagConfig, storage: Arc<dyn DocumentStorage>) -> Result<Self> {
167 let client = reqwest::ClientBuilder::new().timeout(config.timeout_duration()).build()?;
168
169 let cache_ttl = config.cache_ttl_duration().as_secs();
170
171 Ok(Self {
172 config,
173 storage,
174 client,
175 total_response_time_ms: 0.0,
176 response_count: 0,
177 embedding_cache: Cache::new(cache_ttl, 1000), })
179 }
180
181 fn record_response_time(&mut self, duration: Duration) {
183 let ms = duration.as_millis() as f64;
184 self.total_response_time_ms += ms;
185 self.response_count += 1;
186 }
187
188 pub fn config(&self) -> &RagConfig {
190 &self.config
191 }
192
193 pub fn storage(&self) -> &Arc<dyn DocumentStorage> {
195 &self.storage
196 }
197
198 pub fn update_config(&mut self, config: RagConfig) -> Result<()> {
200 config.validate()?;
201 self.config = config;
202 Ok(())
203 }
204
205 pub async fn add_document(
207 &self,
208 document_id: String,
209 content: String,
210 metadata: HashMap<String, String>,
211 ) -> Result<()> {
212 debug!("Adding document: {}", document_id);
213
214 let chunks = self.create_chunks(document_id.clone(), content, metadata).await?;
216
217 let chunks_with_embeddings = self.generate_embeddings(chunks).await?;
219
220 self.storage.store_chunks(chunks_with_embeddings).await?;
222
223 debug!("Successfully added document: {}", document_id);
224 Ok(())
225 }
226
227 pub async fn search(&mut self, query: &str, top_k: Option<usize>) -> Result<Vec<SearchResult>> {
229 let start = tokio::time::Instant::now();
230 let top_k = top_k.unwrap_or(self.config.top_k);
231 debug!("Searching for: {} (top_k: {})", query, top_k);
232
233 let query_embedding = self.generate_query_embedding(query).await?;
235
236 let candidates = self.storage.search_similar(&query_embedding, top_k * 2).await?; let results = if self.config.hybrid_search {
241 self.hybrid_search(query, &query_embedding, candidates).await?
242 } else {
243 self.semantic_search(&query_embedding, candidates).await?
244 };
245
246 debug!("Found {} relevant chunks", results.len());
247 let duration = start.elapsed();
248 self.record_response_time(duration);
249 Ok(results)
250 }
251
252 pub async fn generate(&mut self, query: &str, context: Option<&str>) -> Result<String> {
254 let start = tokio::time::Instant::now();
255 debug!("Generating response for query: {}", query);
256
257 let search_results = self.search(query, None).await?;
259
260 let rag_context = self.build_context(&search_results, context);
262
263 let response = self.generate_with_llm(query, &rag_context).await?;
265
266 debug!("Generated response ({} chars)", response.len());
267 let duration = start.elapsed();
268 self.record_response_time(duration);
269 Ok(response)
270 }
271
272 pub async fn generate_dataset(
274 &mut self,
275 schema: &SchemaDefinition,
276 count: usize,
277 context: Option<&str>,
278 ) -> Result<Vec<HashMap<String, Value>>> {
279 let start = tokio::time::Instant::now();
280 debug!("Generating dataset with {} rows using schema: {}", count, schema.name);
281
282 let prompt = self.create_generation_prompt(schema, count, context);
284
285 let response = self.generate(&prompt, None).await?;
287
288 let dataset = self.parse_dataset_response(&response, schema)?;
290
291 debug!("Generated dataset with {} rows", dataset.len());
292 let duration = start.elapsed();
293 self.record_response_time(duration);
294 Ok(dataset)
295 }
296
297 pub async fn get_stats(&self) -> Result<RagStats> {
299 let storage_stats = self.storage.get_stats().await?;
300
301 let average_response_time_ms = if self.response_count > 0 {
302 (self.total_response_time_ms / self.response_count as f64) as f32
303 } else {
304 0.0
305 };
306
307 Ok(RagStats {
308 total_documents: storage_stats.total_documents,
309 total_chunks: storage_stats.total_chunks,
310 index_size_bytes: storage_stats.index_size_bytes,
311 last_updated: storage_stats.last_updated,
312 cache_hit_rate: self.embedding_cache.hit_rate(),
313 average_response_time_ms,
314 })
315 }
316
317 async fn create_chunks(
319 &self,
320 document_id: String,
321 content: String,
322 metadata: HashMap<String, String>,
323 ) -> Result<Vec<DocumentChunk>> {
324 let mut chunks = Vec::new();
325 let words: Vec<&str> = content.split_whitespace().collect();
326 let chunk_size = self.config.chunk_size;
327 let overlap = self.config.chunk_overlap;
328
329 for (i, chunk_start) in (0..words.len()).step_by(chunk_size - overlap).enumerate() {
330 let chunk_end = (chunk_start + chunk_size).min(words.len());
331 let chunk_words: Vec<&str> = words[chunk_start..chunk_end].to_vec();
332 let chunk_content = chunk_words.join(" ");
333
334 if !chunk_content.is_empty() {
335 let chunk_id = format!("{}_chunk_{}", document_id, i);
336
337 chunks.push(DocumentChunk::new(
338 chunk_id,
339 chunk_content,
340 metadata.clone(),
341 Vec::new(), document_id.clone(),
343 i,
344 chunk_words.len(),
345 ));
346 }
347 }
348
349 Ok(chunks)
350 }
351
352 async fn generate_embeddings(&self, chunks: Vec<DocumentChunk>) -> Result<Vec<DocumentChunk>> {
354 let mut chunks_with_embeddings = Vec::new();
355
356 for chunk in chunks {
357 let embedding = self.generate_embedding(&chunk.content).await?;
358 let mut chunk_with_embedding = chunk;
359 chunk_with_embedding.embedding = embedding;
360 chunks_with_embeddings.push(chunk_with_embedding);
361 }
362
363 Ok(chunks_with_embeddings)
364 }
365
366 async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
368 let provider = &self.config.embedding_provider;
369 let model = &self.config.embedding_model;
370
371 match provider {
372 EmbeddingProvider::OpenAI => self.generate_openai_embedding(text, model).await,
373 EmbeddingProvider::OpenAICompatible => {
374 self.generate_openai_compatible_embedding(text, model).await
375 }
376 EmbeddingProvider::Ollama => {
377 self.generate_openai_compatible_embedding(text, model).await
379 }
380 }
381 }
382
383 async fn generate_query_embedding(&mut self, query: &str) -> Result<Vec<f32>> {
385 if let Some(embedding) = self.embedding_cache.get(&query.to_string()) {
387 return Ok(embedding);
388 }
389
390 let embedding = self.generate_embedding(query).await?;
392
393 self.embedding_cache.put(query.to_string(), embedding.clone());
395
396 Ok(embedding)
397 }
398
399 async fn semantic_search(
401 &self,
402 query_embedding: &[f32],
403 candidates: Vec<DocumentChunk>,
404 ) -> Result<Vec<SearchResult>> {
405 let mut results = Vec::new();
406
407 for (rank, chunk) in candidates.iter().enumerate() {
409 let score = cosine_similarity(query_embedding, &chunk.embedding);
410
411 results.push(SearchResult::new(chunk.clone(), score, rank));
412 }
413
414 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
416 results.retain(|r| r.score >= self.config.similarity_threshold);
417
418 results.truncate(self.config.top_k);
420
421 Ok(results)
422 }
423
424 async fn hybrid_search(
426 &self,
427 query: &str,
428 query_embedding: &[f32],
429 candidates: Vec<DocumentChunk>,
430 ) -> Result<Vec<SearchResult>> {
431 let mut results = Vec::new();
432
433 let semantic_results = self.semantic_search(query_embedding, candidates.clone()).await?;
435
436 let keyword_results = self.keyword_search(query, &candidates).await?;
438
439 let semantic_weight = self.config.semantic_weight;
441 let keyword_weight = self.config.keyword_weight;
442
443 for (rank, chunk) in candidates.iter().enumerate() {
444 let semantic_score = semantic_results
445 .iter()
446 .find(|r| r.chunk.id == chunk.id)
447 .map(|r| r.score)
448 .unwrap_or(0.0);
449
450 let keyword_score = keyword_results
451 .iter()
452 .find(|r| r.chunk.id == chunk.id)
453 .map(|r| r.score)
454 .unwrap_or(0.0);
455
456 let combined_score = semantic_score * semantic_weight + keyword_score * keyword_weight;
457
458 results.push(SearchResult::new(chunk.clone(), combined_score, rank));
459 }
460
461 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
463 results.retain(|r| r.score >= self.config.similarity_threshold);
464 results.truncate(self.config.top_k);
465
466 Ok(results)
467 }
468
469 async fn keyword_search(
471 &self,
472 _query: &str,
473 _candidates: &[DocumentChunk],
474 ) -> Result<Vec<SearchResult>> {
475 Ok(Vec::new())
477 }
478
479 fn build_context(
481 &self,
482 search_results: &[SearchResult],
483 additional_context: Option<&str>,
484 ) -> String {
485 let mut context_parts = Vec::new();
486
487 for result in search_results {
489 context_parts
490 .push(format!("Content: {}\nRelevance: {:.2}", result.chunk.content, result.score));
491 }
492
493 if let Some(context) = additional_context {
495 context_parts.push(format!("Additional Context: {}", context));
496 }
497
498 context_parts.join("\n\n")
499 }
500
501 async fn generate_with_llm(&self, query: &str, context: &str) -> Result<String> {
503 let provider = &self.config.provider;
504 let model = &self.config.model;
505
506 match provider {
507 crate::rag::config::LlmProvider::OpenAI => {
508 self.generate_openai_response(query, context, model).await
509 }
510 crate::rag::config::LlmProvider::Anthropic => {
511 self.generate_anthropic_response(query, context, model).await
512 }
513 crate::rag::config::LlmProvider::OpenAICompatible => {
514 self.generate_openai_compatible_response(query, context, model).await
515 }
516 crate::rag::config::LlmProvider::Ollama => {
517 self.generate_ollama_response(query, context, model).await
518 }
519 }
520 }
521
522 fn create_generation_prompt(
524 &self,
525 schema: &SchemaDefinition,
526 count: usize,
527 context: Option<&str>,
528 ) -> String {
529 let mut prompt = format!(
530 "Generate {} rows of sample data following this schema:\n\n{:?}\n\n",
531 count, schema
532 );
533
534 if let Some(context) = context {
535 prompt.push_str(&format!("Additional context: {}\n\n", context));
536 }
537
538 prompt.push_str("Please generate the data in JSON format as an array of objects.");
539 prompt
540 }
541
542 fn parse_dataset_response(
544 &self,
545 response: &str,
546 _schema: &SchemaDefinition,
547 ) -> Result<Vec<HashMap<String, Value>>> {
548 match serde_json::from_str::<Vec<HashMap<String, Value>>>(response) {
550 Ok(data) => Ok(data),
551 Err(_) => {
552 if let Some(json_start) = response.find('[') {
554 if let Some(json_end) = response.rfind(']') {
555 let json_part = &response[json_start..=json_end];
556 serde_json::from_str(json_part).map_err(|e| {
557 crate::Error::generic(format!("Failed to parse JSON: {}", e))
558 })
559 } else {
560 Err(crate::Error::generic("No closing bracket found in response"))
561 }
562 } else {
563 Err(crate::Error::generic("No JSON array found in response"))
564 }
565 }
566 }
567 }
568
569 async fn generate_openai_embedding(&self, text: &str, model: &str) -> Result<Vec<f32>> {
571 let api_key = self
572 .config
573 .api_key
574 .as_ref()
575 .ok_or_else(|| crate::Error::generic("OpenAI API key not configured"))?;
576
577 let response = self
578 .client
579 .post("https://api.openai.com/v1/embeddings")
580 .header("Authorization", format!("Bearer {}", api_key))
581 .header("Content-Type", "application/json")
582 .json(&serde_json::json!({
583 "input": text,
584 "model": model
585 }))
586 .send()
587 .await?;
588
589 if !response.status().is_success() {
590 return Err(crate::Error::generic(format!("OpenAI API error: {}", response.status())));
591 }
592
593 let json: Value = response.json().await?;
594 let embedding = json["data"][0]["embedding"]
595 .as_array()
596 .ok_or_else(|| crate::Error::generic("Invalid embedding response format"))?;
597
598 Ok(embedding.iter().map(|v| v.as_f64().unwrap_or(0.0) as f32).collect())
599 }
600
601 async fn generate_openai_compatible_embedding(
603 &self,
604 text: &str,
605 model: &str,
606 ) -> Result<Vec<f32>> {
607 let api_key = self
608 .config
609 .api_key
610 .as_ref()
611 .ok_or_else(|| crate::Error::generic("API key not configured"))?;
612
613 let response = self
614 .client
615 .post(format!("{}/embeddings", self.config.api_endpoint))
616 .header("Authorization", format!("Bearer {}", api_key))
617 .header("Content-Type", "application/json")
618 .json(&serde_json::json!({
619 "input": text,
620 "model": model
621 }))
622 .send()
623 .await?;
624
625 if !response.status().is_success() {
626 return Err(crate::Error::generic(format!("API error: {}", response.status())));
627 }
628
629 let json: Value = response.json().await?;
630 let embedding = json["data"][0]["embedding"]
631 .as_array()
632 .ok_or_else(|| crate::Error::generic("Invalid embedding response format"))?;
633
634 Ok(embedding.iter().map(|v| v.as_f64().unwrap_or(0.0) as f32).collect())
635 }
636
637 async fn generate_openai_response(
639 &self,
640 query: &str,
641 context: &str,
642 model: &str,
643 ) -> Result<String> {
644 let api_key = self
645 .config
646 .api_key
647 .as_ref()
648 .ok_or_else(|| crate::Error::generic("OpenAI API key not configured"))?;
649
650 let messages = vec![
651 serde_json::json!({
652 "role": "system",
653 "content": "You are a helpful assistant. Use the provided context to answer questions accurately."
654 }),
655 serde_json::json!({
656 "role": "user",
657 "content": format!("Context: {}\n\nQuestion: {}", context, query)
658 }),
659 ];
660
661 let response = self
662 .client
663 .post("https://api.openai.com/v1/chat/completions")
664 .header("Authorization", format!("Bearer {}", api_key))
665 .header("Content-Type", "application/json")
666 .json(&serde_json::json!({
667 "model": model,
668 "messages": messages,
669 "max_tokens": self.config.max_tokens,
670 "temperature": self.config.temperature,
671 "top_p": self.config.top_p
672 }))
673 .send()
674 .await?;
675
676 if !response.status().is_success() {
677 return Err(crate::Error::generic(format!("OpenAI API error: {}", response.status())));
678 }
679
680 let json: Value = response.json().await?;
681 let content = json["choices"][0]["message"]["content"]
682 .as_str()
683 .ok_or_else(|| crate::Error::generic("Invalid response format"))?;
684
685 Ok(content.to_string())
686 }
687
688 async fn generate_anthropic_response(
690 &self,
691 _query: &str,
692 _context: &str,
693 _model: &str,
694 ) -> Result<String> {
695 Ok("Anthropic response placeholder".to_string())
697 }
698
699 async fn generate_openai_compatible_response(
701 &self,
702 query: &str,
703 context: &str,
704 model: &str,
705 ) -> Result<String> {
706 let api_key = self
707 .config
708 .api_key
709 .as_ref()
710 .ok_or_else(|| crate::Error::generic("API key not configured"))?;
711
712 let messages = vec![
713 serde_json::json!({
714 "role": "system",
715 "content": "You are a helpful assistant. Use the provided context to answer questions accurately."
716 }),
717 serde_json::json!({
718 "role": "user",
719 "content": format!("Context: {}\n\nQuestion: {}", context, query)
720 }),
721 ];
722
723 let response = self
724 .client
725 .post(format!("{}/chat/completions", self.config.api_endpoint))
726 .header("Authorization", format!("Bearer {}", api_key))
727 .header("Content-Type", "application/json")
728 .json(&serde_json::json!({
729 "model": model,
730 "messages": messages,
731 "max_tokens": self.config.max_tokens,
732 "temperature": self.config.temperature,
733 "top_p": self.config.top_p
734 }))
735 .send()
736 .await?;
737
738 if !response.status().is_success() {
739 return Err(crate::Error::generic(format!("API error: {}", response.status())));
740 }
741
742 let json: Value = response.json().await?;
743 let content = json["choices"][0]["message"]["content"]
744 .as_str()
745 .ok_or_else(|| crate::Error::generic("Invalid response format"))?;
746
747 Ok(content.to_string())
748 }
749
750 async fn generate_ollama_response(
752 &self,
753 _query: &str,
754 _context: &str,
755 _model: &str,
756 ) -> Result<String> {
757 Ok("Ollama response placeholder".to_string())
759 }
760}
761
762impl Default for RagEngine {
763 fn default() -> Self {
764 use crate::rag::storage::InMemoryStorage;
765
766 let config = crate::rag::config::RagConfig::default();
769 let storage = Arc::new(InMemoryStorage::default());
770
771 Self::new(config, storage).expect("Failed to create default RagEngine")
773 }
774}
775
776fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
778 if a.len() != b.len() {
779 return 0.0;
780 }
781
782 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
783 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
784 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
785
786 if norm_a == 0.0 || norm_b == 0.0 {
787 0.0
788 } else {
789 dot_product / (norm_a * norm_b)
790 }
791}
792
793#[derive(Debug, Clone, Serialize, Deserialize)]
795pub struct RagStats {
796 pub total_documents: usize,
798 pub total_chunks: usize,
800 pub index_size_bytes: u64,
802 pub last_updated: chrono::DateTime<chrono::Utc>,
804 pub cache_hit_rate: f32,
806 pub average_response_time_ms: f32,
808}
809
810impl Default for RagStats {
811 fn default() -> Self {
812 Self {
813 total_documents: 0,
814 total_chunks: 0,
815 index_size_bytes: 0,
816 last_updated: chrono::Utc::now(),
817 cache_hit_rate: 0.0,
818 average_response_time_ms: 0.0,
819 }
820 }
821}
822
823#[derive(Debug, Clone)]
825pub struct StorageStats {
826 pub total_documents: usize,
828 pub total_chunks: usize,
830 pub index_size_bytes: u64,
832 pub last_updated: chrono::DateTime<chrono::Utc>,
834}
835
836impl Default for StorageStats {
837 fn default() -> Self {
838 Self {
839 total_documents: 0,
840 total_chunks: 0,
841 index_size_bytes: 0,
842 last_updated: chrono::Utc::now(),
843 }
844 }
845}
846
847#[cfg(test)]
848mod tests {
849
850 #[test]
851 fn test_module_compiles() {
852 }
854}