1pub mod config;
12pub mod engine;
13pub mod providers;
14pub mod storage;
15pub mod utils;
16
17pub use config::*;
19pub use providers::*;
20pub use utils::*;
21
22pub use engine::StorageStats as EngineStorageStats;
24pub use storage::StorageStats as StorageStorageStats;
25
26use crate::Result;
28use crate::{schema::SchemaDefinition, DataConfig};
29use reqwest::{Client, ClientBuilder};
30use serde::{Deserialize, Serialize};
31use serde_json::Value;
32use std::cmp::Ordering;
33use std::collections::HashMap;
34use std::time::Duration;
35use tokio::time::sleep;
36use tracing::{debug, warn};
37
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
40#[serde(rename_all = "lowercase")]
41pub enum LlmProvider {
42 OpenAI,
44 Anthropic,
46 OpenAICompatible,
48 Ollama,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
54#[serde(rename_all = "lowercase")]
55pub enum EmbeddingProvider {
56 OpenAI,
58 OpenAICompatible,
60 Ollama,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct RagConfig {
67 pub provider: LlmProvider,
69 pub api_endpoint: String,
71 pub api_key: Option<String>,
73 pub model: String,
75 pub max_tokens: usize,
77 pub temperature: f64,
79 pub context_window: usize,
81
82 pub semantic_search_enabled: bool,
84 pub embedding_provider: EmbeddingProvider,
86 pub embedding_model: String,
88 pub embedding_endpoint: Option<String>,
90 pub similarity_threshold: f64,
92 pub max_chunks: usize,
94
95 pub request_timeout_seconds: u64,
97 pub max_retries: usize,
99}
100
101impl Default for RagConfig {
102 fn default() -> Self {
103 Self {
104 provider: LlmProvider::OpenAI,
105 api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
106 api_key: None,
107 model: "gpt-3.5-turbo".to_string(),
108 max_tokens: 1000,
109 temperature: 0.7,
110 context_window: 4000,
111 semantic_search_enabled: true,
112 embedding_provider: EmbeddingProvider::OpenAI,
113 embedding_model: "text-embedding-ada-002".to_string(),
114 embedding_endpoint: None,
115 similarity_threshold: 0.7,
116 max_chunks: 5,
117 request_timeout_seconds: 30,
118 max_retries: 3,
119 }
120 }
121}
122
123#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct DocumentChunk {
126 pub id: String,
128 pub content: String,
130 pub metadata: HashMap<String, Value>,
132 pub embedding: Vec<f32>,
134}
135
136#[derive(Debug)]
138pub struct SearchResult<'a> {
139 pub chunk: &'a DocumentChunk,
141 pub score: f64,
143}
144
145#[derive(Debug)]
147pub struct RagEngine {
148 config: RagConfig,
150 chunks: Vec<DocumentChunk>,
152 schema_kb: HashMap<String, Vec<String>>,
154 client: Client,
156}
157
158impl RagEngine {
159 pub fn new(config: RagConfig) -> Self {
161 let client = ClientBuilder::new()
162 .timeout(Duration::from_secs(config.request_timeout_seconds))
163 .build()
164 .unwrap_or_else(|e| {
165 warn!("Failed to create HTTP client with timeout, using default: {}", e);
166 Client::new()
167 });
168
169 Self {
170 config,
171 chunks: Vec::new(),
172 schema_kb: HashMap::new(),
173 client,
174 }
175 }
176
177 pub fn add_document(
179 &mut self,
180 content: String,
181 metadata: HashMap<String, Value>,
182 ) -> Result<String> {
183 let id = format!("chunk_{}", self.chunks.len());
184 let chunk = DocumentChunk {
185 id: id.clone(),
186 content,
187 metadata,
188 embedding: Vec::new(), };
190
191 self.chunks.push(chunk);
192 Ok(id)
193 }
194
195 pub fn add_schema(&mut self, schema: &SchemaDefinition) -> Result<()> {
197 let mut schema_info = Vec::new();
198
199 schema_info.push(format!("Schema: {}", schema.name));
200
201 if let Some(description) = &schema.description {
202 schema_info.push(format!("Description: {}", description));
203 }
204
205 for field in &schema.fields {
206 let mut field_info = format!(
207 "Field '{}': type={}, required={}",
208 field.name, field.field_type, field.required
209 );
210
211 if let Some(description) = &field.description {
212 field_info.push_str(&format!(" - {}", description));
213 }
214
215 schema_info.push(field_info);
216 }
217
218 for (rel_name, relationship) in &schema.relationships {
219 schema_info.push(format!(
220 "Relationship '{}': {} -> {} ({:?})",
221 rel_name, schema.name, relationship.target_schema, relationship.relationship_type
222 ));
223 }
224
225 self.schema_kb.insert(schema.name.clone(), schema_info);
226 Ok(())
227 }
228
229 pub async fn generate_with_rag(
231 &self,
232 schema: &SchemaDefinition,
233 config: &DataConfig,
234 ) -> Result<Vec<Value>> {
235 if !config.rag_enabled {
236 return Err(crate::Error::generic("RAG is not enabled in config"));
237 }
238
239 if self.config.api_key.is_none() {
241 return Err(crate::Error::generic(
242 "RAG is enabled but no API key is configured. Please set MOCKFORGE_RAG_API_KEY or provide --rag-api-key"
243 ));
244 }
245
246 let mut results = Vec::new();
247 let mut failed_rows = 0;
248
249 for i in 0..config.rows {
251 match self.generate_single_row_with_rag(schema, i).await {
252 Ok(data) => results.push(data),
253 Err(e) => {
254 failed_rows += 1;
255 warn!("Failed to generate RAG data for row {}: {}", i, e);
256
257 if failed_rows > config.rows / 4 {
259 return Err(crate::Error::generic(
261 format!("Too many RAG generation failures ({} out of {} rows failed). Check API configuration and network connectivity.", failed_rows, config.rows)
262 ));
263 }
264
265 let fallback_data = self.generate_fallback_data(schema);
267 results.push(fallback_data);
268 }
269 }
270 }
271
272 if failed_rows > 0 {
273 warn!(
274 "RAG generation completed with {} failed rows out of {}",
275 failed_rows, config.rows
276 );
277 }
278
279 Ok(results)
280 }
281
282 async fn generate_single_row_with_rag(
284 &self,
285 schema: &SchemaDefinition,
286 row_index: usize,
287 ) -> Result<Value> {
288 let prompt = self.build_generation_prompt(schema, row_index).await?;
289 let generated_data = self.call_llm(&prompt).await?;
290 self.parse_llm_response(&generated_data)
291 }
292
293 fn generate_fallback_data(&self, schema: &SchemaDefinition) -> Value {
295 let mut obj = serde_json::Map::new();
296
297 for field in &schema.fields {
298 let value = match field.field_type.as_str() {
299 "string" => Value::String("sample_data".to_string()),
300 "integer" | "number" => Value::Number(42.into()),
301 "boolean" => Value::Bool(true),
302 _ => Value::String("sample_data".to_string()),
303 };
304 obj.insert(field.name.clone(), value);
305 }
306
307 Value::Object(obj)
308 }
309
310 async fn build_generation_prompt(
312 &self,
313 schema: &SchemaDefinition,
314 _row_index: usize,
315 ) -> Result<String> {
316 let mut prompt =
317 format!("Generate a single row of data for the '{}' schema.\n\n", schema.name);
318
319 if let Some(schema_info) = self.schema_kb.get(&schema.name) {
321 prompt.push_str("Schema Information:\n");
322 for info in schema_info {
323 prompt.push_str(&format!("- {}\n", info));
324 }
325 prompt.push('\n');
326 }
327
328 let relevant_chunks = self.retrieve_relevant_chunks(&schema.name, 3).await?;
330 if !relevant_chunks.is_empty() {
331 prompt.push_str("Relevant Context:\n");
332 for chunk in relevant_chunks {
333 prompt.push_str(&format!("- {}\n", chunk.content));
334 }
335 prompt.push('\n');
336 }
337
338 prompt.push_str("Instructions:\n");
340 prompt.push_str("- Generate realistic data that matches the schema\n");
341 prompt.push_str("- Ensure all required fields are present\n");
342 prompt.push_str("- Use appropriate data types and formats\n");
343 prompt.push_str("- Make relationships consistent if referenced\n");
344 prompt.push_str("- Output only valid JSON for a single object\n\n");
345
346 prompt.push_str("Generate the data:");
347
348 Ok(prompt)
349 }
350
351 async fn retrieve_relevant_chunks(
353 &self,
354 query: &str,
355 limit: usize,
356 ) -> Result<Vec<&DocumentChunk>> {
357 if self.config.semantic_search_enabled {
358 let results = self.semantic_search(query, limit).await?;
360 Ok(results.into_iter().map(|r| r.chunk).collect())
361 } else {
362 Ok(self.keyword_search(query, limit))
364 }
365 }
366
367 pub fn keyword_search(&self, query: &str, limit: usize) -> Vec<&DocumentChunk> {
369 self.chunks
370 .iter()
371 .filter(|chunk| {
372 chunk.content.to_lowercase().contains(&query.to_lowercase())
373 || chunk.metadata.values().any(|v| {
374 if let Some(s) = v.as_str() {
375 s.to_lowercase().contains(&query.to_lowercase())
376 } else {
377 false
378 }
379 })
380 })
381 .take(limit)
382 .collect()
383 }
384
385 async fn semantic_search(&self, query: &str, limit: usize) -> Result<Vec<SearchResult<'_>>> {
387 let query_embedding = self.generate_embedding(query).await?;
389
390 let mut results: Vec<SearchResult> = Vec::new();
392
393 for chunk in &self.chunks {
394 if chunk.embedding.is_empty() {
395 continue;
397 }
398
399 let score = Self::cosine_similarity(&query_embedding, &chunk.embedding);
400 if score >= self.config.similarity_threshold {
401 results.push(SearchResult { chunk, score });
402 }
403 }
404
405 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
407 results.truncate(limit);
408
409 Ok(results)
410 }
411
412 async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
414 match &self.config.embedding_provider {
415 EmbeddingProvider::OpenAI => self.generate_openai_embedding(text).await,
416 EmbeddingProvider::OpenAICompatible => {
417 self.generate_openai_compatible_embedding(text).await
418 }
419 EmbeddingProvider::Ollama => self.generate_ollama_embedding(text).await,
420 }
421 }
422
423 async fn generate_openai_embedding(&self, text: &str) -> Result<Vec<f32>> {
425 let api_key = self
426 .config
427 .api_key
428 .as_ref()
429 .ok_or_else(|| crate::Error::generic("OpenAI API key not configured"))?;
430
431 let endpoint = self
432 .config
433 .embedding_endpoint
434 .as_ref()
435 .unwrap_or(&self.config.api_endpoint)
436 .replace("chat/completions", "embeddings");
437
438 let request_body = serde_json::json!({
439 "model": self.config.embedding_model,
440 "input": text
441 });
442
443 debug!("Generating embedding for text with OpenAI API");
444
445 let response = self
446 .client
447 .post(&endpoint)
448 .header("Authorization", format!("Bearer {}", api_key))
449 .header("Content-Type", "application/json")
450 .json(&request_body)
451 .send()
452 .await
453 .map_err(|e| crate::Error::generic(format!("Embedding API request failed: {}", e)))?;
454
455 if !response.status().is_success() {
456 let error_text = response.text().await.unwrap_or_default();
457 return Err(crate::Error::generic(format!("Embedding API error: {}", error_text)));
458 }
459
460 let response_json: Value = response.json().await.map_err(|e| {
461 crate::Error::generic(format!("Failed to parse embedding response: {}", e))
462 })?;
463
464 if let Some(data) = response_json.get("data").and_then(|d| d.as_array()) {
465 if let Some(first_item) = data.first() {
466 if let Some(embedding) = first_item.get("embedding").and_then(|e| e.as_array()) {
467 let embedding_vec: Vec<f32> =
468 embedding.iter().filter_map(|v| v.as_f64().map(|f| f as f32)).collect();
469 return Ok(embedding_vec);
470 }
471 }
472 }
473
474 Err(crate::Error::generic("Invalid embedding response format"))
475 }
476
477 async fn generate_openai_compatible_embedding(&self, text: &str) -> Result<Vec<f32>> {
479 let endpoint = self
480 .config
481 .embedding_endpoint
482 .as_ref()
483 .unwrap_or(&self.config.api_endpoint)
484 .replace("chat/completions", "embeddings");
485
486 let request_body = serde_json::json!({
487 "model": self.config.embedding_model,
488 "input": text
489 });
490
491 debug!("Generating embedding for text with OpenAI-compatible API");
492
493 let mut request = self
494 .client
495 .post(&endpoint)
496 .header("Content-Type", "application/json")
497 .json(&request_body);
498
499 if let Some(api_key) = &self.config.api_key {
500 request = request.header("Authorization", format!("Bearer {}", api_key));
501 }
502
503 let response = request
504 .send()
505 .await
506 .map_err(|e| crate::Error::generic(format!("Embedding API request failed: {}", e)))?;
507
508 if !response.status().is_success() {
509 let error_text = response.text().await.unwrap_or_default();
510 return Err(crate::Error::generic(format!("Embedding API error: {}", error_text)));
511 }
512
513 let response_json: Value = response.json().await.map_err(|e| {
514 crate::Error::generic(format!("Failed to parse embedding response: {}", e))
515 })?;
516
517 if let Some(data) = response_json.get("data").and_then(|d| d.as_array()) {
518 if let Some(first_item) = data.first() {
519 if let Some(embedding) = first_item.get("embedding").and_then(|e| e.as_array()) {
520 let embedding_vec: Vec<f32> =
521 embedding.iter().filter_map(|v| v.as_f64().map(|f| f as f32)).collect();
522 return Ok(embedding_vec);
523 }
524 }
525 }
526
527 Err(crate::Error::generic("Invalid embedding response format"))
528 }
529
530 async fn generate_ollama_embedding(&self, text: &str) -> Result<Vec<f32>> {
534 let base_url = self.config.embedding_endpoint.as_ref().unwrap_or(&self.config.api_endpoint);
535
536 let endpoint = if base_url.ends_with("/api/embeddings") {
538 base_url.clone()
539 } else {
540 format!("{}/api/embeddings", base_url.trim_end_matches('/'))
541 };
542
543 let model = &self.config.embedding_model;
544 let request_body = serde_json::json!({
545 "model": model,
546 "prompt": text
547 });
548
549 debug!("Generating embedding for text with Ollama (model: {})", model);
550
551 let response = self
552 .client
553 .post(&endpoint)
554 .header("Content-Type", "application/json")
555 .json(&request_body)
556 .send()
557 .await
558 .map_err(|e| {
559 crate::Error::generic(format!("Ollama embedding request failed: {}", e))
560 })?;
561
562 if !response.status().is_success() {
563 let error_text = response.text().await.unwrap_or_default();
564 return Err(crate::Error::generic(format!("Ollama embedding error: {}", error_text)));
565 }
566
567 let response_json: Value = response.json().await.map_err(|e| {
568 crate::Error::generic(format!("Failed to parse Ollama embedding response: {}", e))
569 })?;
570
571 if let Some(embedding) = response_json.get("embedding").and_then(|e| e.as_array()) {
573 let embedding_vec: Vec<f32> =
574 embedding.iter().filter_map(|v| v.as_f64().map(|f| f as f32)).collect();
575 return Ok(embedding_vec);
576 }
577
578 Err(crate::Error::generic("Invalid Ollama embedding response format"))
579 }
580
581 pub async fn compute_embeddings(&mut self) -> Result<()> {
583 debug!("Computing embeddings for {} chunks", self.chunks.len());
584
585 let chunks_to_embed: Vec<(usize, String)> = self
587 .chunks
588 .iter()
589 .enumerate()
590 .filter(|(_, chunk)| chunk.embedding.is_empty())
591 .map(|(idx, chunk)| (idx, chunk.content.clone()))
592 .collect();
593
594 for (idx, content) in chunks_to_embed {
596 let embedding = self.generate_embedding(&content).await?;
597 self.chunks[idx].embedding = embedding;
598 debug!("Computed embedding for chunk {}", self.chunks[idx].id);
599 }
600
601 Ok(())
602 }
603
604 fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
606 if a.len() != b.len() || a.is_empty() {
607 return 0.0;
608 }
609
610 let mut dot_product = 0.0;
611 let mut norm_a = 0.0;
612 let mut norm_b = 0.0;
613
614 for i in 0..a.len() {
615 dot_product += a[i] as f64 * b[i] as f64;
616 norm_a += (a[i] as f64).powi(2);
617 norm_b += (b[i] as f64).powi(2);
618 }
619
620 norm_a = norm_a.sqrt();
621 norm_b = norm_b.sqrt();
622
623 if norm_a == 0.0 || norm_b == 0.0 {
624 0.0
625 } else {
626 dot_product / (norm_a * norm_b)
627 }
628 }
629
630 async fn call_llm(&self, prompt: &str) -> Result<String> {
632 let mut last_error = None;
633
634 for attempt in 0..=self.config.max_retries {
635 match self.call_llm_single_attempt(prompt).await {
636 Ok(result) => return Ok(result),
637 Err(e) => {
638 last_error = Some(e);
639 if attempt < self.config.max_retries {
640 let delay = Duration::from_millis(500 * (attempt + 1) as u64);
641 warn!(
642 "LLM API call failed (attempt {}), retrying in {:?}: {:?}",
643 attempt + 1,
644 delay,
645 last_error
646 );
647 sleep(delay).await;
648 }
649 }
650 }
651 }
652
653 Err(last_error
654 .unwrap_or_else(|| crate::Error::generic("All LLM API retry attempts failed")))
655 }
656
657 async fn call_llm_single_attempt(&self, prompt: &str) -> Result<String> {
659 match &self.config.provider {
660 LlmProvider::OpenAI => self.call_openai(prompt).await,
661 LlmProvider::Anthropic => self.call_anthropic(prompt).await,
662 LlmProvider::OpenAICompatible => self.call_openai_compatible(prompt).await,
663 LlmProvider::Ollama => self.call_ollama(prompt).await,
664 }
665 }
666
667 async fn call_openai(&self, prompt: &str) -> Result<String> {
669 let api_key = self
670 .config
671 .api_key
672 .as_ref()
673 .ok_or_else(|| crate::Error::generic("OpenAI API key not configured"))?;
674
675 let request_body = serde_json::json!({
676 "model": self.config.model,
677 "messages": [
678 {
679 "role": "user",
680 "content": prompt
681 }
682 ],
683 "max_tokens": self.config.max_tokens,
684 "temperature": self.config.temperature
685 });
686
687 debug!("Calling OpenAI API with model: {}", self.config.model);
688
689 let response = self
690 .client
691 .post(&self.config.api_endpoint)
692 .header("Authorization", format!("Bearer {}", api_key))
693 .header("Content-Type", "application/json")
694 .json(&request_body)
695 .send()
696 .await
697 .map_err(|e| crate::Error::generic(format!("OpenAI API request failed: {}", e)))?;
698
699 if !response.status().is_success() {
700 let error_text = response.text().await.unwrap_or_default();
701 return Err(crate::Error::generic(format!("OpenAI API error: {}", error_text)));
702 }
703
704 let response_json: Value = response.json().await.map_err(|e| {
705 crate::Error::generic(format!("Failed to parse OpenAI response: {}", e))
706 })?;
707
708 if let Some(choices) = response_json.get("choices").and_then(|c| c.as_array()) {
709 if let Some(choice) = choices.first() {
710 if let Some(message) = choice.get("message").and_then(|m| m.get("content")) {
711 if let Some(content) = message.as_str() {
712 return Ok(content.to_string());
713 }
714 }
715 }
716 }
717
718 Err(crate::Error::generic("Invalid OpenAI response format"))
719 }
720
721 async fn call_anthropic(&self, prompt: &str) -> Result<String> {
723 let api_key = self
724 .config
725 .api_key
726 .as_ref()
727 .ok_or_else(|| crate::Error::generic("Anthropic API key not configured"))?;
728
729 let request_body = serde_json::json!({
730 "model": self.config.model,
731 "max_tokens": self.config.max_tokens,
732 "temperature": self.config.temperature,
733 "messages": [
734 {
735 "role": "user",
736 "content": prompt
737 }
738 ]
739 });
740
741 debug!("Calling Anthropic API with model: {}", self.config.model);
742
743 let response = self
744 .client
745 .post(&self.config.api_endpoint)
746 .header("x-api-key", api_key)
747 .header("Content-Type", "application/json")
748 .header("anthropic-version", "2023-06-01")
749 .json(&request_body)
750 .send()
751 .await
752 .map_err(|e| crate::Error::generic(format!("Anthropic API request failed: {}", e)))?;
753
754 if !response.status().is_success() {
755 let error_text = response.text().await.unwrap_or_default();
756 return Err(crate::Error::generic(format!("Anthropic API error: {}", error_text)));
757 }
758
759 let response_json: Value = response.json().await.map_err(|e| {
760 crate::Error::generic(format!("Failed to parse Anthropic response: {}", e))
761 })?;
762
763 if let Some(content) = response_json.get("content") {
764 if let Some(content_array) = content.as_array() {
765 if let Some(first_content) = content_array.first() {
766 if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
767 return Ok(text.to_string());
768 }
769 }
770 }
771 }
772
773 Err(crate::Error::generic("Invalid Anthropic response format"))
774 }
775
776 async fn call_openai_compatible(&self, prompt: &str) -> Result<String> {
778 let request_body = serde_json::json!({
779 "model": self.config.model,
780 "messages": [
781 {
782 "role": "user",
783 "content": prompt
784 }
785 ],
786 "max_tokens": self.config.max_tokens,
787 "temperature": self.config.temperature
788 });
789
790 debug!("Calling OpenAI-compatible API with model: {}", self.config.model);
791
792 let mut request = self
793 .client
794 .post(&self.config.api_endpoint)
795 .header("Content-Type", "application/json")
796 .json(&request_body);
797
798 if let Some(api_key) = &self.config.api_key {
799 request = request.header("Authorization", format!("Bearer {}", api_key));
800 }
801
802 let response = request.send().await.map_err(|e| {
803 crate::Error::generic(format!("OpenAI-compatible API request failed: {}", e))
804 })?;
805
806 if !response.status().is_success() {
807 let error_text = response.text().await.unwrap_or_default();
808 return Err(crate::Error::generic(format!(
809 "OpenAI-compatible API error: {}",
810 error_text
811 )));
812 }
813
814 let response_json: Value = response.json().await.map_err(|e| {
815 crate::Error::generic(format!("Failed to parse OpenAI-compatible response: {}", e))
816 })?;
817
818 if let Some(choices) = response_json.get("choices").and_then(|c| c.as_array()) {
819 if let Some(choice) = choices.first() {
820 if let Some(message) = choice.get("message").and_then(|m| m.get("content")) {
821 if let Some(content) = message.as_str() {
822 return Ok(content.to_string());
823 }
824 }
825 }
826 }
827
828 Err(crate::Error::generic("Invalid OpenAI-compatible response format"))
829 }
830
831 async fn call_ollama(&self, prompt: &str) -> Result<String> {
833 let request_body = serde_json::json!({
834 "model": self.config.model,
835 "prompt": prompt,
836 "stream": false
837 });
838
839 debug!("Calling Ollama API with model: {}", self.config.model);
840
841 let response = self
842 .client
843 .post(&self.config.api_endpoint)
844 .header("Content-Type", "application/json")
845 .json(&request_body)
846 .send()
847 .await
848 .map_err(|e| crate::Error::generic(format!("Ollama API request failed: {}", e)))?;
849
850 if !response.status().is_success() {
851 let error_text = response.text().await.unwrap_or_default();
852 return Err(crate::Error::generic(format!("Ollama API error: {}", error_text)));
853 }
854
855 let response_json: Value = response.json().await.map_err(|e| {
856 crate::Error::generic(format!("Failed to parse Ollama response: {}", e))
857 })?;
858
859 if let Some(response_text) = response_json.get("response").and_then(|r| r.as_str()) {
860 return Ok(response_text.to_string());
861 }
862
863 Err(crate::Error::generic("Invalid Ollama response format"))
864 }
865
866 fn parse_llm_response(&self, response: &str) -> Result<Value> {
868 match serde_json::from_str(response) {
870 Ok(value) => Ok(value),
871 Err(e) => {
872 if let Some(start) = response.find('{') {
874 if let Some(end) = response.rfind('}') {
875 let json_str = &response[start..=end];
876 match serde_json::from_str(json_str) {
877 Ok(value) => Ok(value),
878 Err(_) => Err(crate::Error::generic(format!(
879 "Failed to parse LLM response: {}",
880 e
881 ))),
882 }
883 } else {
884 Err(crate::Error::generic(format!(
885 "No closing brace found in response: {}",
886 e
887 )))
888 }
889 } else {
890 Err(crate::Error::generic(format!("No JSON found in response: {}", e)))
891 }
892 }
893 }
894 }
895
896 pub fn update_config(&mut self, config: RagConfig) {
898 self.config = config;
899 }
900
901 pub fn config(&self) -> &RagConfig {
903 &self.config
904 }
905
906 pub fn chunk_count(&self) -> usize {
908 self.chunks.len()
909 }
910
911 pub fn schema_count(&self) -> usize {
913 self.schema_kb.len()
914 }
915
916 pub fn get_chunk(&self, index: usize) -> Option<&DocumentChunk> {
918 self.chunks.get(index)
919 }
920
921 pub fn has_schema(&self, name: &str) -> bool {
923 self.schema_kb.contains_key(name)
924 }
925
926 pub async fn generate_text(&self, prompt: &str) -> Result<String> {
928 self.call_llm(prompt).await
929 }
930}
931
932impl Default for RagEngine {
933 fn default() -> Self {
934 Self::new(RagConfig::default())
935 }
936}
937
938pub mod rag_utils {
940 use super::*;
941
942 pub fn create_business_rag_engine() -> Result<RagEngine> {
944 let mut engine = RagEngine::default();
945
946 engine.add_document(
948 "Customer data typically includes personal information like name, email, phone, and address. Customers usually have unique identifiers and account creation dates.".to_string(),
949 HashMap::from([
950 ("domain".to_string(), Value::String("customer".to_string())),
951 ("type".to_string(), Value::String("general".to_string())),
952 ]),
953 )?;
954
955 engine.add_document(
956 "Product information includes name, description, price, category, and stock status. Products should have unique SKUs or IDs.".to_string(),
957 HashMap::from([
958 ("domain".to_string(), Value::String("product".to_string())),
959 ("type".to_string(), Value::String("general".to_string())),
960 ]),
961 )?;
962
963 engine.add_document(
964 "Order data contains customer references, product lists, total amounts, status, and timestamps. Orders should maintain referential integrity with customers and products.".to_string(),
965 HashMap::from([
966 ("domain".to_string(), Value::String("order".to_string())),
967 ("type".to_string(), Value::String("general".to_string())),
968 ]),
969 )?;
970
971 Ok(engine)
972 }
973
974 pub fn create_technical_rag_engine() -> Result<RagEngine> {
976 let mut engine = RagEngine::default();
977
978 engine.add_document(
980 "API endpoints should follow RESTful conventions with proper HTTP methods. GET for retrieval, POST for creation, PUT for updates, DELETE for removal.".to_string(),
981 HashMap::from([
982 ("domain".to_string(), Value::String("api".to_string())),
983 ("type".to_string(), Value::String("technical".to_string())),
984 ]),
985 )?;
986
987 engine.add_document(
988 "Database records typically have auto-incrementing primary keys, created_at and updated_at timestamps, and foreign key relationships.".to_string(),
989 HashMap::from([
990 ("domain".to_string(), Value::String("database".to_string())),
991 ("type".to_string(), Value::String("technical".to_string())),
992 ]),
993 )?;
994
995 Ok(engine)
996 }
997}
998#[cfg(test)]
999mod tests {
1000 use super::*;
1001
1002 #[test]
1003 fn test_llm_provider_variants() {
1004 let openai = LlmProvider::OpenAI;
1005 let anthropic = LlmProvider::Anthropic;
1006 let compatible = LlmProvider::OpenAICompatible;
1007 let ollama = LlmProvider::Ollama;
1008
1009 assert!(matches!(openai, LlmProvider::OpenAI));
1010 assert!(matches!(anthropic, LlmProvider::Anthropic));
1011 assert!(matches!(compatible, LlmProvider::OpenAICompatible));
1012 assert!(matches!(ollama, LlmProvider::Ollama));
1013 }
1014
1015 #[test]
1016 fn test_embedding_provider_variants() {
1017 let openai = EmbeddingProvider::OpenAI;
1018 let compatible = EmbeddingProvider::OpenAICompatible;
1019
1020 assert!(matches!(openai, EmbeddingProvider::OpenAI));
1021 assert!(matches!(compatible, EmbeddingProvider::OpenAICompatible));
1022 }
1023
1024 #[test]
1025 fn test_rag_config_default() {
1026 let config = RagConfig::default();
1027
1028 assert!(config.max_tokens > 0);
1029 assert!(config.temperature >= 0.0 && config.temperature <= 1.0);
1030 assert!(config.context_window > 0);
1031 }
1032}