1pub mod config;
12pub mod engine;
13pub mod providers;
14pub mod storage;
15pub mod utils;
16
17pub use config::*;
19pub use providers::*;
20pub use utils::*;
21
22pub use engine::StorageStats as EngineStorageStats;
24pub use storage::StorageStats as StorageStorageStats;
25
26use crate::Result;
28use crate::{schema::SchemaDefinition, DataConfig};
29use reqwest::{Client, ClientBuilder};
30use serde::{Deserialize, Serialize};
31use serde_json::Value;
32use std::cmp::Ordering;
33use std::collections::HashMap;
34use std::time::Duration;
35use tokio::time::sleep;
36use tracing::{debug, warn};
37
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
40#[serde(rename_all = "lowercase")]
41pub enum LlmProvider {
42 OpenAI,
44 Anthropic,
46 OpenAICompatible,
48 Ollama,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
54#[serde(rename_all = "lowercase")]
55pub enum EmbeddingProvider {
56 OpenAI,
58 OpenAICompatible,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct RagConfig {
65 pub provider: LlmProvider,
67 pub api_endpoint: String,
69 pub api_key: Option<String>,
71 pub model: String,
73 pub max_tokens: usize,
75 pub temperature: f64,
77 pub context_window: usize,
79
80 pub semantic_search_enabled: bool,
82 pub embedding_provider: EmbeddingProvider,
84 pub embedding_model: String,
86 pub embedding_endpoint: Option<String>,
88 pub similarity_threshold: f64,
90 pub max_chunks: usize,
92
93 pub request_timeout_seconds: u64,
95 pub max_retries: usize,
97}
98
99impl Default for RagConfig {
100 fn default() -> Self {
101 Self {
102 provider: LlmProvider::OpenAI,
103 api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
104 api_key: None,
105 model: "gpt-3.5-turbo".to_string(),
106 max_tokens: 1000,
107 temperature: 0.7,
108 context_window: 4000,
109 semantic_search_enabled: true,
110 embedding_provider: EmbeddingProvider::OpenAI,
111 embedding_model: "text-embedding-ada-002".to_string(),
112 embedding_endpoint: None,
113 similarity_threshold: 0.7,
114 max_chunks: 5,
115 request_timeout_seconds: 30,
116 max_retries: 3,
117 }
118 }
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct DocumentChunk {
124 pub id: String,
126 pub content: String,
128 pub metadata: HashMap<String, Value>,
130 pub embedding: Vec<f32>,
132}
133
134#[derive(Debug)]
136pub struct SearchResult<'a> {
137 pub chunk: &'a DocumentChunk,
139 pub score: f64,
141}
142
143#[derive(Debug)]
145pub struct RagEngine {
146 config: RagConfig,
148 chunks: Vec<DocumentChunk>,
150 schema_kb: HashMap<String, Vec<String>>,
152 client: Client,
154}
155
156impl RagEngine {
157 pub fn new(config: RagConfig) -> Self {
159 let client = ClientBuilder::new()
160 .timeout(Duration::from_secs(config.request_timeout_seconds))
161 .build()
162 .unwrap_or_else(|e| {
163 warn!("Failed to create HTTP client with timeout, using default: {}", e);
164 Client::new()
165 });
166
167 Self {
168 config,
169 chunks: Vec::new(),
170 schema_kb: HashMap::new(),
171 client,
172 }
173 }
174
175 pub fn add_document(
177 &mut self,
178 content: String,
179 metadata: HashMap<String, Value>,
180 ) -> Result<String> {
181 let id = format!("chunk_{}", self.chunks.len());
182 let chunk = DocumentChunk {
183 id: id.clone(),
184 content,
185 metadata,
186 embedding: Vec::new(), };
188
189 self.chunks.push(chunk);
190 Ok(id)
191 }
192
193 pub fn add_schema(&mut self, schema: &SchemaDefinition) -> Result<()> {
195 let mut schema_info = Vec::new();
196
197 schema_info.push(format!("Schema: {}", schema.name));
198
199 if let Some(description) = &schema.description {
200 schema_info.push(format!("Description: {}", description));
201 }
202
203 for field in &schema.fields {
204 let mut field_info = format!(
205 "Field '{}': type={}, required={}",
206 field.name, field.field_type, field.required
207 );
208
209 if let Some(description) = &field.description {
210 field_info.push_str(&format!(" - {}", description));
211 }
212
213 schema_info.push(field_info);
214 }
215
216 for (rel_name, relationship) in &schema.relationships {
217 schema_info.push(format!(
218 "Relationship '{}': {} -> {} ({:?})",
219 rel_name, schema.name, relationship.target_schema, relationship.relationship_type
220 ));
221 }
222
223 self.schema_kb.insert(schema.name.clone(), schema_info);
224 Ok(())
225 }
226
227 pub async fn generate_with_rag(
229 &self,
230 schema: &SchemaDefinition,
231 config: &DataConfig,
232 ) -> Result<Vec<Value>> {
233 if !config.rag_enabled {
234 return Err(crate::Error::generic("RAG is not enabled in config"));
235 }
236
237 if self.config.api_key.is_none() {
239 return Err(crate::Error::generic(
240 "RAG is enabled but no API key is configured. Please set MOCKFORGE_RAG_API_KEY or provide --rag-api-key"
241 ));
242 }
243
244 let mut results = Vec::new();
245 let mut failed_rows = 0;
246
247 for i in 0..config.rows {
249 match self.generate_single_row_with_rag(schema, i).await {
250 Ok(data) => results.push(data),
251 Err(e) => {
252 failed_rows += 1;
253 warn!("Failed to generate RAG data for row {}: {}", i, e);
254
255 if failed_rows > config.rows / 4 {
257 return Err(crate::Error::generic(
259 format!("Too many RAG generation failures ({} out of {} rows failed). Check API configuration and network connectivity.", failed_rows, config.rows)
260 ));
261 }
262
263 let fallback_data = self.generate_fallback_data(schema);
265 results.push(fallback_data);
266 }
267 }
268 }
269
270 if failed_rows > 0 {
271 warn!(
272 "RAG generation completed with {} failed rows out of {}",
273 failed_rows, config.rows
274 );
275 }
276
277 Ok(results)
278 }
279
280 async fn generate_single_row_with_rag(
282 &self,
283 schema: &SchemaDefinition,
284 row_index: usize,
285 ) -> Result<Value> {
286 let prompt = self.build_generation_prompt(schema, row_index).await?;
287 let generated_data = self.call_llm(&prompt).await?;
288 self.parse_llm_response(&generated_data)
289 }
290
291 fn generate_fallback_data(&self, schema: &SchemaDefinition) -> Value {
293 let mut obj = serde_json::Map::new();
294
295 for field in &schema.fields {
296 let value = match field.field_type.as_str() {
297 "string" => Value::String("sample_data".to_string()),
298 "integer" | "number" => Value::Number(42.into()),
299 "boolean" => Value::Bool(true),
300 _ => Value::String("sample_data".to_string()),
301 };
302 obj.insert(field.name.clone(), value);
303 }
304
305 Value::Object(obj)
306 }
307
308 async fn build_generation_prompt(
310 &self,
311 schema: &SchemaDefinition,
312 _row_index: usize,
313 ) -> Result<String> {
314 let mut prompt =
315 format!("Generate a single row of data for the '{}' schema.\n\n", schema.name);
316
317 if let Some(schema_info) = self.schema_kb.get(&schema.name) {
319 prompt.push_str("Schema Information:\n");
320 for info in schema_info {
321 prompt.push_str(&format!("- {}\n", info));
322 }
323 prompt.push('\n');
324 }
325
326 let relevant_chunks = self.retrieve_relevant_chunks(&schema.name, 3).await?;
328 if !relevant_chunks.is_empty() {
329 prompt.push_str("Relevant Context:\n");
330 for chunk in relevant_chunks {
331 prompt.push_str(&format!("- {}\n", chunk.content));
332 }
333 prompt.push('\n');
334 }
335
336 prompt.push_str("Instructions:\n");
338 prompt.push_str("- Generate realistic data that matches the schema\n");
339 prompt.push_str("- Ensure all required fields are present\n");
340 prompt.push_str("- Use appropriate data types and formats\n");
341 prompt.push_str("- Make relationships consistent if referenced\n");
342 prompt.push_str("- Output only valid JSON for a single object\n\n");
343
344 prompt.push_str("Generate the data:");
345
346 Ok(prompt)
347 }
348
349 async fn retrieve_relevant_chunks(
351 &self,
352 query: &str,
353 limit: usize,
354 ) -> Result<Vec<&DocumentChunk>> {
355 if self.config.semantic_search_enabled {
356 let results = self.semantic_search(query, limit).await?;
358 Ok(results.into_iter().map(|r| r.chunk).collect())
359 } else {
360 Ok(self.keyword_search(query, limit))
362 }
363 }
364
365 pub fn keyword_search(&self, query: &str, limit: usize) -> Vec<&DocumentChunk> {
367 self.chunks
368 .iter()
369 .filter(|chunk| {
370 chunk.content.to_lowercase().contains(&query.to_lowercase())
371 || chunk.metadata.values().any(|v| {
372 if let Some(s) = v.as_str() {
373 s.to_lowercase().contains(&query.to_lowercase())
374 } else {
375 false
376 }
377 })
378 })
379 .take(limit)
380 .collect()
381 }
382
383 async fn semantic_search(&self, query: &str, limit: usize) -> Result<Vec<SearchResult<'_>>> {
385 let query_embedding = self.generate_embedding(query).await?;
387
388 let mut results: Vec<SearchResult> = Vec::new();
390
391 for chunk in &self.chunks {
392 if chunk.embedding.is_empty() {
393 continue;
395 }
396
397 let score = Self::cosine_similarity(&query_embedding, &chunk.embedding);
398 if score >= self.config.similarity_threshold {
399 results.push(SearchResult { chunk, score });
400 }
401 }
402
403 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
405 results.truncate(limit);
406
407 Ok(results)
408 }
409
410 async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
412 match &self.config.embedding_provider {
413 EmbeddingProvider::OpenAI => self.generate_openai_embedding(text).await,
414 EmbeddingProvider::OpenAICompatible => {
415 self.generate_openai_compatible_embedding(text).await
416 }
417 }
418 }
419
420 async fn generate_openai_embedding(&self, text: &str) -> Result<Vec<f32>> {
422 let api_key = self
423 .config
424 .api_key
425 .as_ref()
426 .ok_or_else(|| crate::Error::generic("OpenAI API key not configured"))?;
427
428 let endpoint = self
429 .config
430 .embedding_endpoint
431 .as_ref()
432 .unwrap_or(&self.config.api_endpoint)
433 .replace("chat/completions", "embeddings");
434
435 let request_body = serde_json::json!({
436 "model": self.config.embedding_model,
437 "input": text
438 });
439
440 debug!("Generating embedding for text with OpenAI API");
441
442 let response = self
443 .client
444 .post(&endpoint)
445 .header("Authorization", format!("Bearer {}", api_key))
446 .header("Content-Type", "application/json")
447 .json(&request_body)
448 .send()
449 .await
450 .map_err(|e| crate::Error::generic(format!("Embedding API request failed: {}", e)))?;
451
452 if !response.status().is_success() {
453 let error_text = response.text().await.unwrap_or_default();
454 return Err(crate::Error::generic(format!("Embedding API error: {}", error_text)));
455 }
456
457 let response_json: Value = response.json().await.map_err(|e| {
458 crate::Error::generic(format!("Failed to parse embedding response: {}", e))
459 })?;
460
461 if let Some(data) = response_json.get("data").and_then(|d| d.as_array()) {
462 if let Some(first_item) = data.first() {
463 if let Some(embedding) = first_item.get("embedding").and_then(|e| e.as_array()) {
464 let embedding_vec: Vec<f32> =
465 embedding.iter().filter_map(|v| v.as_f64().map(|f| f as f32)).collect();
466 return Ok(embedding_vec);
467 }
468 }
469 }
470
471 Err(crate::Error::generic("Invalid embedding response format"))
472 }
473
474 async fn generate_openai_compatible_embedding(&self, text: &str) -> Result<Vec<f32>> {
476 let endpoint = self
477 .config
478 .embedding_endpoint
479 .as_ref()
480 .unwrap_or(&self.config.api_endpoint)
481 .replace("chat/completions", "embeddings");
482
483 let request_body = serde_json::json!({
484 "model": self.config.embedding_model,
485 "input": text
486 });
487
488 debug!("Generating embedding for text with OpenAI-compatible API");
489
490 let mut request = self
491 .client
492 .post(&endpoint)
493 .header("Content-Type", "application/json")
494 .json(&request_body);
495
496 if let Some(api_key) = &self.config.api_key {
497 request = request.header("Authorization", format!("Bearer {}", api_key));
498 }
499
500 let response = request
501 .send()
502 .await
503 .map_err(|e| crate::Error::generic(format!("Embedding API request failed: {}", e)))?;
504
505 if !response.status().is_success() {
506 let error_text = response.text().await.unwrap_or_default();
507 return Err(crate::Error::generic(format!("Embedding API error: {}", error_text)));
508 }
509
510 let response_json: Value = response.json().await.map_err(|e| {
511 crate::Error::generic(format!("Failed to parse embedding response: {}", e))
512 })?;
513
514 if let Some(data) = response_json.get("data").and_then(|d| d.as_array()) {
515 if let Some(first_item) = data.first() {
516 if let Some(embedding) = first_item.get("embedding").and_then(|e| e.as_array()) {
517 let embedding_vec: Vec<f32> =
518 embedding.iter().filter_map(|v| v.as_f64().map(|f| f as f32)).collect();
519 return Ok(embedding_vec);
520 }
521 }
522 }
523
524 Err(crate::Error::generic("Invalid embedding response format"))
525 }
526
527 pub async fn compute_embeddings(&mut self) -> Result<()> {
529 debug!("Computing embeddings for {} chunks", self.chunks.len());
530
531 let chunks_to_embed: Vec<(usize, String)> = self
533 .chunks
534 .iter()
535 .enumerate()
536 .filter(|(_, chunk)| chunk.embedding.is_empty())
537 .map(|(idx, chunk)| (idx, chunk.content.clone()))
538 .collect();
539
540 for (idx, content) in chunks_to_embed {
542 let embedding = self.generate_embedding(&content).await?;
543 self.chunks[idx].embedding = embedding;
544 debug!("Computed embedding for chunk {}", self.chunks[idx].id);
545 }
546
547 Ok(())
548 }
549
550 fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
552 if a.len() != b.len() || a.is_empty() {
553 return 0.0;
554 }
555
556 let mut dot_product = 0.0;
557 let mut norm_a = 0.0;
558 let mut norm_b = 0.0;
559
560 for i in 0..a.len() {
561 dot_product += a[i] as f64 * b[i] as f64;
562 norm_a += (a[i] as f64).powi(2);
563 norm_b += (b[i] as f64).powi(2);
564 }
565
566 norm_a = norm_a.sqrt();
567 norm_b = norm_b.sqrt();
568
569 if norm_a == 0.0 || norm_b == 0.0 {
570 0.0
571 } else {
572 dot_product / (norm_a * norm_b)
573 }
574 }
575
576 async fn call_llm(&self, prompt: &str) -> Result<String> {
578 let mut last_error = None;
579
580 for attempt in 0..=self.config.max_retries {
581 match self.call_llm_single_attempt(prompt).await {
582 Ok(result) => return Ok(result),
583 Err(e) => {
584 last_error = Some(e);
585 if attempt < self.config.max_retries {
586 let delay = Duration::from_millis(500 * (attempt + 1) as u64);
587 warn!(
588 "LLM API call failed (attempt {}), retrying in {:?}: {:?}",
589 attempt + 1,
590 delay,
591 last_error
592 );
593 sleep(delay).await;
594 }
595 }
596 }
597 }
598
599 Err(last_error
600 .unwrap_or_else(|| crate::Error::generic("All LLM API retry attempts failed")))
601 }
602
603 async fn call_llm_single_attempt(&self, prompt: &str) -> Result<String> {
605 match &self.config.provider {
606 LlmProvider::OpenAI => self.call_openai(prompt).await,
607 LlmProvider::Anthropic => self.call_anthropic(prompt).await,
608 LlmProvider::OpenAICompatible => self.call_openai_compatible(prompt).await,
609 LlmProvider::Ollama => self.call_ollama(prompt).await,
610 }
611 }
612
613 async fn call_openai(&self, prompt: &str) -> Result<String> {
615 let api_key = self
616 .config
617 .api_key
618 .as_ref()
619 .ok_or_else(|| crate::Error::generic("OpenAI API key not configured"))?;
620
621 let request_body = serde_json::json!({
622 "model": self.config.model,
623 "messages": [
624 {
625 "role": "user",
626 "content": prompt
627 }
628 ],
629 "max_tokens": self.config.max_tokens,
630 "temperature": self.config.temperature
631 });
632
633 debug!("Calling OpenAI API with model: {}", self.config.model);
634
635 let response = self
636 .client
637 .post(&self.config.api_endpoint)
638 .header("Authorization", format!("Bearer {}", api_key))
639 .header("Content-Type", "application/json")
640 .json(&request_body)
641 .send()
642 .await
643 .map_err(|e| crate::Error::generic(format!("OpenAI API request failed: {}", e)))?;
644
645 if !response.status().is_success() {
646 let error_text = response.text().await.unwrap_or_default();
647 return Err(crate::Error::generic(format!("OpenAI API error: {}", error_text)));
648 }
649
650 let response_json: Value = response.json().await.map_err(|e| {
651 crate::Error::generic(format!("Failed to parse OpenAI response: {}", e))
652 })?;
653
654 if let Some(choices) = response_json.get("choices").and_then(|c| c.as_array()) {
655 if let Some(choice) = choices.first() {
656 if let Some(message) = choice.get("message").and_then(|m| m.get("content")) {
657 if let Some(content) = message.as_str() {
658 return Ok(content.to_string());
659 }
660 }
661 }
662 }
663
664 Err(crate::Error::generic("Invalid OpenAI response format"))
665 }
666
667 async fn call_anthropic(&self, prompt: &str) -> Result<String> {
669 let api_key = self
670 .config
671 .api_key
672 .as_ref()
673 .ok_or_else(|| crate::Error::generic("Anthropic API key not configured"))?;
674
675 let request_body = serde_json::json!({
676 "model": self.config.model,
677 "max_tokens": self.config.max_tokens,
678 "temperature": self.config.temperature,
679 "messages": [
680 {
681 "role": "user",
682 "content": prompt
683 }
684 ]
685 });
686
687 debug!("Calling Anthropic API with model: {}", self.config.model);
688
689 let response = self
690 .client
691 .post(&self.config.api_endpoint)
692 .header("x-api-key", api_key)
693 .header("Content-Type", "application/json")
694 .header("anthropic-version", "2023-06-01")
695 .json(&request_body)
696 .send()
697 .await
698 .map_err(|e| crate::Error::generic(format!("Anthropic API request failed: {}", e)))?;
699
700 if !response.status().is_success() {
701 let error_text = response.text().await.unwrap_or_default();
702 return Err(crate::Error::generic(format!("Anthropic API error: {}", error_text)));
703 }
704
705 let response_json: Value = response.json().await.map_err(|e| {
706 crate::Error::generic(format!("Failed to parse Anthropic response: {}", e))
707 })?;
708
709 if let Some(content) = response_json.get("content") {
710 if let Some(content_array) = content.as_array() {
711 if let Some(first_content) = content_array.first() {
712 if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
713 return Ok(text.to_string());
714 }
715 }
716 }
717 }
718
719 Err(crate::Error::generic("Invalid Anthropic response format"))
720 }
721
722 async fn call_openai_compatible(&self, prompt: &str) -> Result<String> {
724 let request_body = serde_json::json!({
725 "model": self.config.model,
726 "messages": [
727 {
728 "role": "user",
729 "content": prompt
730 }
731 ],
732 "max_tokens": self.config.max_tokens,
733 "temperature": self.config.temperature
734 });
735
736 debug!("Calling OpenAI-compatible API with model: {}", self.config.model);
737
738 let mut request = self
739 .client
740 .post(&self.config.api_endpoint)
741 .header("Content-Type", "application/json")
742 .json(&request_body);
743
744 if let Some(api_key) = &self.config.api_key {
745 request = request.header("Authorization", format!("Bearer {}", api_key));
746 }
747
748 let response = request.send().await.map_err(|e| {
749 crate::Error::generic(format!("OpenAI-compatible API request failed: {}", e))
750 })?;
751
752 if !response.status().is_success() {
753 let error_text = response.text().await.unwrap_or_default();
754 return Err(crate::Error::generic(format!(
755 "OpenAI-compatible API error: {}",
756 error_text
757 )));
758 }
759
760 let response_json: Value = response.json().await.map_err(|e| {
761 crate::Error::generic(format!("Failed to parse OpenAI-compatible response: {}", e))
762 })?;
763
764 if let Some(choices) = response_json.get("choices").and_then(|c| c.as_array()) {
765 if let Some(choice) = choices.first() {
766 if let Some(message) = choice.get("message").and_then(|m| m.get("content")) {
767 if let Some(content) = message.as_str() {
768 return Ok(content.to_string());
769 }
770 }
771 }
772 }
773
774 Err(crate::Error::generic("Invalid OpenAI-compatible response format"))
775 }
776
777 async fn call_ollama(&self, prompt: &str) -> Result<String> {
779 let request_body = serde_json::json!({
780 "model": self.config.model,
781 "prompt": prompt,
782 "stream": false
783 });
784
785 debug!("Calling Ollama API with model: {}", self.config.model);
786
787 let response = self
788 .client
789 .post(&self.config.api_endpoint)
790 .header("Content-Type", "application/json")
791 .json(&request_body)
792 .send()
793 .await
794 .map_err(|e| crate::Error::generic(format!("Ollama API request failed: {}", e)))?;
795
796 if !response.status().is_success() {
797 let error_text = response.text().await.unwrap_or_default();
798 return Err(crate::Error::generic(format!("Ollama API error: {}", error_text)));
799 }
800
801 let response_json: Value = response.json().await.map_err(|e| {
802 crate::Error::generic(format!("Failed to parse Ollama response: {}", e))
803 })?;
804
805 if let Some(response_text) = response_json.get("response").and_then(|r| r.as_str()) {
806 return Ok(response_text.to_string());
807 }
808
809 Err(crate::Error::generic("Invalid Ollama response format"))
810 }
811
812 fn parse_llm_response(&self, response: &str) -> Result<Value> {
814 match serde_json::from_str(response) {
816 Ok(value) => Ok(value),
817 Err(e) => {
818 if let Some(start) = response.find('{') {
820 if let Some(end) = response.rfind('}') {
821 let json_str = &response[start..=end];
822 match serde_json::from_str(json_str) {
823 Ok(value) => Ok(value),
824 Err(_) => Err(crate::Error::generic(format!(
825 "Failed to parse LLM response: {}",
826 e
827 ))),
828 }
829 } else {
830 Err(crate::Error::generic(format!(
831 "No closing brace found in response: {}",
832 e
833 )))
834 }
835 } else {
836 Err(crate::Error::generic(format!("No JSON found in response: {}", e)))
837 }
838 }
839 }
840 }
841
842 pub fn update_config(&mut self, config: RagConfig) {
844 self.config = config;
845 }
846
847 pub fn config(&self) -> &RagConfig {
849 &self.config
850 }
851
852 pub fn chunk_count(&self) -> usize {
854 self.chunks.len()
855 }
856
857 pub fn schema_count(&self) -> usize {
859 self.schema_kb.len()
860 }
861
862 pub fn get_chunk(&self, index: usize) -> Option<&DocumentChunk> {
864 self.chunks.get(index)
865 }
866
867 pub fn has_schema(&self, name: &str) -> bool {
869 self.schema_kb.contains_key(name)
870 }
871
872 pub async fn generate_text(&self, prompt: &str) -> Result<String> {
874 self.call_llm(prompt).await
875 }
876}
877
878impl Default for RagEngine {
879 fn default() -> Self {
880 Self::new(RagConfig::default())
881 }
882}
883
884pub mod rag_utils {
886 use super::*;
887
888 pub fn create_business_rag_engine() -> Result<RagEngine> {
890 let mut engine = RagEngine::default();
891
892 engine.add_document(
894 "Customer data typically includes personal information like name, email, phone, and address. Customers usually have unique identifiers and account creation dates.".to_string(),
895 HashMap::from([
896 ("domain".to_string(), Value::String("customer".to_string())),
897 ("type".to_string(), Value::String("general".to_string())),
898 ]),
899 )?;
900
901 engine.add_document(
902 "Product information includes name, description, price, category, and stock status. Products should have unique SKUs or IDs.".to_string(),
903 HashMap::from([
904 ("domain".to_string(), Value::String("product".to_string())),
905 ("type".to_string(), Value::String("general".to_string())),
906 ]),
907 )?;
908
909 engine.add_document(
910 "Order data contains customer references, product lists, total amounts, status, and timestamps. Orders should maintain referential integrity with customers and products.".to_string(),
911 HashMap::from([
912 ("domain".to_string(), Value::String("order".to_string())),
913 ("type".to_string(), Value::String("general".to_string())),
914 ]),
915 )?;
916
917 Ok(engine)
918 }
919
920 pub fn create_technical_rag_engine() -> Result<RagEngine> {
922 let mut engine = RagEngine::default();
923
924 engine.add_document(
926 "API endpoints should follow RESTful conventions with proper HTTP methods. GET for retrieval, POST for creation, PUT for updates, DELETE for removal.".to_string(),
927 HashMap::from([
928 ("domain".to_string(), Value::String("api".to_string())),
929 ("type".to_string(), Value::String("technical".to_string())),
930 ]),
931 )?;
932
933 engine.add_document(
934 "Database records typically have auto-incrementing primary keys, created_at and updated_at timestamps, and foreign key relationships.".to_string(),
935 HashMap::from([
936 ("domain".to_string(), Value::String("database".to_string())),
937 ("type".to_string(), Value::String("technical".to_string())),
938 ]),
939 )?;
940
941 Ok(engine)
942 }
943}
944#[cfg(test)]
945mod tests {
946 use super::*;
947
948 #[test]
949 fn test_llm_provider_variants() {
950 let openai = LlmProvider::OpenAI;
951 let anthropic = LlmProvider::Anthropic;
952 let compatible = LlmProvider::OpenAICompatible;
953 let ollama = LlmProvider::Ollama;
954
955 assert!(matches!(openai, LlmProvider::OpenAI));
956 assert!(matches!(anthropic, LlmProvider::Anthropic));
957 assert!(matches!(compatible, LlmProvider::OpenAICompatible));
958 assert!(matches!(ollama, LlmProvider::Ollama));
959 }
960
961 #[test]
962 fn test_embedding_provider_variants() {
963 let openai = EmbeddingProvider::OpenAI;
964 let compatible = EmbeddingProvider::OpenAICompatible;
965
966 assert!(matches!(openai, EmbeddingProvider::OpenAI));
967 assert!(matches!(compatible, EmbeddingProvider::OpenAICompatible));
968 }
969
970 #[test]
971 fn test_rag_config_default() {
972 let config = RagConfig::default();
973
974 assert!(config.max_tokens > 0);
975 assert!(config.temperature >= 0.0 && config.temperature <= 1.0);
976 assert!(config.context_window > 0);
977 }
978}