Skip to main content

mockforge_data/rag/
providers.rs

1//! LLM and embedding provider integrations
2//!
3//! This module handles integrations with various LLM and embedding providers,
4//! providing a unified interface for different AI services.
5
6use crate::Result;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::collections::HashMap;
10
11/// Supported LLM providers
12#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
13#[serde(rename_all = "lowercase")]
14pub enum LlmProvider {
15    /// OpenAI GPT models
16    OpenAI,
17    /// Anthropic Claude models
18    Anthropic,
19    /// Generic OpenAI-compatible API
20    OpenAICompatible,
21    /// Local Ollama instance
22    Ollama,
23}
24
25/// Supported embedding providers
26#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
27#[serde(rename_all = "lowercase")]
28pub enum EmbeddingProvider {
29    /// OpenAI text-embedding-ada-002
30    OpenAI,
31    /// Generic OpenAI-compatible embeddings API
32    OpenAICompatible,
33    /// Local Ollama instance
34    Ollama,
35}
36
37/// LLM provider trait
38#[async_trait::async_trait]
39pub trait LlmProviderTrait: Send + Sync {
40    /// Generate text completion
41    async fn generate_completion(
42        &self,
43        prompt: &str,
44        max_tokens: Option<usize>,
45        temperature: Option<f32>,
46        top_p: Option<f32>,
47        stop_sequences: Option<Vec<String>>,
48    ) -> Result<String>;
49
50    /// Generate chat completion
51    async fn generate_chat_completion(
52        &self,
53        messages: Vec<ChatMessage>,
54        max_tokens: Option<usize>,
55        temperature: Option<f32>,
56        top_p: Option<f32>,
57        stop_sequences: Option<Vec<String>>,
58    ) -> Result<String>;
59
60    /// Get available models
61    async fn get_available_models(&self) -> Result<Vec<String>>;
62
63    /// Check if provider is available
64    async fn is_available(&self) -> bool;
65
66    /// Get provider name
67    fn name(&self) -> &'static str;
68
69    /// Get maximum context length
70    fn max_context_length(&self) -> usize;
71}
72
73/// Embedding provider trait
74#[async_trait::async_trait]
75pub trait EmbeddingProviderTrait: Send + Sync {
76    /// Generate embedding for text
77    async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>>;
78
79    /// Generate embeddings for multiple texts
80    async fn generate_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>>;
81
82    /// Get embedding dimensions
83    fn embedding_dimensions(&self) -> usize;
84
85    /// Get maximum tokens for embedding
86    fn max_tokens(&self) -> usize;
87
88    /// Get provider name
89    fn name(&self) -> &'static str;
90
91    /// Check if provider is available
92    async fn is_available(&self) -> bool;
93}
94
95/// Chat message for LLM providers
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct ChatMessage {
98    /// Message role
99    pub role: ChatRole,
100    /// Message content
101    pub content: String,
102    /// Optional metadata
103    pub metadata: Option<HashMap<String, String>>,
104}
105
106impl ChatMessage {
107    /// Create a new system message
108    pub fn system(content: String) -> Self {
109        Self {
110            role: ChatRole::System,
111            content,
112            metadata: None,
113        }
114    }
115
116    /// Create a new user message
117    pub fn user(content: String) -> Self {
118        Self {
119            role: ChatRole::User,
120            content,
121            metadata: None,
122        }
123    }
124
125    /// Create a new assistant message
126    pub fn assistant(content: String) -> Self {
127        Self {
128            role: ChatRole::Assistant,
129            content,
130            metadata: None,
131        }
132    }
133
134    /// Add metadata to message
135    pub fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
136        self.metadata = Some(metadata);
137        self
138    }
139}
140
141/// Chat message role
142#[derive(Debug, Clone, Serialize, Deserialize)]
143#[serde(rename_all = "lowercase")]
144pub enum ChatRole {
145    /// System message
146    System,
147    /// User message
148    User,
149    /// Assistant message
150    Assistant,
151}
152
153/// OpenAI provider implementation
154pub struct OpenAiProvider {
155    api_key: String,
156    client: reqwest::Client,
157    base_url: String,
158}
159
160impl OpenAiProvider {
161    /// Create a new OpenAI provider
162    pub fn new(api_key: String) -> Self {
163        Self {
164            api_key,
165            client: reqwest::Client::new(),
166            base_url: "https://api.openai.com/v1".to_string(),
167        }
168    }
169
170    /// Create with custom base URL
171    pub fn new_with_base_url(api_key: String, base_url: String) -> Self {
172        Self {
173            api_key,
174            client: reqwest::Client::new(),
175            base_url,
176        }
177    }
178}
179
180#[async_trait::async_trait]
181impl LlmProviderTrait for OpenAiProvider {
182    async fn generate_completion(
183        &self,
184        prompt: &str,
185        max_tokens: Option<usize>,
186        temperature: Option<f32>,
187        top_p: Option<f32>,
188        stop_sequences: Option<Vec<String>>,
189    ) -> Result<String> {
190        let mut request_body = serde_json::json!({
191            "model": "gpt-3.5-turbo-instruct",
192            "prompt": prompt,
193            "max_tokens": max_tokens.unwrap_or(1024),
194            "temperature": temperature.unwrap_or(0.7),
195        });
196
197        if let Some(top_p) = top_p {
198            request_body["top_p"] = serde_json::json!(top_p);
199        }
200
201        if let Some(stop) = stop_sequences {
202            request_body["stop"] = serde_json::json!(stop);
203        }
204
205        let response = self
206            .client
207            .post(format!("{}/completions", self.base_url))
208            .header("Authorization", format!("Bearer {}", self.api_key))
209            .header("Content-Type", "application/json")
210            .json(&request_body)
211            .send()
212            .await?;
213
214        if !response.status().is_success() {
215            return Err(crate::Error::generic(format!("OpenAI API error: {}", response.status())));
216        }
217
218        let json: Value = response.json().await?;
219        let content = json["choices"][0]["text"]
220            .as_str()
221            .ok_or_else(|| crate::Error::generic("Invalid response format"))?;
222
223        Ok(content.to_string())
224    }
225
226    async fn generate_chat_completion(
227        &self,
228        messages: Vec<ChatMessage>,
229        max_tokens: Option<usize>,
230        temperature: Option<f32>,
231        top_p: Option<f32>,
232        stop_sequences: Option<Vec<String>>,
233    ) -> Result<String> {
234        let openai_messages: Vec<Value> = messages
235            .iter()
236            .map(|msg| {
237                serde_json::json!({
238                    "role": format!("{:?}", msg.role).to_lowercase(),
239                    "content": msg.content
240                })
241            })
242            .collect();
243
244        let mut request_body = serde_json::json!({
245            "model": "gpt-3.5-turbo",
246            "messages": openai_messages,
247            "max_tokens": max_tokens.unwrap_or(1024),
248            "temperature": temperature.unwrap_or(0.7),
249        });
250
251        if let Some(top_p) = top_p {
252            request_body["top_p"] = serde_json::json!(top_p);
253        }
254
255        if let Some(stop) = stop_sequences {
256            request_body["stop"] = serde_json::json!(stop);
257        }
258
259        let response = self
260            .client
261            .post(format!("{}/chat/completions", self.base_url))
262            .header("Authorization", format!("Bearer {}", self.api_key))
263            .header("Content-Type", "application/json")
264            .json(&request_body)
265            .send()
266            .await?;
267
268        if !response.status().is_success() {
269            return Err(crate::Error::generic(format!("OpenAI API error: {}", response.status())));
270        }
271
272        let json: Value = response.json().await?;
273        let content = json["choices"][0]["message"]["content"]
274            .as_str()
275            .ok_or_else(|| crate::Error::generic("Invalid response format"))?;
276
277        Ok(content.to_string())
278    }
279
280    async fn get_available_models(&self) -> Result<Vec<String>> {
281        let response = self
282            .client
283            .get(format!("{}/models", self.base_url))
284            .header("Authorization", format!("Bearer {}", self.api_key))
285            .send()
286            .await?;
287
288        if !response.status().is_success() {
289            return Err(crate::Error::generic(format!("OpenAI API error: {}", response.status())));
290        }
291
292        let json: Value = response.json().await?;
293        let models = json["data"]
294            .as_array()
295            .ok_or_else(|| crate::Error::generic("Invalid models response format"))?;
296
297        let model_names = models
298            .iter()
299            .filter_map(|model| model["id"].as_str().map(|s| s.to_string()))
300            .collect();
301
302        Ok(model_names)
303    }
304
305    async fn is_available(&self) -> bool {
306        (self.get_available_models().await).is_ok()
307    }
308
309    fn name(&self) -> &'static str {
310        "OpenAI"
311    }
312
313    fn max_context_length(&self) -> usize {
314        4096 // GPT-3.5 context length
315    }
316}
317
318/// OpenAI embedding provider implementation
319pub struct OpenAiEmbeddingProvider {
320    api_key: String,
321    client: reqwest::Client,
322    base_url: String,
323    model: String,
324}
325
326impl OpenAiEmbeddingProvider {
327    /// Create a new OpenAI embedding provider
328    pub fn new(api_key: String) -> Self {
329        Self {
330            api_key,
331            client: reqwest::Client::new(),
332            base_url: "https://api.openai.com/v1".to_string(),
333            model: "text-embedding-ada-002".to_string(),
334        }
335    }
336
337    /// Create with custom model
338    pub fn new_with_model(api_key: String, model: String) -> Self {
339        Self {
340            api_key,
341            client: reqwest::Client::new(),
342            base_url: "https://api.openai.com/v1".to_string(),
343            model,
344        }
345    }
346}
347
348#[async_trait::async_trait]
349impl EmbeddingProviderTrait for OpenAiEmbeddingProvider {
350    async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
351        let response = self
352            .client
353            .post(format!("{}/embeddings", self.base_url))
354            .header("Authorization", format!("Bearer {}", self.api_key))
355            .header("Content-Type", "application/json")
356            .json(&serde_json::json!({
357                "input": text,
358                "model": self.model
359            }))
360            .send()
361            .await?;
362
363        if !response.status().is_success() {
364            return Err(crate::Error::generic(format!("OpenAI API error: {}", response.status())));
365        }
366
367        let json: Value = response.json().await?;
368        let embedding = json["data"][0]["embedding"]
369            .as_array()
370            .ok_or_else(|| crate::Error::generic("Invalid embedding response format"))?;
371
372        Ok(embedding.iter().map(|v| v.as_f64().unwrap_or(0.0) as f32).collect())
373    }
374
375    async fn generate_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
376        let mut embeddings = Vec::new();
377
378        for text in texts {
379            let embedding = self.generate_embedding(&text).await?;
380            embeddings.push(embedding);
381        }
382
383        Ok(embeddings)
384    }
385
386    fn embedding_dimensions(&self) -> usize {
387        match self.model.as_str() {
388            "text-embedding-ada-002" => 1536,
389            "text-embedding-3-small" => 1536,
390            "text-embedding-3-large" => 3072,
391            _ => 1536, // Default
392        }
393    }
394
395    fn max_tokens(&self) -> usize {
396        match self.model.as_str() {
397            "text-embedding-ada-002" => 8191,
398            "text-embedding-3-small" => 8191,
399            "text-embedding-3-large" => 8191,
400            _ => 8191, // Default
401        }
402    }
403
404    fn name(&self) -> &'static str {
405        "OpenAI"
406    }
407
408    async fn is_available(&self) -> bool {
409        (self.generate_embedding("test").await).is_ok()
410    }
411}
412
413/// OpenAI-compatible provider implementation
414pub struct OpenAiCompatibleProvider {
415    api_key: String,
416    client: reqwest::Client,
417    base_url: String,
418    model: String,
419}
420
421impl OpenAiCompatibleProvider {
422    /// Create a new OpenAI-compatible provider
423    pub fn new(api_key: String, base_url: String, model: String) -> Self {
424        Self {
425            api_key,
426            client: reqwest::Client::new(),
427            base_url,
428            model,
429        }
430    }
431}
432
433#[async_trait::async_trait]
434impl LlmProviderTrait for OpenAiCompatibleProvider {
435    async fn generate_completion(
436        &self,
437        prompt: &str,
438        max_tokens: Option<usize>,
439        temperature: Option<f32>,
440        top_p: Option<f32>,
441        stop_sequences: Option<Vec<String>>,
442    ) -> Result<String> {
443        let mut request_body = serde_json::json!({
444            "model": self.model,
445            "prompt": prompt,
446            "max_tokens": max_tokens.unwrap_or(1024),
447            "temperature": temperature.unwrap_or(0.7),
448        });
449
450        if let Some(top_p) = top_p {
451            request_body["top_p"] = serde_json::json!(top_p);
452        }
453
454        if let Some(stop) = stop_sequences {
455            request_body["stop"] = serde_json::json!(stop);
456        }
457
458        let response = self
459            .client
460            .post(format!("{}/completions", self.base_url))
461            .header("Authorization", format!("Bearer {}", self.api_key))
462            .header("Content-Type", "application/json")
463            .json(&request_body)
464            .send()
465            .await?;
466
467        if !response.status().is_success() {
468            return Err(crate::Error::generic(format!("API error: {}", response.status())));
469        }
470
471        let json: Value = response.json().await?;
472        let content = json["choices"][0]["text"]
473            .as_str()
474            .ok_or_else(|| crate::Error::generic("Invalid response format"))?;
475
476        Ok(content.to_string())
477    }
478
479    async fn generate_chat_completion(
480        &self,
481        messages: Vec<ChatMessage>,
482        max_tokens: Option<usize>,
483        temperature: Option<f32>,
484        top_p: Option<f32>,
485        stop_sequences: Option<Vec<String>>,
486    ) -> Result<String> {
487        let openai_messages: Vec<Value> = messages
488            .iter()
489            .map(|msg| {
490                serde_json::json!({
491                    "role": format!("{:?}", msg.role).to_lowercase(),
492                    "content": msg.content
493                })
494            })
495            .collect();
496
497        let mut request_body = serde_json::json!({
498            "model": self.model,
499            "messages": openai_messages,
500            "max_tokens": max_tokens.unwrap_or(1024),
501            "temperature": temperature.unwrap_or(0.7),
502        });
503
504        if let Some(top_p) = top_p {
505            request_body["top_p"] = serde_json::json!(top_p);
506        }
507
508        if let Some(stop) = stop_sequences {
509            request_body["stop"] = serde_json::json!(stop);
510        }
511
512        let response = self
513            .client
514            .post(format!("{}/chat/completions", self.base_url))
515            .header("Authorization", format!("Bearer {}", self.api_key))
516            .header("Content-Type", "application/json")
517            .json(&request_body)
518            .send()
519            .await?;
520
521        if !response.status().is_success() {
522            return Err(crate::Error::generic(format!("API error: {}", response.status())));
523        }
524
525        let json: Value = response.json().await?;
526        let content = json["choices"][0]["message"]["content"]
527            .as_str()
528            .ok_or_else(|| crate::Error::generic("Invalid response format"))?;
529
530        Ok(content.to_string())
531    }
532
533    async fn get_available_models(&self) -> Result<Vec<String>> {
534        // Try to get models, but fall back gracefully if not available
535        match self
536            .client
537            .get(format!("{}/models", self.base_url))
538            .header("Authorization", format!("Bearer {}", self.api_key))
539            .send()
540            .await
541        {
542            Ok(response) if response.status().is_success() => {
543                let json: Value = response.json().await?;
544                let models = json["data"]
545                    .as_array()
546                    .ok_or_else(|| crate::Error::generic("Invalid models response format"))?;
547                Ok(models
548                    .iter()
549                    .filter_map(|model| model["id"].as_str().map(|s| s.to_string()))
550                    .collect())
551            }
552            _ => Ok(vec![self.model.clone()]), // Return configured model as fallback
553        }
554    }
555
556    async fn is_available(&self) -> bool {
557        (self.generate_completion("test", Some(1), None, None, None).await).is_ok()
558    }
559
560    fn name(&self) -> &'static str {
561        "OpenAI Compatible"
562    }
563
564    fn max_context_length(&self) -> usize {
565        4096 // Default context length
566    }
567}
568
569/// Anthropic provider implementation
570pub struct AnthropicProvider {
571    api_key: String,
572    client: reqwest::Client,
573    base_url: String,
574    model: String,
575}
576
577impl AnthropicProvider {
578    /// Create a new Anthropic provider
579    pub fn new(api_key: String, base_url: String, model: String) -> Self {
580        Self {
581            api_key,
582            client: reqwest::Client::new(),
583            base_url,
584            model,
585        }
586    }
587}
588
589#[async_trait::async_trait]
590impl LlmProviderTrait for AnthropicProvider {
591    async fn generate_completion(
592        &self,
593        prompt: &str,
594        max_tokens: Option<usize>,
595        temperature: Option<f32>,
596        top_p: Option<f32>,
597        stop_sequences: Option<Vec<String>>,
598    ) -> Result<String> {
599        let mut request_body = serde_json::json!({
600            "model": self.model,
601            "max_tokens": max_tokens.unwrap_or(1024),
602            "messages": [
603                {
604                    "role": "user",
605                    "content": prompt,
606                }
607            ],
608        });
609
610        if let Some(temp) = temperature {
611            request_body["temperature"] = serde_json::json!(temp);
612        }
613        if let Some(p) = top_p {
614            request_body["top_p"] = serde_json::json!(p);
615        }
616        if let Some(stop) = stop_sequences {
617            request_body["stop_sequences"] = serde_json::json!(stop);
618        }
619
620        let response = self
621            .client
622            .post(format!("{}/messages", self.base_url))
623            .header("x-api-key", &self.api_key)
624            .header("anthropic-version", "2023-06-01")
625            .header("Content-Type", "application/json")
626            .json(&request_body)
627            .send()
628            .await?;
629
630        if !response.status().is_success() {
631            return Err(crate::Error::generic(format!(
632                "Anthropic API error: {}",
633                response.status()
634            )));
635        }
636
637        let json: Value = response.json().await?;
638        let content = json["content"][0]["text"]
639            .as_str()
640            .ok_or_else(|| crate::Error::generic("Invalid response format"))?;
641
642        Ok(content.to_string())
643    }
644
645    async fn generate_chat_completion(
646        &self,
647        messages: Vec<ChatMessage>,
648        max_tokens: Option<usize>,
649        temperature: Option<f32>,
650        top_p: Option<f32>,
651        stop_sequences: Option<Vec<String>>,
652    ) -> Result<String> {
653        let mut anthropic_messages = Vec::new();
654        let mut system_parts = Vec::new();
655
656        for message in messages {
657            match message.role {
658                ChatRole::System => system_parts.push(message.content),
659                ChatRole::User => anthropic_messages.push(serde_json::json!({
660                    "role": "user",
661                    "content": message.content,
662                })),
663                ChatRole::Assistant => anthropic_messages.push(serde_json::json!({
664                    "role": "assistant",
665                    "content": message.content,
666                })),
667            }
668        }
669
670        if anthropic_messages.is_empty() {
671            anthropic_messages.push(serde_json::json!({
672                "role": "user",
673                "content": "",
674            }));
675        }
676
677        let mut request_body = serde_json::json!({
678            "model": self.model,
679            "max_tokens": max_tokens.unwrap_or(1024),
680            "messages": anthropic_messages,
681        });
682
683        if !system_parts.is_empty() {
684            request_body["system"] = serde_json::json!(system_parts.join("\n"));
685        }
686        if let Some(temp) = temperature {
687            request_body["temperature"] = serde_json::json!(temp);
688        }
689        if let Some(p) = top_p {
690            request_body["top_p"] = serde_json::json!(p);
691        }
692        if let Some(stop) = stop_sequences {
693            request_body["stop_sequences"] = serde_json::json!(stop);
694        }
695
696        let response = self
697            .client
698            .post(format!("{}/messages", self.base_url))
699            .header("x-api-key", &self.api_key)
700            .header("anthropic-version", "2023-06-01")
701            .header("Content-Type", "application/json")
702            .json(&request_body)
703            .send()
704            .await?;
705
706        if !response.status().is_success() {
707            return Err(crate::Error::generic(format!(
708                "Anthropic API error: {}",
709                response.status()
710            )));
711        }
712
713        let json: Value = response.json().await?;
714        let content = json["content"][0]["text"]
715            .as_str()
716            .ok_or_else(|| crate::Error::generic("Invalid response format"))?;
717
718        Ok(content.to_string())
719    }
720
721    async fn get_available_models(&self) -> Result<Vec<String>> {
722        Ok(vec![
723            "claude-3-5-sonnet-latest".to_string(),
724            "claude-3-5-haiku-latest".to_string(),
725        ])
726    }
727
728    fn name(&self) -> &'static str {
729        "Anthropic"
730    }
731
732    fn max_context_length(&self) -> usize {
733        200_000
734    }
735
736    async fn is_available(&self) -> bool {
737        (self.get_available_models().await).is_ok()
738    }
739}
740
741/// OpenAI-compatible embedding provider implementation
742pub struct OpenAiCompatibleEmbeddingProvider {
743    api_key: String,
744    client: reqwest::Client,
745    base_url: String,
746    model: String,
747}
748
749impl OpenAiCompatibleEmbeddingProvider {
750    /// Create a new OpenAI-compatible embedding provider
751    pub fn new(api_key: String, base_url: String, model: String) -> Self {
752        Self {
753            api_key,
754            client: reqwest::Client::new(),
755            base_url,
756            model,
757        }
758    }
759}
760
761#[async_trait::async_trait]
762impl EmbeddingProviderTrait for OpenAiCompatibleEmbeddingProvider {
763    async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
764        let response = self
765            .client
766            .post(format!("{}/embeddings", self.base_url))
767            .header("Authorization", format!("Bearer {}", self.api_key))
768            .header("Content-Type", "application/json")
769            .json(&serde_json::json!({
770                "input": text,
771                "model": self.model
772            }))
773            .send()
774            .await?;
775
776        if !response.status().is_success() {
777            return Err(crate::Error::generic(format!("API error: {}", response.status())));
778        }
779
780        let json: Value = response.json().await?;
781        let embedding = json["data"][0]["embedding"]
782            .as_array()
783            .ok_or_else(|| crate::Error::generic("Invalid embedding response format"))?;
784
785        Ok(embedding.iter().map(|v| v.as_f64().unwrap_or(0.0) as f32).collect())
786    }
787
788    async fn generate_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
789        let mut embeddings = Vec::new();
790
791        for text in texts {
792            let embedding = self.generate_embedding(&text).await?;
793            embeddings.push(embedding);
794        }
795
796        Ok(embeddings)
797    }
798
799    fn embedding_dimensions(&self) -> usize {
800        1536 // Default OpenAI embedding dimensions
801    }
802
803    fn max_tokens(&self) -> usize {
804        8191 // Default OpenAI token limit
805    }
806
807    fn name(&self) -> &'static str {
808        "OpenAI Compatible"
809    }
810
811    async fn is_available(&self) -> bool {
812        (self.generate_embedding("test").await).is_ok()
813    }
814}
815
816/// Provider factory for creating LLM and embedding providers
817pub struct ProviderFactory;
818
819impl ProviderFactory {
820    /// Create LLM provider from configuration
821    pub fn create_llm_provider(
822        provider_type: LlmProvider,
823        api_key: String,
824        base_url: Option<String>,
825        model: String,
826    ) -> Result<Box<dyn LlmProviderTrait>> {
827        match provider_type {
828            LlmProvider::OpenAI => {
829                let base_url = base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string());
830                Ok(Box::new(OpenAiProvider::new_with_base_url(api_key, base_url)))
831            }
832            LlmProvider::Anthropic => {
833                let base_url =
834                    base_url.unwrap_or_else(|| "https://api.anthropic.com/v1".to_string());
835                Ok(Box::new(AnthropicProvider::new(api_key, base_url, model)))
836            }
837            LlmProvider::Ollama => {
838                let base_url = base_url.unwrap_or_else(|| "http://localhost:11434/v1".to_string());
839                Ok(Box::new(OpenAiCompatibleProvider::new(api_key, base_url, model)))
840            }
841            LlmProvider::OpenAICompatible => {
842                let base_url = base_url.ok_or_else(|| {
843                    crate::Error::generic("Base URL required for OpenAI compatible provider")
844                })?;
845                Ok(Box::new(OpenAiCompatibleProvider::new(api_key, base_url, model)))
846            }
847        }
848    }
849
850    /// Create embedding provider from configuration
851    pub fn create_embedding_provider(
852        provider_type: EmbeddingProvider,
853        api_key: String,
854        base_url: Option<String>,
855        model: String,
856    ) -> Result<Box<dyn EmbeddingProviderTrait>> {
857        match provider_type {
858            EmbeddingProvider::OpenAI => {
859                Ok(Box::new(OpenAiEmbeddingProvider::new_with_model(api_key, model)))
860            }
861            EmbeddingProvider::OpenAICompatible => {
862                let base_url = base_url.ok_or_else(|| {
863                    crate::Error::generic(
864                        "Base URL required for OpenAI compatible embedding provider",
865                    )
866                })?;
867                Ok(Box::new(OpenAiCompatibleEmbeddingProvider::new(api_key, base_url, model)))
868            }
869            EmbeddingProvider::Ollama => {
870                // Ollama embeddings use OpenAI-compatible API
871                let base_url = base_url.ok_or_else(|| {
872                    crate::Error::generic("Base URL required for Ollama embedding provider")
873                })?;
874                // Ollama doesn't require API key, use empty string
875                Ok(Box::new(OpenAiCompatibleEmbeddingProvider::new(String::new(), base_url, model)))
876            }
877        }
878    }
879}
880
881#[cfg(test)]
882mod tests {
883    use super::{LlmProvider, ProviderFactory};
884
885    #[test]
886    fn test_module_compiles() {
887        // Basic compilation test
888    }
889
890    #[test]
891    fn test_create_anthropic_provider() {
892        let provider = ProviderFactory::create_llm_provider(
893            LlmProvider::Anthropic,
894            "key".to_string(),
895            None,
896            "claude-3-5-sonnet-latest".to_string(),
897        )
898        .expect("provider");
899        assert_eq!(provider.name(), "Anthropic");
900    }
901
902    #[test]
903    fn test_create_ollama_provider() {
904        let provider = ProviderFactory::create_llm_provider(
905            LlmProvider::Ollama,
906            String::new(),
907            None,
908            "llama3.1".to_string(),
909        )
910        .expect("provider");
911        assert_eq!(provider.name(), "OpenAI Compatible");
912    }
913}