1pub mod config;
12pub mod engine;
13pub mod providers;
14pub mod storage;
15pub mod utils;
16
17pub use config::*;
19pub use providers::*;
20pub use utils::*;
21
22pub use engine::StorageStats as EngineStorageStats;
24pub use storage::StorageStats as StorageStorageStats;
25
26use crate::{schema::SchemaDefinition, DataConfig};
28use mockforge_core::Result;
29use reqwest::{Client, ClientBuilder};
30use serde::{Deserialize, Serialize};
31use serde_json::Value;
32use std::cmp::Ordering;
33use std::collections::HashMap;
34use std::time::Duration;
35use tokio::time::sleep;
36use tracing::{debug, warn};
37
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
40#[serde(rename_all = "lowercase")]
41pub enum LlmProvider {
42 OpenAI,
44 Anthropic,
46 OpenAICompatible,
48 Ollama,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
54#[serde(rename_all = "lowercase")]
55pub enum EmbeddingProvider {
56 OpenAI,
58 OpenAICompatible,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct RagConfig {
65 pub provider: LlmProvider,
67 pub api_endpoint: String,
69 pub api_key: Option<String>,
71 pub model: String,
73 pub max_tokens: usize,
75 pub temperature: f64,
77 pub context_window: usize,
79
80 pub semantic_search_enabled: bool,
82 pub embedding_provider: EmbeddingProvider,
84 pub embedding_model: String,
86 pub embedding_endpoint: Option<String>,
88 pub similarity_threshold: f64,
90 pub max_chunks: usize,
92
93 pub request_timeout_seconds: u64,
95 pub max_retries: usize,
97}
98
99impl Default for RagConfig {
100 fn default() -> Self {
101 Self {
102 provider: LlmProvider::OpenAI,
103 api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
104 api_key: None,
105 model: "gpt-3.5-turbo".to_string(),
106 max_tokens: 1000,
107 temperature: 0.7,
108 context_window: 4000,
109 semantic_search_enabled: true,
110 embedding_provider: EmbeddingProvider::OpenAI,
111 embedding_model: "text-embedding-ada-002".to_string(),
112 embedding_endpoint: None,
113 similarity_threshold: 0.7,
114 max_chunks: 5,
115 request_timeout_seconds: 30,
116 max_retries: 3,
117 }
118 }
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct DocumentChunk {
124 pub id: String,
126 pub content: String,
128 pub metadata: HashMap<String, Value>,
130 pub embedding: Vec<f32>,
132}
133
134#[derive(Debug)]
136pub struct SearchResult<'a> {
137 pub chunk: &'a DocumentChunk,
139 pub score: f64,
141}
142
143#[derive(Debug)]
145pub struct RagEngine {
146 config: RagConfig,
148 chunks: Vec<DocumentChunk>,
150 schema_kb: HashMap<String, Vec<String>>,
152 client: Client,
154}
155
156impl RagEngine {
157 pub fn new(config: RagConfig) -> Self {
159 let client = ClientBuilder::new()
160 .timeout(Duration::from_secs(config.request_timeout_seconds))
161 .build()
162 .unwrap_or_else(|e| {
163 warn!("Failed to create HTTP client with timeout, using default: {}", e);
164 Client::new()
165 });
166
167 Self {
168 config,
169 chunks: Vec::new(),
170 schema_kb: HashMap::new(),
171 client,
172 }
173 }
174
175 pub fn add_document(
177 &mut self,
178 content: String,
179 metadata: HashMap<String, Value>,
180 ) -> Result<String> {
181 let id = format!("chunk_{}", self.chunks.len());
182 let chunk = DocumentChunk {
183 id: id.clone(),
184 content,
185 metadata,
186 embedding: Vec::new(), };
188
189 self.chunks.push(chunk);
190 Ok(id)
191 }
192
193 pub fn add_schema(&mut self, schema: &SchemaDefinition) -> Result<()> {
195 let mut schema_info = Vec::new();
196
197 schema_info.push(format!("Schema: {}", schema.name));
198
199 if let Some(description) = &schema.description {
200 schema_info.push(format!("Description: {}", description));
201 }
202
203 for field in &schema.fields {
204 let mut field_info = format!(
205 "Field '{}': type={}, required={}",
206 field.name, field.field_type, field.required
207 );
208
209 if let Some(description) = &field.description {
210 field_info.push_str(&format!(" - {}", description));
211 }
212
213 schema_info.push(field_info);
214 }
215
216 for (rel_name, relationship) in &schema.relationships {
217 schema_info.push(format!(
218 "Relationship '{}': {} -> {} ({:?})",
219 rel_name, schema.name, relationship.target_schema, relationship.relationship_type
220 ));
221 }
222
223 self.schema_kb.insert(schema.name.clone(), schema_info);
224 Ok(())
225 }
226
227 pub async fn generate_with_rag(
229 &self,
230 schema: &SchemaDefinition,
231 config: &DataConfig,
232 ) -> Result<Vec<Value>> {
233 if !config.rag_enabled {
234 return Err(mockforge_core::Error::generic("RAG is not enabled in config"));
235 }
236
237 if self.config.api_key.is_none() {
239 return Err(mockforge_core::Error::generic(
240 "RAG is enabled but no API key is configured. Please set MOCKFORGE_RAG_API_KEY or provide --rag-api-key"
241 ));
242 }
243
244 let mut results = Vec::new();
245 let mut failed_rows = 0;
246
247 for i in 0..config.rows {
249 match self.generate_single_row_with_rag(schema, i).await {
250 Ok(data) => results.push(data),
251 Err(e) => {
252 failed_rows += 1;
253 warn!("Failed to generate RAG data for row {}: {}", i, e);
254
255 if failed_rows > config.rows / 4 {
257 return Err(mockforge_core::Error::generic(
259 format!("Too many RAG generation failures ({} out of {} rows failed). Check API configuration and network connectivity.", failed_rows, config.rows)
260 ));
261 }
262
263 let fallback_data = self.generate_fallback_data(schema);
265 results.push(fallback_data);
266 }
267 }
268 }
269
270 if failed_rows > 0 {
271 warn!(
272 "RAG generation completed with {} failed rows out of {}",
273 failed_rows, config.rows
274 );
275 }
276
277 Ok(results)
278 }
279
280 async fn generate_single_row_with_rag(
282 &self,
283 schema: &SchemaDefinition,
284 row_index: usize,
285 ) -> Result<Value> {
286 let prompt = self.build_generation_prompt(schema, row_index).await?;
287 let generated_data = self.call_llm(&prompt).await?;
288 self.parse_llm_response(&generated_data)
289 }
290
291 fn generate_fallback_data(&self, schema: &SchemaDefinition) -> Value {
293 let mut obj = serde_json::Map::new();
294
295 for field in &schema.fields {
296 let value = match field.field_type.as_str() {
297 "string" => Value::String("sample_data".to_string()),
298 "integer" | "number" => Value::Number(42.into()),
299 "boolean" => Value::Bool(true),
300 _ => Value::String("sample_data".to_string()),
301 };
302 obj.insert(field.name.clone(), value);
303 }
304
305 Value::Object(obj)
306 }
307
308 async fn build_generation_prompt(
310 &self,
311 schema: &SchemaDefinition,
312 _row_index: usize,
313 ) -> Result<String> {
314 let mut prompt =
315 format!("Generate a single row of data for the '{}' schema.\n\n", schema.name);
316
317 if let Some(schema_info) = self.schema_kb.get(&schema.name) {
319 prompt.push_str("Schema Information:\n");
320 for info in schema_info {
321 prompt.push_str(&format!("- {}\n", info));
322 }
323 prompt.push('\n');
324 }
325
326 let relevant_chunks = self.retrieve_relevant_chunks(&schema.name, 3).await?;
328 if !relevant_chunks.is_empty() {
329 prompt.push_str("Relevant Context:\n");
330 for chunk in relevant_chunks {
331 prompt.push_str(&format!("- {}\n", chunk.content));
332 }
333 prompt.push('\n');
334 }
335
336 prompt.push_str("Instructions:\n");
338 prompt.push_str("- Generate realistic data that matches the schema\n");
339 prompt.push_str("- Ensure all required fields are present\n");
340 prompt.push_str("- Use appropriate data types and formats\n");
341 prompt.push_str("- Make relationships consistent if referenced\n");
342 prompt.push_str("- Output only valid JSON for a single object\n\n");
343
344 prompt.push_str("Generate the data:");
345
346 Ok(prompt)
347 }
348
349 async fn retrieve_relevant_chunks(
351 &self,
352 query: &str,
353 limit: usize,
354 ) -> Result<Vec<&DocumentChunk>> {
355 if self.config.semantic_search_enabled {
356 let results = self.semantic_search(query, limit).await?;
358 Ok(results.into_iter().map(|r| r.chunk).collect())
359 } else {
360 Ok(self.keyword_search(query, limit))
362 }
363 }
364
365 pub fn keyword_search(&self, query: &str, limit: usize) -> Vec<&DocumentChunk> {
367 self.chunks
368 .iter()
369 .filter(|chunk| {
370 chunk.content.to_lowercase().contains(&query.to_lowercase())
371 || chunk.metadata.values().any(|v| {
372 if let Some(s) = v.as_str() {
373 s.to_lowercase().contains(&query.to_lowercase())
374 } else {
375 false
376 }
377 })
378 })
379 .take(limit)
380 .collect()
381 }
382
383 async fn semantic_search(&self, query: &str, limit: usize) -> Result<Vec<SearchResult<'_>>> {
385 let query_embedding = self.generate_embedding(query).await?;
387
388 let mut results: Vec<SearchResult> = Vec::new();
390
391 for chunk in &self.chunks {
392 if chunk.embedding.is_empty() {
393 continue;
395 }
396
397 let score = Self::cosine_similarity(&query_embedding, &chunk.embedding);
398 if score >= self.config.similarity_threshold {
399 results.push(SearchResult { chunk, score });
400 }
401 }
402
403 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
405 results.truncate(limit);
406
407 Ok(results)
408 }
409
410 async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
412 match &self.config.embedding_provider {
413 EmbeddingProvider::OpenAI => self.generate_openai_embedding(text).await,
414 EmbeddingProvider::OpenAICompatible => {
415 self.generate_openai_compatible_embedding(text).await
416 }
417 }
418 }
419
420 async fn generate_openai_embedding(&self, text: &str) -> Result<Vec<f32>> {
422 let api_key = self
423 .config
424 .api_key
425 .as_ref()
426 .ok_or_else(|| mockforge_core::Error::generic("OpenAI API key not configured"))?;
427
428 let endpoint = self
429 .config
430 .embedding_endpoint
431 .as_ref()
432 .unwrap_or(&self.config.api_endpoint)
433 .replace("chat/completions", "embeddings");
434
435 let request_body = serde_json::json!({
436 "model": self.config.embedding_model,
437 "input": text
438 });
439
440 debug!("Generating embedding for text with OpenAI API");
441
442 let response = self
443 .client
444 .post(&endpoint)
445 .header("Authorization", format!("Bearer {}", api_key))
446 .header("Content-Type", "application/json")
447 .json(&request_body)
448 .send()
449 .await
450 .map_err(|e| {
451 mockforge_core::Error::generic(format!("Embedding API request failed: {}", e))
452 })?;
453
454 if !response.status().is_success() {
455 let error_text = response.text().await.unwrap_or_default();
456 return Err(mockforge_core::Error::generic(format!(
457 "Embedding API error: {}",
458 error_text
459 )));
460 }
461
462 let response_json: Value = response.json().await.map_err(|e| {
463 mockforge_core::Error::generic(format!("Failed to parse embedding response: {}", e))
464 })?;
465
466 if let Some(data) = response_json.get("data").and_then(|d| d.as_array()) {
467 if let Some(first_item) = data.first() {
468 if let Some(embedding) = first_item.get("embedding").and_then(|e| e.as_array()) {
469 let embedding_vec: Vec<f32> =
470 embedding.iter().filter_map(|v| v.as_f64().map(|f| f as f32)).collect();
471 return Ok(embedding_vec);
472 }
473 }
474 }
475
476 Err(mockforge_core::Error::generic("Invalid embedding response format"))
477 }
478
479 async fn generate_openai_compatible_embedding(&self, text: &str) -> Result<Vec<f32>> {
481 let endpoint = self
482 .config
483 .embedding_endpoint
484 .as_ref()
485 .unwrap_or(&self.config.api_endpoint)
486 .replace("chat/completions", "embeddings");
487
488 let request_body = serde_json::json!({
489 "model": self.config.embedding_model,
490 "input": text
491 });
492
493 debug!("Generating embedding for text with OpenAI-compatible API");
494
495 let mut request = self
496 .client
497 .post(&endpoint)
498 .header("Content-Type", "application/json")
499 .json(&request_body);
500
501 if let Some(api_key) = &self.config.api_key {
502 request = request.header("Authorization", format!("Bearer {}", api_key));
503 }
504
505 let response = request.send().await.map_err(|e| {
506 mockforge_core::Error::generic(format!("Embedding API request failed: {}", e))
507 })?;
508
509 if !response.status().is_success() {
510 let error_text = response.text().await.unwrap_or_default();
511 return Err(mockforge_core::Error::generic(format!(
512 "Embedding API error: {}",
513 error_text
514 )));
515 }
516
517 let response_json: Value = response.json().await.map_err(|e| {
518 mockforge_core::Error::generic(format!("Failed to parse embedding response: {}", e))
519 })?;
520
521 if let Some(data) = response_json.get("data").and_then(|d| d.as_array()) {
522 if let Some(first_item) = data.first() {
523 if let Some(embedding) = first_item.get("embedding").and_then(|e| e.as_array()) {
524 let embedding_vec: Vec<f32> =
525 embedding.iter().filter_map(|v| v.as_f64().map(|f| f as f32)).collect();
526 return Ok(embedding_vec);
527 }
528 }
529 }
530
531 Err(mockforge_core::Error::generic("Invalid embedding response format"))
532 }
533
534 pub async fn compute_embeddings(&mut self) -> Result<()> {
536 debug!("Computing embeddings for {} chunks", self.chunks.len());
537
538 let chunks_to_embed: Vec<(usize, String)> = self
540 .chunks
541 .iter()
542 .enumerate()
543 .filter(|(_, chunk)| chunk.embedding.is_empty())
544 .map(|(idx, chunk)| (idx, chunk.content.clone()))
545 .collect();
546
547 for (idx, content) in chunks_to_embed {
549 let embedding = self.generate_embedding(&content).await?;
550 self.chunks[idx].embedding = embedding;
551 debug!("Computed embedding for chunk {}", self.chunks[idx].id);
552 }
553
554 Ok(())
555 }
556
557 fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
559 if a.len() != b.len() || a.is_empty() {
560 return 0.0;
561 }
562
563 let mut dot_product = 0.0;
564 let mut norm_a = 0.0;
565 let mut norm_b = 0.0;
566
567 for i in 0..a.len() {
568 dot_product += a[i] as f64 * b[i] as f64;
569 norm_a += (a[i] as f64).powi(2);
570 norm_b += (b[i] as f64).powi(2);
571 }
572
573 norm_a = norm_a.sqrt();
574 norm_b = norm_b.sqrt();
575
576 if norm_a == 0.0 || norm_b == 0.0 {
577 0.0
578 } else {
579 dot_product / (norm_a * norm_b)
580 }
581 }
582
583 async fn call_llm(&self, prompt: &str) -> Result<String> {
585 let mut last_error = None;
586
587 for attempt in 0..=self.config.max_retries {
588 match self.call_llm_single_attempt(prompt).await {
589 Ok(result) => return Ok(result),
590 Err(e) => {
591 last_error = Some(e);
592 if attempt < self.config.max_retries {
593 let delay = Duration::from_millis(500 * (attempt + 1) as u64);
594 warn!(
595 "LLM API call failed (attempt {}), retrying in {:?}: {:?}",
596 attempt + 1,
597 delay,
598 last_error
599 );
600 sleep(delay).await;
601 }
602 }
603 }
604 }
605
606 Err(last_error
607 .unwrap_or_else(|| mockforge_core::Error::generic("All LLM API retry attempts failed")))
608 }
609
610 async fn call_llm_single_attempt(&self, prompt: &str) -> Result<String> {
612 match &self.config.provider {
613 LlmProvider::OpenAI => self.call_openai(prompt).await,
614 LlmProvider::Anthropic => self.call_anthropic(prompt).await,
615 LlmProvider::OpenAICompatible => self.call_openai_compatible(prompt).await,
616 LlmProvider::Ollama => self.call_ollama(prompt).await,
617 }
618 }
619
620 async fn call_openai(&self, prompt: &str) -> Result<String> {
622 let api_key = self
623 .config
624 .api_key
625 .as_ref()
626 .ok_or_else(|| mockforge_core::Error::generic("OpenAI API key not configured"))?;
627
628 let request_body = serde_json::json!({
629 "model": self.config.model,
630 "messages": [
631 {
632 "role": "user",
633 "content": prompt
634 }
635 ],
636 "max_tokens": self.config.max_tokens,
637 "temperature": self.config.temperature
638 });
639
640 debug!("Calling OpenAI API with model: {}", self.config.model);
641
642 let response = self
643 .client
644 .post(&self.config.api_endpoint)
645 .header("Authorization", format!("Bearer {}", api_key))
646 .header("Content-Type", "application/json")
647 .json(&request_body)
648 .send()
649 .await
650 .map_err(|e| {
651 mockforge_core::Error::generic(format!("OpenAI API request failed: {}", e))
652 })?;
653
654 if !response.status().is_success() {
655 let error_text = response.text().await.unwrap_or_default();
656 return Err(mockforge_core::Error::generic(format!(
657 "OpenAI API error: {}",
658 error_text
659 )));
660 }
661
662 let response_json: Value = response.json().await.map_err(|e| {
663 mockforge_core::Error::generic(format!("Failed to parse OpenAI response: {}", e))
664 })?;
665
666 if let Some(choices) = response_json.get("choices").and_then(|c| c.as_array()) {
667 if let Some(choice) = choices.first() {
668 if let Some(message) = choice.get("message").and_then(|m| m.get("content")) {
669 if let Some(content) = message.as_str() {
670 return Ok(content.to_string());
671 }
672 }
673 }
674 }
675
676 Err(mockforge_core::Error::generic("Invalid OpenAI response format"))
677 }
678
679 async fn call_anthropic(&self, prompt: &str) -> Result<String> {
681 let api_key =
682 self.config.api_key.as_ref().ok_or_else(|| {
683 mockforge_core::Error::generic("Anthropic API key not configured")
684 })?;
685
686 let request_body = serde_json::json!({
687 "model": self.config.model,
688 "max_tokens": self.config.max_tokens,
689 "temperature": self.config.temperature,
690 "messages": [
691 {
692 "role": "user",
693 "content": prompt
694 }
695 ]
696 });
697
698 debug!("Calling Anthropic API with model: {}", self.config.model);
699
700 let response = self
701 .client
702 .post(&self.config.api_endpoint)
703 .header("x-api-key", api_key)
704 .header("Content-Type", "application/json")
705 .header("anthropic-version", "2023-06-01")
706 .json(&request_body)
707 .send()
708 .await
709 .map_err(|e| {
710 mockforge_core::Error::generic(format!("Anthropic API request failed: {}", e))
711 })?;
712
713 if !response.status().is_success() {
714 let error_text = response.text().await.unwrap_or_default();
715 return Err(mockforge_core::Error::generic(format!(
716 "Anthropic API error: {}",
717 error_text
718 )));
719 }
720
721 let response_json: Value = response.json().await.map_err(|e| {
722 mockforge_core::Error::generic(format!("Failed to parse Anthropic response: {}", e))
723 })?;
724
725 if let Some(content) = response_json.get("content") {
726 if let Some(content_array) = content.as_array() {
727 if let Some(first_content) = content_array.first() {
728 if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
729 return Ok(text.to_string());
730 }
731 }
732 }
733 }
734
735 Err(mockforge_core::Error::generic("Invalid Anthropic response format"))
736 }
737
738 async fn call_openai_compatible(&self, prompt: &str) -> Result<String> {
740 let request_body = serde_json::json!({
741 "model": self.config.model,
742 "messages": [
743 {
744 "role": "user",
745 "content": prompt
746 }
747 ],
748 "max_tokens": self.config.max_tokens,
749 "temperature": self.config.temperature
750 });
751
752 debug!("Calling OpenAI-compatible API with model: {}", self.config.model);
753
754 let mut request = self
755 .client
756 .post(&self.config.api_endpoint)
757 .header("Content-Type", "application/json")
758 .json(&request_body);
759
760 if let Some(api_key) = &self.config.api_key {
761 request = request.header("Authorization", format!("Bearer {}", api_key));
762 }
763
764 let response = request.send().await.map_err(|e| {
765 mockforge_core::Error::generic(format!("OpenAI-compatible API request failed: {}", e))
766 })?;
767
768 if !response.status().is_success() {
769 let error_text = response.text().await.unwrap_or_default();
770 return Err(mockforge_core::Error::generic(format!(
771 "OpenAI-compatible API error: {}",
772 error_text
773 )));
774 }
775
776 let response_json: Value = response.json().await.map_err(|e| {
777 mockforge_core::Error::generic(format!(
778 "Failed to parse OpenAI-compatible response: {}",
779 e
780 ))
781 })?;
782
783 if let Some(choices) = response_json.get("choices").and_then(|c| c.as_array()) {
784 if let Some(choice) = choices.first() {
785 if let Some(message) = choice.get("message").and_then(|m| m.get("content")) {
786 if let Some(content) = message.as_str() {
787 return Ok(content.to_string());
788 }
789 }
790 }
791 }
792
793 Err(mockforge_core::Error::generic("Invalid OpenAI-compatible response format"))
794 }
795
796 async fn call_ollama(&self, prompt: &str) -> Result<String> {
798 let request_body = serde_json::json!({
799 "model": self.config.model,
800 "prompt": prompt,
801 "stream": false
802 });
803
804 debug!("Calling Ollama API with model: {}", self.config.model);
805
806 let response = self
807 .client
808 .post(&self.config.api_endpoint)
809 .header("Content-Type", "application/json")
810 .json(&request_body)
811 .send()
812 .await
813 .map_err(|e| {
814 mockforge_core::Error::generic(format!("Ollama API request failed: {}", e))
815 })?;
816
817 if !response.status().is_success() {
818 let error_text = response.text().await.unwrap_or_default();
819 return Err(mockforge_core::Error::generic(format!(
820 "Ollama API error: {}",
821 error_text
822 )));
823 }
824
825 let response_json: Value = response.json().await.map_err(|e| {
826 mockforge_core::Error::generic(format!("Failed to parse Ollama response: {}", e))
827 })?;
828
829 if let Some(response_text) = response_json.get("response").and_then(|r| r.as_str()) {
830 return Ok(response_text.to_string());
831 }
832
833 Err(mockforge_core::Error::generic("Invalid Ollama response format"))
834 }
835
836 fn parse_llm_response(&self, response: &str) -> Result<Value> {
838 match serde_json::from_str(response) {
840 Ok(value) => Ok(value),
841 Err(e) => {
842 if let Some(start) = response.find('{') {
844 if let Some(end) = response.rfind('}') {
845 let json_str = &response[start..=end];
846 match serde_json::from_str(json_str) {
847 Ok(value) => Ok(value),
848 Err(_) => Err(mockforge_core::Error::generic(format!(
849 "Failed to parse LLM response: {}",
850 e
851 ))),
852 }
853 } else {
854 Err(mockforge_core::Error::generic(format!(
855 "No closing brace found in response: {}",
856 e
857 )))
858 }
859 } else {
860 Err(mockforge_core::Error::generic(format!("No JSON found in response: {}", e)))
861 }
862 }
863 }
864 }
865
866 pub fn update_config(&mut self, config: RagConfig) {
868 self.config = config;
869 }
870
871 pub fn config(&self) -> &RagConfig {
873 &self.config
874 }
875
876 pub fn chunk_count(&self) -> usize {
878 self.chunks.len()
879 }
880
881 pub fn schema_count(&self) -> usize {
883 self.schema_kb.len()
884 }
885
886 pub fn get_chunk(&self, index: usize) -> Option<&DocumentChunk> {
888 self.chunks.get(index)
889 }
890
891 pub fn has_schema(&self, name: &str) -> bool {
893 self.schema_kb.contains_key(name)
894 }
895
896 pub async fn generate_text(&self, prompt: &str) -> Result<String> {
898 self.call_llm(prompt).await
899 }
900}
901
902impl Default for RagEngine {
903 fn default() -> Self {
904 Self::new(RagConfig::default())
905 }
906}
907
908pub mod rag_utils {
910 use super::*;
911
912 pub fn create_business_rag_engine() -> Result<RagEngine> {
914 let mut engine = RagEngine::default();
915
916 engine.add_document(
918 "Customer data typically includes personal information like name, email, phone, and address. Customers usually have unique identifiers and account creation dates.".to_string(),
919 HashMap::from([
920 ("domain".to_string(), Value::String("customer".to_string())),
921 ("type".to_string(), Value::String("general".to_string())),
922 ]),
923 )?;
924
925 engine.add_document(
926 "Product information includes name, description, price, category, and stock status. Products should have unique SKUs or IDs.".to_string(),
927 HashMap::from([
928 ("domain".to_string(), Value::String("product".to_string())),
929 ("type".to_string(), Value::String("general".to_string())),
930 ]),
931 )?;
932
933 engine.add_document(
934 "Order data contains customer references, product lists, total amounts, status, and timestamps. Orders should maintain referential integrity with customers and products.".to_string(),
935 HashMap::from([
936 ("domain".to_string(), Value::String("order".to_string())),
937 ("type".to_string(), Value::String("general".to_string())),
938 ]),
939 )?;
940
941 Ok(engine)
942 }
943
944 pub fn create_technical_rag_engine() -> Result<RagEngine> {
946 let mut engine = RagEngine::default();
947
948 engine.add_document(
950 "API endpoints should follow RESTful conventions with proper HTTP methods. GET for retrieval, POST for creation, PUT for updates, DELETE for removal.".to_string(),
951 HashMap::from([
952 ("domain".to_string(), Value::String("api".to_string())),
953 ("type".to_string(), Value::String("technical".to_string())),
954 ]),
955 )?;
956
957 engine.add_document(
958 "Database records typically have auto-incrementing primary keys, created_at and updated_at timestamps, and foreign key relationships.".to_string(),
959 HashMap::from([
960 ("domain".to_string(), Value::String("database".to_string())),
961 ("type".to_string(), Value::String("technical".to_string())),
962 ]),
963 )?;
964
965 Ok(engine)
966 }
967}
968#[cfg(test)]
969mod tests {
970 use super::*;
971
972 #[test]
973 fn test_llm_provider_variants() {
974 let openai = LlmProvider::OpenAI;
975 let anthropic = LlmProvider::Anthropic;
976 let compatible = LlmProvider::OpenAICompatible;
977 let ollama = LlmProvider::Ollama;
978
979 assert!(matches!(openai, LlmProvider::OpenAI));
980 assert!(matches!(anthropic, LlmProvider::Anthropic));
981 assert!(matches!(compatible, LlmProvider::OpenAICompatible));
982 assert!(matches!(ollama, LlmProvider::Ollama));
983 }
984
985 #[test]
986 fn test_embedding_provider_variants() {
987 let openai = EmbeddingProvider::OpenAI;
988 let compatible = EmbeddingProvider::OpenAICompatible;
989
990 assert!(matches!(openai, EmbeddingProvider::OpenAI));
991 assert!(matches!(compatible, EmbeddingProvider::OpenAICompatible));
992 }
993
994 #[test]
995 fn test_rag_config_default() {
996 let config = RagConfig::default();
997
998 assert!(config.max_tokens > 0);
999 assert!(config.temperature >= 0.0 && config.temperature <= 1.0);
1000 assert!(config.context_window > 0);
1001 }
1002}