mockforge_data/rag/
providers.rs

1//! LLM and embedding provider integrations
2//!
3//! This module handles integrations with various LLM and embedding providers,
4//! providing a unified interface for different AI services.
5
6use crate::Result;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::collections::HashMap;
10
11/// Supported LLM providers
12#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
13#[serde(rename_all = "lowercase")]
14pub enum LlmProvider {
15    /// OpenAI GPT models
16    OpenAI,
17    /// Anthropic Claude models
18    Anthropic,
19    /// Generic OpenAI-compatible API
20    OpenAICompatible,
21    /// Local Ollama instance
22    Ollama,
23}
24
25/// Supported embedding providers
26#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
27#[serde(rename_all = "lowercase")]
28pub enum EmbeddingProvider {
29    /// OpenAI text-embedding-ada-002
30    OpenAI,
31    /// Generic OpenAI-compatible embeddings API
32    OpenAICompatible,
33}
34
35/// LLM provider trait
36#[async_trait::async_trait]
37pub trait LlmProviderTrait: Send + Sync {
38    /// Generate text completion
39    async fn generate_completion(
40        &self,
41        prompt: &str,
42        max_tokens: Option<usize>,
43        temperature: Option<f32>,
44        top_p: Option<f32>,
45        stop_sequences: Option<Vec<String>>,
46    ) -> Result<String>;
47
48    /// Generate chat completion
49    async fn generate_chat_completion(
50        &self,
51        messages: Vec<ChatMessage>,
52        max_tokens: Option<usize>,
53        temperature: Option<f32>,
54        top_p: Option<f32>,
55        stop_sequences: Option<Vec<String>>,
56    ) -> Result<String>;
57
58    /// Get available models
59    async fn get_available_models(&self) -> Result<Vec<String>>;
60
61    /// Check if provider is available
62    async fn is_available(&self) -> bool;
63
64    /// Get provider name
65    fn name(&self) -> &'static str;
66
67    /// Get maximum context length
68    fn max_context_length(&self) -> usize;
69}
70
71/// Embedding provider trait
72#[async_trait::async_trait]
73pub trait EmbeddingProviderTrait: Send + Sync {
74    /// Generate embedding for text
75    async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>>;
76
77    /// Generate embeddings for multiple texts
78    async fn generate_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>>;
79
80    /// Get embedding dimensions
81    fn embedding_dimensions(&self) -> usize;
82
83    /// Get maximum tokens for embedding
84    fn max_tokens(&self) -> usize;
85
86    /// Get provider name
87    fn name(&self) -> &'static str;
88
89    /// Check if provider is available
90    async fn is_available(&self) -> bool;
91}
92
93/// Chat message for LLM providers
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct ChatMessage {
96    /// Message role
97    pub role: ChatRole,
98    /// Message content
99    pub content: String,
100    /// Optional metadata
101    pub metadata: Option<HashMap<String, String>>,
102}
103
104impl ChatMessage {
105    /// Create a new system message
106    pub fn system(content: String) -> Self {
107        Self {
108            role: ChatRole::System,
109            content,
110            metadata: None,
111        }
112    }
113
114    /// Create a new user message
115    pub fn user(content: String) -> Self {
116        Self {
117            role: ChatRole::User,
118            content,
119            metadata: None,
120        }
121    }
122
123    /// Create a new assistant message
124    pub fn assistant(content: String) -> Self {
125        Self {
126            role: ChatRole::Assistant,
127            content,
128            metadata: None,
129        }
130    }
131
132    /// Add metadata to message
133    pub fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
134        self.metadata = Some(metadata);
135        self
136    }
137}
138
139/// Chat message role
140#[derive(Debug, Clone, Serialize, Deserialize)]
141#[serde(rename_all = "lowercase")]
142pub enum ChatRole {
143    /// System message
144    System,
145    /// User message
146    User,
147    /// Assistant message
148    Assistant,
149}
150
151/// OpenAI provider implementation
152pub struct OpenAiProvider {
153    api_key: String,
154    client: reqwest::Client,
155    base_url: String,
156}
157
158impl OpenAiProvider {
159    /// Create a new OpenAI provider
160    pub fn new(api_key: String) -> Self {
161        Self {
162            api_key,
163            client: reqwest::Client::new(),
164            base_url: "https://api.openai.com/v1".to_string(),
165        }
166    }
167
168    /// Create with custom base URL
169    pub fn new_with_base_url(api_key: String, base_url: String) -> Self {
170        Self {
171            api_key,
172            client: reqwest::Client::new(),
173            base_url,
174        }
175    }
176}
177
178#[async_trait::async_trait]
179impl LlmProviderTrait for OpenAiProvider {
180    async fn generate_completion(
181        &self,
182        prompt: &str,
183        max_tokens: Option<usize>,
184        temperature: Option<f32>,
185        top_p: Option<f32>,
186        stop_sequences: Option<Vec<String>>,
187    ) -> Result<String> {
188        let mut request_body = serde_json::json!({
189            "model": "gpt-3.5-turbo-instruct",
190            "prompt": prompt,
191            "max_tokens": max_tokens.unwrap_or(1024),
192            "temperature": temperature.unwrap_or(0.7),
193        });
194
195        if let Some(top_p) = top_p {
196            request_body["top_p"] = serde_json::json!(top_p);
197        }
198
199        if let Some(stop) = stop_sequences {
200            request_body["stop"] = serde_json::json!(stop);
201        }
202
203        let response = self
204            .client
205            .post(format!("{}/completions", self.base_url))
206            .header("Authorization", format!("Bearer {}", self.api_key))
207            .header("Content-Type", "application/json")
208            .json(&request_body)
209            .send()
210            .await?;
211
212        if !response.status().is_success() {
213            return Err(mockforge_core::Error::generic(format!(
214                "OpenAI API error: {}",
215                response.status()
216            )));
217        }
218
219        let json: Value = response.json().await?;
220        let content = json["choices"][0]["text"]
221            .as_str()
222            .ok_or_else(|| mockforge_core::Error::generic("Invalid response format"))?;
223
224        Ok(content.to_string())
225    }
226
227    async fn generate_chat_completion(
228        &self,
229        messages: Vec<ChatMessage>,
230        max_tokens: Option<usize>,
231        temperature: Option<f32>,
232        top_p: Option<f32>,
233        stop_sequences: Option<Vec<String>>,
234    ) -> Result<String> {
235        let openai_messages: Vec<Value> = messages
236            .iter()
237            .map(|msg| {
238                serde_json::json!({
239                    "role": format!("{:?}", msg.role).to_lowercase(),
240                    "content": msg.content
241                })
242            })
243            .collect();
244
245        let mut request_body = serde_json::json!({
246            "model": "gpt-3.5-turbo",
247            "messages": openai_messages,
248            "max_tokens": max_tokens.unwrap_or(1024),
249            "temperature": temperature.unwrap_or(0.7),
250        });
251
252        if let Some(top_p) = top_p {
253            request_body["top_p"] = serde_json::json!(top_p);
254        }
255
256        if let Some(stop) = stop_sequences {
257            request_body["stop"] = serde_json::json!(stop);
258        }
259
260        let response = self
261            .client
262            .post(format!("{}/chat/completions", self.base_url))
263            .header("Authorization", format!("Bearer {}", self.api_key))
264            .header("Content-Type", "application/json")
265            .json(&request_body)
266            .send()
267            .await?;
268
269        if !response.status().is_success() {
270            return Err(mockforge_core::Error::generic(format!(
271                "OpenAI API error: {}",
272                response.status()
273            )));
274        }
275
276        let json: Value = response.json().await?;
277        let content = json["choices"][0]["message"]["content"]
278            .as_str()
279            .ok_or_else(|| mockforge_core::Error::generic("Invalid response format"))?;
280
281        Ok(content.to_string())
282    }
283
284    async fn get_available_models(&self) -> Result<Vec<String>> {
285        let response = self
286            .client
287            .get(format!("{}/models", self.base_url))
288            .header("Authorization", format!("Bearer {}", self.api_key))
289            .send()
290            .await?;
291
292        if !response.status().is_success() {
293            return Err(mockforge_core::Error::generic(format!(
294                "OpenAI API error: {}",
295                response.status()
296            )));
297        }
298
299        let json: Value = response.json().await?;
300        let models = json["data"]
301            .as_array()
302            .ok_or_else(|| mockforge_core::Error::generic("Invalid models response format"))?;
303
304        let model_names = models
305            .iter()
306            .filter_map(|model| model["id"].as_str().map(|s| s.to_string()))
307            .collect();
308
309        Ok(model_names)
310    }
311
312    async fn is_available(&self) -> bool {
313        (self.get_available_models().await).is_ok()
314    }
315
316    fn name(&self) -> &'static str {
317        "OpenAI"
318    }
319
320    fn max_context_length(&self) -> usize {
321        4096 // GPT-3.5 context length
322    }
323}
324
325/// OpenAI embedding provider implementation
326pub struct OpenAiEmbeddingProvider {
327    api_key: String,
328    client: reqwest::Client,
329    base_url: String,
330    model: String,
331}
332
333impl OpenAiEmbeddingProvider {
334    /// Create a new OpenAI embedding provider
335    pub fn new(api_key: String) -> Self {
336        Self {
337            api_key,
338            client: reqwest::Client::new(),
339            base_url: "https://api.openai.com/v1".to_string(),
340            model: "text-embedding-ada-002".to_string(),
341        }
342    }
343
344    /// Create with custom model
345    pub fn new_with_model(api_key: String, model: String) -> Self {
346        Self {
347            api_key,
348            client: reqwest::Client::new(),
349            base_url: "https://api.openai.com/v1".to_string(),
350            model,
351        }
352    }
353}
354
355#[async_trait::async_trait]
356impl EmbeddingProviderTrait for OpenAiEmbeddingProvider {
357    async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
358        let response = self
359            .client
360            .post(format!("{}/embeddings", self.base_url))
361            .header("Authorization", format!("Bearer {}", self.api_key))
362            .header("Content-Type", "application/json")
363            .json(&serde_json::json!({
364                "input": text,
365                "model": self.model
366            }))
367            .send()
368            .await?;
369
370        if !response.status().is_success() {
371            return Err(mockforge_core::Error::generic(format!(
372                "OpenAI API error: {}",
373                response.status()
374            )));
375        }
376
377        let json: Value = response.json().await?;
378        let embedding = json["data"][0]["embedding"]
379            .as_array()
380            .ok_or_else(|| mockforge_core::Error::generic("Invalid embedding response format"))?;
381
382        Ok(embedding.iter().map(|v| v.as_f64().unwrap_or(0.0) as f32).collect())
383    }
384
385    async fn generate_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
386        let mut embeddings = Vec::new();
387
388        for text in texts {
389            let embedding = self.generate_embedding(&text).await?;
390            embeddings.push(embedding);
391        }
392
393        Ok(embeddings)
394    }
395
396    fn embedding_dimensions(&self) -> usize {
397        match self.model.as_str() {
398            "text-embedding-ada-002" => 1536,
399            "text-embedding-3-small" => 1536,
400            "text-embedding-3-large" => 3072,
401            _ => 1536, // Default
402        }
403    }
404
405    fn max_tokens(&self) -> usize {
406        match self.model.as_str() {
407            "text-embedding-ada-002" => 8191,
408            "text-embedding-3-small" => 8191,
409            "text-embedding-3-large" => 8191,
410            _ => 8191, // Default
411        }
412    }
413
414    fn name(&self) -> &'static str {
415        "OpenAI"
416    }
417
418    async fn is_available(&self) -> bool {
419        (self.generate_embedding("test").await).is_ok()
420    }
421}
422
423/// OpenAI-compatible provider implementation
424pub struct OpenAiCompatibleProvider {
425    api_key: String,
426    client: reqwest::Client,
427    base_url: String,
428    model: String,
429}
430
431impl OpenAiCompatibleProvider {
432    /// Create a new OpenAI-compatible provider
433    pub fn new(api_key: String, base_url: String, model: String) -> Self {
434        Self {
435            api_key,
436            client: reqwest::Client::new(),
437            base_url,
438            model,
439        }
440    }
441}
442
443#[async_trait::async_trait]
444impl LlmProviderTrait for OpenAiCompatibleProvider {
445    async fn generate_completion(
446        &self,
447        prompt: &str,
448        max_tokens: Option<usize>,
449        temperature: Option<f32>,
450        top_p: Option<f32>,
451        stop_sequences: Option<Vec<String>>,
452    ) -> Result<String> {
453        let mut request_body = serde_json::json!({
454            "model": self.model,
455            "prompt": prompt,
456            "max_tokens": max_tokens.unwrap_or(1024),
457            "temperature": temperature.unwrap_or(0.7),
458        });
459
460        if let Some(top_p) = top_p {
461            request_body["top_p"] = serde_json::json!(top_p);
462        }
463
464        if let Some(stop) = stop_sequences {
465            request_body["stop"] = serde_json::json!(stop);
466        }
467
468        let response = self
469            .client
470            .post(format!("{}/completions", self.base_url))
471            .header("Authorization", format!("Bearer {}", self.api_key))
472            .header("Content-Type", "application/json")
473            .json(&request_body)
474            .send()
475            .await?;
476
477        if !response.status().is_success() {
478            return Err(mockforge_core::Error::generic(format!(
479                "API error: {}",
480                response.status()
481            )));
482        }
483
484        let json: Value = response.json().await?;
485        let content = json["choices"][0]["text"]
486            .as_str()
487            .ok_or_else(|| mockforge_core::Error::generic("Invalid response format"))?;
488
489        Ok(content.to_string())
490    }
491
492    async fn generate_chat_completion(
493        &self,
494        messages: Vec<ChatMessage>,
495        max_tokens: Option<usize>,
496        temperature: Option<f32>,
497        top_p: Option<f32>,
498        stop_sequences: Option<Vec<String>>,
499    ) -> Result<String> {
500        let openai_messages: Vec<Value> = messages
501            .iter()
502            .map(|msg| {
503                serde_json::json!({
504                    "role": format!("{:?}", msg.role).to_lowercase(),
505                    "content": msg.content
506                })
507            })
508            .collect();
509
510        let mut request_body = serde_json::json!({
511            "model": self.model,
512            "messages": openai_messages,
513            "max_tokens": max_tokens.unwrap_or(1024),
514            "temperature": temperature.unwrap_or(0.7),
515        });
516
517        if let Some(top_p) = top_p {
518            request_body["top_p"] = serde_json::json!(top_p);
519        }
520
521        if let Some(stop) = stop_sequences {
522            request_body["stop"] = serde_json::json!(stop);
523        }
524
525        let response = self
526            .client
527            .post(format!("{}/chat/completions", self.base_url))
528            .header("Authorization", format!("Bearer {}", self.api_key))
529            .header("Content-Type", "application/json")
530            .json(&request_body)
531            .send()
532            .await?;
533
534        if !response.status().is_success() {
535            return Err(mockforge_core::Error::generic(format!(
536                "API error: {}",
537                response.status()
538            )));
539        }
540
541        let json: Value = response.json().await?;
542        let content = json["choices"][0]["message"]["content"]
543            .as_str()
544            .ok_or_else(|| mockforge_core::Error::generic("Invalid response format"))?;
545
546        Ok(content.to_string())
547    }
548
549    async fn get_available_models(&self) -> Result<Vec<String>> {
550        // Try to get models, but fall back gracefully if not available
551        match self
552            .client
553            .get(format!("{}/models", self.base_url))
554            .header("Authorization", format!("Bearer {}", self.api_key))
555            .send()
556            .await
557        {
558            Ok(response) if response.status().is_success() => {
559                let json: Value = response.json().await?;
560                let models = json["data"].as_array().ok_or_else(|| {
561                    mockforge_core::Error::generic("Invalid models response format")
562                })?;
563                Ok(models
564                    .iter()
565                    .filter_map(|model| model["id"].as_str().map(|s| s.to_string()))
566                    .collect())
567            }
568            _ => Ok(vec![self.model.clone()]), // Return configured model as fallback
569        }
570    }
571
572    async fn is_available(&self) -> bool {
573        (self.generate_completion("test", Some(1), None, None, None).await).is_ok()
574    }
575
576    fn name(&self) -> &'static str {
577        "OpenAI Compatible"
578    }
579
580    fn max_context_length(&self) -> usize {
581        4096 // Default context length
582    }
583}
584
585/// OpenAI-compatible embedding provider implementation
586pub struct OpenAiCompatibleEmbeddingProvider {
587    api_key: String,
588    client: reqwest::Client,
589    base_url: String,
590    model: String,
591}
592
593impl OpenAiCompatibleEmbeddingProvider {
594    /// Create a new OpenAI-compatible embedding provider
595    pub fn new(api_key: String, base_url: String, model: String) -> Self {
596        Self {
597            api_key,
598            client: reqwest::Client::new(),
599            base_url,
600            model,
601        }
602    }
603}
604
605#[async_trait::async_trait]
606impl EmbeddingProviderTrait for OpenAiCompatibleEmbeddingProvider {
607    async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
608        let response = self
609            .client
610            .post(format!("{}/embeddings", self.base_url))
611            .header("Authorization", format!("Bearer {}", self.api_key))
612            .header("Content-Type", "application/json")
613            .json(&serde_json::json!({
614                "input": text,
615                "model": self.model
616            }))
617            .send()
618            .await?;
619
620        if !response.status().is_success() {
621            return Err(mockforge_core::Error::generic(format!(
622                "API error: {}",
623                response.status()
624            )));
625        }
626
627        let json: Value = response.json().await?;
628        let embedding = json["data"][0]["embedding"]
629            .as_array()
630            .ok_or_else(|| mockforge_core::Error::generic("Invalid embedding response format"))?;
631
632        Ok(embedding.iter().map(|v| v.as_f64().unwrap_or(0.0) as f32).collect())
633    }
634
635    async fn generate_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
636        let mut embeddings = Vec::new();
637
638        for text in texts {
639            let embedding = self.generate_embedding(&text).await?;
640            embeddings.push(embedding);
641        }
642
643        Ok(embeddings)
644    }
645
646    fn embedding_dimensions(&self) -> usize {
647        1536 // Default OpenAI embedding dimensions
648    }
649
650    fn max_tokens(&self) -> usize {
651        8191 // Default OpenAI token limit
652    }
653
654    fn name(&self) -> &'static str {
655        "OpenAI Compatible"
656    }
657
658    async fn is_available(&self) -> bool {
659        (self.generate_embedding("test").await).is_ok()
660    }
661}
662
663/// Provider factory for creating LLM and embedding providers
664pub struct ProviderFactory;
665
666impl ProviderFactory {
667    /// Create LLM provider from configuration
668    pub fn create_llm_provider(
669        provider_type: LlmProvider,
670        api_key: String,
671        base_url: Option<String>,
672        model: String,
673    ) -> Result<Box<dyn LlmProviderTrait>> {
674        match provider_type {
675            LlmProvider::OpenAI => {
676                let base_url = base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string());
677                Ok(Box::new(OpenAiProvider::new_with_base_url(api_key, base_url)))
678            }
679            LlmProvider::OpenAICompatible => {
680                let base_url = base_url.ok_or_else(|| {
681                    mockforge_core::Error::generic(
682                        "Base URL required for OpenAI compatible provider",
683                    )
684                })?;
685                Ok(Box::new(OpenAiCompatibleProvider::new(api_key, base_url, model)))
686            }
687            _ => Err(mockforge_core::Error::generic(format!(
688                "Provider type {:?} not yet implemented",
689                provider_type
690            ))),
691        }
692    }
693
694    /// Create embedding provider from configuration
695    pub fn create_embedding_provider(
696        provider_type: EmbeddingProvider,
697        api_key: String,
698        base_url: Option<String>,
699        model: String,
700    ) -> Result<Box<dyn EmbeddingProviderTrait>> {
701        match provider_type {
702            EmbeddingProvider::OpenAI => {
703                Ok(Box::new(OpenAiEmbeddingProvider::new_with_model(api_key, model)))
704            }
705            EmbeddingProvider::OpenAICompatible => {
706                let base_url = base_url.ok_or_else(|| {
707                    mockforge_core::Error::generic(
708                        "Base URL required for OpenAI compatible embedding provider",
709                    )
710                })?;
711                Ok(Box::new(OpenAiCompatibleEmbeddingProvider::new(api_key, base_url, model)))
712            }
713        }
714    }
715}
716
717#[cfg(test)]
718mod tests {
719
720    #[test]
721    fn test_module_compiles() {
722        // Basic compilation test
723    }
724}