mockforge_data/rag/
providers.rs

1//! LLM and embedding provider integrations
2//!
3//! This module handles integrations with various LLM and embedding providers,
4//! providing a unified interface for different AI services.
5
6use crate::Result;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::collections::HashMap;
10
11/// Supported LLM providers
12#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
13#[serde(rename_all = "lowercase")]
14pub enum LlmProvider {
15    /// OpenAI GPT models
16    OpenAI,
17    /// Anthropic Claude models
18    Anthropic,
19    /// Generic OpenAI-compatible API
20    OpenAICompatible,
21    /// Local Ollama instance
22    Ollama,
23}
24
25/// Supported embedding providers
26#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
27#[serde(rename_all = "lowercase")]
28pub enum EmbeddingProvider {
29    /// OpenAI text-embedding-ada-002
30    OpenAI,
31    /// Generic OpenAI-compatible embeddings API
32    OpenAICompatible,
33    /// Local Ollama instance
34    Ollama,
35}
36
37/// LLM provider trait
38#[async_trait::async_trait]
39pub trait LlmProviderTrait: Send + Sync {
40    /// Generate text completion
41    async fn generate_completion(
42        &self,
43        prompt: &str,
44        max_tokens: Option<usize>,
45        temperature: Option<f32>,
46        top_p: Option<f32>,
47        stop_sequences: Option<Vec<String>>,
48    ) -> Result<String>;
49
50    /// Generate chat completion
51    async fn generate_chat_completion(
52        &self,
53        messages: Vec<ChatMessage>,
54        max_tokens: Option<usize>,
55        temperature: Option<f32>,
56        top_p: Option<f32>,
57        stop_sequences: Option<Vec<String>>,
58    ) -> Result<String>;
59
60    /// Get available models
61    async fn get_available_models(&self) -> Result<Vec<String>>;
62
63    /// Check if provider is available
64    async fn is_available(&self) -> bool;
65
66    /// Get provider name
67    fn name(&self) -> &'static str;
68
69    /// Get maximum context length
70    fn max_context_length(&self) -> usize;
71}
72
73/// Embedding provider trait
74#[async_trait::async_trait]
75pub trait EmbeddingProviderTrait: Send + Sync {
76    /// Generate embedding for text
77    async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>>;
78
79    /// Generate embeddings for multiple texts
80    async fn generate_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>>;
81
82    /// Get embedding dimensions
83    fn embedding_dimensions(&self) -> usize;
84
85    /// Get maximum tokens for embedding
86    fn max_tokens(&self) -> usize;
87
88    /// Get provider name
89    fn name(&self) -> &'static str;
90
91    /// Check if provider is available
92    async fn is_available(&self) -> bool;
93}
94
95/// Chat message for LLM providers
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct ChatMessage {
98    /// Message role
99    pub role: ChatRole,
100    /// Message content
101    pub content: String,
102    /// Optional metadata
103    pub metadata: Option<HashMap<String, String>>,
104}
105
106impl ChatMessage {
107    /// Create a new system message
108    pub fn system(content: String) -> Self {
109        Self {
110            role: ChatRole::System,
111            content,
112            metadata: None,
113        }
114    }
115
116    /// Create a new user message
117    pub fn user(content: String) -> Self {
118        Self {
119            role: ChatRole::User,
120            content,
121            metadata: None,
122        }
123    }
124
125    /// Create a new assistant message
126    pub fn assistant(content: String) -> Self {
127        Self {
128            role: ChatRole::Assistant,
129            content,
130            metadata: None,
131        }
132    }
133
134    /// Add metadata to message
135    pub fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
136        self.metadata = Some(metadata);
137        self
138    }
139}
140
141/// Chat message role
142#[derive(Debug, Clone, Serialize, Deserialize)]
143#[serde(rename_all = "lowercase")]
144pub enum ChatRole {
145    /// System message
146    System,
147    /// User message
148    User,
149    /// Assistant message
150    Assistant,
151}
152
153/// OpenAI provider implementation
154pub struct OpenAiProvider {
155    api_key: String,
156    client: reqwest::Client,
157    base_url: String,
158}
159
160impl OpenAiProvider {
161    /// Create a new OpenAI provider
162    pub fn new(api_key: String) -> Self {
163        Self {
164            api_key,
165            client: reqwest::Client::new(),
166            base_url: "https://api.openai.com/v1".to_string(),
167        }
168    }
169
170    /// Create with custom base URL
171    pub fn new_with_base_url(api_key: String, base_url: String) -> Self {
172        Self {
173            api_key,
174            client: reqwest::Client::new(),
175            base_url,
176        }
177    }
178}
179
180#[async_trait::async_trait]
181impl LlmProviderTrait for OpenAiProvider {
182    async fn generate_completion(
183        &self,
184        prompt: &str,
185        max_tokens: Option<usize>,
186        temperature: Option<f32>,
187        top_p: Option<f32>,
188        stop_sequences: Option<Vec<String>>,
189    ) -> Result<String> {
190        let mut request_body = serde_json::json!({
191            "model": "gpt-3.5-turbo-instruct",
192            "prompt": prompt,
193            "max_tokens": max_tokens.unwrap_or(1024),
194            "temperature": temperature.unwrap_or(0.7),
195        });
196
197        if let Some(top_p) = top_p {
198            request_body["top_p"] = serde_json::json!(top_p);
199        }
200
201        if let Some(stop) = stop_sequences {
202            request_body["stop"] = serde_json::json!(stop);
203        }
204
205        let response = self
206            .client
207            .post(format!("{}/completions", self.base_url))
208            .header("Authorization", format!("Bearer {}", self.api_key))
209            .header("Content-Type", "application/json")
210            .json(&request_body)
211            .send()
212            .await?;
213
214        if !response.status().is_success() {
215            return Err(crate::Error::generic(format!("OpenAI API error: {}", response.status())));
216        }
217
218        let json: Value = response.json().await?;
219        let content = json["choices"][0]["text"]
220            .as_str()
221            .ok_or_else(|| crate::Error::generic("Invalid response format"))?;
222
223        Ok(content.to_string())
224    }
225
226    async fn generate_chat_completion(
227        &self,
228        messages: Vec<ChatMessage>,
229        max_tokens: Option<usize>,
230        temperature: Option<f32>,
231        top_p: Option<f32>,
232        stop_sequences: Option<Vec<String>>,
233    ) -> Result<String> {
234        let openai_messages: Vec<Value> = messages
235            .iter()
236            .map(|msg| {
237                serde_json::json!({
238                    "role": format!("{:?}", msg.role).to_lowercase(),
239                    "content": msg.content
240                })
241            })
242            .collect();
243
244        let mut request_body = serde_json::json!({
245            "model": "gpt-3.5-turbo",
246            "messages": openai_messages,
247            "max_tokens": max_tokens.unwrap_or(1024),
248            "temperature": temperature.unwrap_or(0.7),
249        });
250
251        if let Some(top_p) = top_p {
252            request_body["top_p"] = serde_json::json!(top_p);
253        }
254
255        if let Some(stop) = stop_sequences {
256            request_body["stop"] = serde_json::json!(stop);
257        }
258
259        let response = self
260            .client
261            .post(format!("{}/chat/completions", self.base_url))
262            .header("Authorization", format!("Bearer {}", self.api_key))
263            .header("Content-Type", "application/json")
264            .json(&request_body)
265            .send()
266            .await?;
267
268        if !response.status().is_success() {
269            return Err(crate::Error::generic(format!("OpenAI API error: {}", response.status())));
270        }
271
272        let json: Value = response.json().await?;
273        let content = json["choices"][0]["message"]["content"]
274            .as_str()
275            .ok_or_else(|| crate::Error::generic("Invalid response format"))?;
276
277        Ok(content.to_string())
278    }
279
280    async fn get_available_models(&self) -> Result<Vec<String>> {
281        let response = self
282            .client
283            .get(format!("{}/models", self.base_url))
284            .header("Authorization", format!("Bearer {}", self.api_key))
285            .send()
286            .await?;
287
288        if !response.status().is_success() {
289            return Err(crate::Error::generic(format!("OpenAI API error: {}", response.status())));
290        }
291
292        let json: Value = response.json().await?;
293        let models = json["data"]
294            .as_array()
295            .ok_or_else(|| crate::Error::generic("Invalid models response format"))?;
296
297        let model_names = models
298            .iter()
299            .filter_map(|model| model["id"].as_str().map(|s| s.to_string()))
300            .collect();
301
302        Ok(model_names)
303    }
304
305    async fn is_available(&self) -> bool {
306        (self.get_available_models().await).is_ok()
307    }
308
309    fn name(&self) -> &'static str {
310        "OpenAI"
311    }
312
313    fn max_context_length(&self) -> usize {
314        4096 // GPT-3.5 context length
315    }
316}
317
318/// OpenAI embedding provider implementation
319pub struct OpenAiEmbeddingProvider {
320    api_key: String,
321    client: reqwest::Client,
322    base_url: String,
323    model: String,
324}
325
326impl OpenAiEmbeddingProvider {
327    /// Create a new OpenAI embedding provider
328    pub fn new(api_key: String) -> Self {
329        Self {
330            api_key,
331            client: reqwest::Client::new(),
332            base_url: "https://api.openai.com/v1".to_string(),
333            model: "text-embedding-ada-002".to_string(),
334        }
335    }
336
337    /// Create with custom model
338    pub fn new_with_model(api_key: String, model: String) -> Self {
339        Self {
340            api_key,
341            client: reqwest::Client::new(),
342            base_url: "https://api.openai.com/v1".to_string(),
343            model,
344        }
345    }
346}
347
348#[async_trait::async_trait]
349impl EmbeddingProviderTrait for OpenAiEmbeddingProvider {
350    async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
351        let response = self
352            .client
353            .post(format!("{}/embeddings", self.base_url))
354            .header("Authorization", format!("Bearer {}", self.api_key))
355            .header("Content-Type", "application/json")
356            .json(&serde_json::json!({
357                "input": text,
358                "model": self.model
359            }))
360            .send()
361            .await?;
362
363        if !response.status().is_success() {
364            return Err(crate::Error::generic(format!("OpenAI API error: {}", response.status())));
365        }
366
367        let json: Value = response.json().await?;
368        let embedding = json["data"][0]["embedding"]
369            .as_array()
370            .ok_or_else(|| crate::Error::generic("Invalid embedding response format"))?;
371
372        Ok(embedding.iter().map(|v| v.as_f64().unwrap_or(0.0) as f32).collect())
373    }
374
375    async fn generate_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
376        let mut embeddings = Vec::new();
377
378        for text in texts {
379            let embedding = self.generate_embedding(&text).await?;
380            embeddings.push(embedding);
381        }
382
383        Ok(embeddings)
384    }
385
386    fn embedding_dimensions(&self) -> usize {
387        match self.model.as_str() {
388            "text-embedding-ada-002" => 1536,
389            "text-embedding-3-small" => 1536,
390            "text-embedding-3-large" => 3072,
391            _ => 1536, // Default
392        }
393    }
394
395    fn max_tokens(&self) -> usize {
396        match self.model.as_str() {
397            "text-embedding-ada-002" => 8191,
398            "text-embedding-3-small" => 8191,
399            "text-embedding-3-large" => 8191,
400            _ => 8191, // Default
401        }
402    }
403
404    fn name(&self) -> &'static str {
405        "OpenAI"
406    }
407
408    async fn is_available(&self) -> bool {
409        (self.generate_embedding("test").await).is_ok()
410    }
411}
412
413/// OpenAI-compatible provider implementation
414pub struct OpenAiCompatibleProvider {
415    api_key: String,
416    client: reqwest::Client,
417    base_url: String,
418    model: String,
419}
420
421impl OpenAiCompatibleProvider {
422    /// Create a new OpenAI-compatible provider
423    pub fn new(api_key: String, base_url: String, model: String) -> Self {
424        Self {
425            api_key,
426            client: reqwest::Client::new(),
427            base_url,
428            model,
429        }
430    }
431}
432
433#[async_trait::async_trait]
434impl LlmProviderTrait for OpenAiCompatibleProvider {
435    async fn generate_completion(
436        &self,
437        prompt: &str,
438        max_tokens: Option<usize>,
439        temperature: Option<f32>,
440        top_p: Option<f32>,
441        stop_sequences: Option<Vec<String>>,
442    ) -> Result<String> {
443        let mut request_body = serde_json::json!({
444            "model": self.model,
445            "prompt": prompt,
446            "max_tokens": max_tokens.unwrap_or(1024),
447            "temperature": temperature.unwrap_or(0.7),
448        });
449
450        if let Some(top_p) = top_p {
451            request_body["top_p"] = serde_json::json!(top_p);
452        }
453
454        if let Some(stop) = stop_sequences {
455            request_body["stop"] = serde_json::json!(stop);
456        }
457
458        let response = self
459            .client
460            .post(format!("{}/completions", self.base_url))
461            .header("Authorization", format!("Bearer {}", self.api_key))
462            .header("Content-Type", "application/json")
463            .json(&request_body)
464            .send()
465            .await?;
466
467        if !response.status().is_success() {
468            return Err(crate::Error::generic(format!("API error: {}", response.status())));
469        }
470
471        let json: Value = response.json().await?;
472        let content = json["choices"][0]["text"]
473            .as_str()
474            .ok_or_else(|| crate::Error::generic("Invalid response format"))?;
475
476        Ok(content.to_string())
477    }
478
479    async fn generate_chat_completion(
480        &self,
481        messages: Vec<ChatMessage>,
482        max_tokens: Option<usize>,
483        temperature: Option<f32>,
484        top_p: Option<f32>,
485        stop_sequences: Option<Vec<String>>,
486    ) -> Result<String> {
487        let openai_messages: Vec<Value> = messages
488            .iter()
489            .map(|msg| {
490                serde_json::json!({
491                    "role": format!("{:?}", msg.role).to_lowercase(),
492                    "content": msg.content
493                })
494            })
495            .collect();
496
497        let mut request_body = serde_json::json!({
498            "model": self.model,
499            "messages": openai_messages,
500            "max_tokens": max_tokens.unwrap_or(1024),
501            "temperature": temperature.unwrap_or(0.7),
502        });
503
504        if let Some(top_p) = top_p {
505            request_body["top_p"] = serde_json::json!(top_p);
506        }
507
508        if let Some(stop) = stop_sequences {
509            request_body["stop"] = serde_json::json!(stop);
510        }
511
512        let response = self
513            .client
514            .post(format!("{}/chat/completions", self.base_url))
515            .header("Authorization", format!("Bearer {}", self.api_key))
516            .header("Content-Type", "application/json")
517            .json(&request_body)
518            .send()
519            .await?;
520
521        if !response.status().is_success() {
522            return Err(crate::Error::generic(format!("API error: {}", response.status())));
523        }
524
525        let json: Value = response.json().await?;
526        let content = json["choices"][0]["message"]["content"]
527            .as_str()
528            .ok_or_else(|| crate::Error::generic("Invalid response format"))?;
529
530        Ok(content.to_string())
531    }
532
533    async fn get_available_models(&self) -> Result<Vec<String>> {
534        // Try to get models, but fall back gracefully if not available
535        match self
536            .client
537            .get(format!("{}/models", self.base_url))
538            .header("Authorization", format!("Bearer {}", self.api_key))
539            .send()
540            .await
541        {
542            Ok(response) if response.status().is_success() => {
543                let json: Value = response.json().await?;
544                let models = json["data"]
545                    .as_array()
546                    .ok_or_else(|| crate::Error::generic("Invalid models response format"))?;
547                Ok(models
548                    .iter()
549                    .filter_map(|model| model["id"].as_str().map(|s| s.to_string()))
550                    .collect())
551            }
552            _ => Ok(vec![self.model.clone()]), // Return configured model as fallback
553        }
554    }
555
556    async fn is_available(&self) -> bool {
557        (self.generate_completion("test", Some(1), None, None, None).await).is_ok()
558    }
559
560    fn name(&self) -> &'static str {
561        "OpenAI Compatible"
562    }
563
564    fn max_context_length(&self) -> usize {
565        4096 // Default context length
566    }
567}
568
569/// OpenAI-compatible embedding provider implementation
570pub struct OpenAiCompatibleEmbeddingProvider {
571    api_key: String,
572    client: reqwest::Client,
573    base_url: String,
574    model: String,
575}
576
577impl OpenAiCompatibleEmbeddingProvider {
578    /// Create a new OpenAI-compatible embedding provider
579    pub fn new(api_key: String, base_url: String, model: String) -> Self {
580        Self {
581            api_key,
582            client: reqwest::Client::new(),
583            base_url,
584            model,
585        }
586    }
587}
588
589#[async_trait::async_trait]
590impl EmbeddingProviderTrait for OpenAiCompatibleEmbeddingProvider {
591    async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
592        let response = self
593            .client
594            .post(format!("{}/embeddings", self.base_url))
595            .header("Authorization", format!("Bearer {}", self.api_key))
596            .header("Content-Type", "application/json")
597            .json(&serde_json::json!({
598                "input": text,
599                "model": self.model
600            }))
601            .send()
602            .await?;
603
604        if !response.status().is_success() {
605            return Err(crate::Error::generic(format!("API error: {}", response.status())));
606        }
607
608        let json: Value = response.json().await?;
609        let embedding = json["data"][0]["embedding"]
610            .as_array()
611            .ok_or_else(|| crate::Error::generic("Invalid embedding response format"))?;
612
613        Ok(embedding.iter().map(|v| v.as_f64().unwrap_or(0.0) as f32).collect())
614    }
615
616    async fn generate_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
617        let mut embeddings = Vec::new();
618
619        for text in texts {
620            let embedding = self.generate_embedding(&text).await?;
621            embeddings.push(embedding);
622        }
623
624        Ok(embeddings)
625    }
626
627    fn embedding_dimensions(&self) -> usize {
628        1536 // Default OpenAI embedding dimensions
629    }
630
631    fn max_tokens(&self) -> usize {
632        8191 // Default OpenAI token limit
633    }
634
635    fn name(&self) -> &'static str {
636        "OpenAI Compatible"
637    }
638
639    async fn is_available(&self) -> bool {
640        (self.generate_embedding("test").await).is_ok()
641    }
642}
643
644/// Provider factory for creating LLM and embedding providers
645pub struct ProviderFactory;
646
647impl ProviderFactory {
648    /// Create LLM provider from configuration
649    pub fn create_llm_provider(
650        provider_type: LlmProvider,
651        api_key: String,
652        base_url: Option<String>,
653        model: String,
654    ) -> Result<Box<dyn LlmProviderTrait>> {
655        match provider_type {
656            LlmProvider::OpenAI => {
657                let base_url = base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string());
658                Ok(Box::new(OpenAiProvider::new_with_base_url(api_key, base_url)))
659            }
660            LlmProvider::OpenAICompatible => {
661                let base_url = base_url.ok_or_else(|| {
662                    crate::Error::generic("Base URL required for OpenAI compatible provider")
663                })?;
664                Ok(Box::new(OpenAiCompatibleProvider::new(api_key, base_url, model)))
665            }
666            _ => Err(crate::Error::generic(format!(
667                "Provider type {:?} not yet implemented",
668                provider_type
669            ))),
670        }
671    }
672
673    /// Create embedding provider from configuration
674    pub fn create_embedding_provider(
675        provider_type: EmbeddingProvider,
676        api_key: String,
677        base_url: Option<String>,
678        model: String,
679    ) -> Result<Box<dyn EmbeddingProviderTrait>> {
680        match provider_type {
681            EmbeddingProvider::OpenAI => {
682                Ok(Box::new(OpenAiEmbeddingProvider::new_with_model(api_key, model)))
683            }
684            EmbeddingProvider::OpenAICompatible => {
685                let base_url = base_url.ok_or_else(|| {
686                    crate::Error::generic(
687                        "Base URL required for OpenAI compatible embedding provider",
688                    )
689                })?;
690                Ok(Box::new(OpenAiCompatibleEmbeddingProvider::new(api_key, base_url, model)))
691            }
692            EmbeddingProvider::Ollama => {
693                // Ollama embeddings use OpenAI-compatible API
694                let base_url = base_url.ok_or_else(|| {
695                    crate::Error::generic("Base URL required for Ollama embedding provider")
696                })?;
697                // Ollama doesn't require API key, use empty string
698                Ok(Box::new(OpenAiCompatibleEmbeddingProvider::new(String::new(), base_url, model)))
699            }
700        }
701    }
702}
703
704#[cfg(test)]
705mod tests {
706
707    #[test]
708    fn test_module_compiles() {
709        // Basic compilation test
710    }
711}