dspy_rs/core/lm/
client_registry.rs

1use anyhow::Result;
2use enum_dispatch::enum_dispatch;
3use reqwest;
4use rig::{
5    completion::{CompletionError, CompletionRequest, CompletionResponse},
6    providers::*,
7};
8use std::borrow::Cow;
9
10#[enum_dispatch]
11#[allow(async_fn_in_trait)]
12pub trait CompletionProvider {
13    async fn completion(
14        &self,
15        request: CompletionRequest,
16    ) -> Result<CompletionResponse<()>, CompletionError>;
17}
18
19#[enum_dispatch(CompletionProvider)]
20#[derive(Clone)]
21pub enum LMClient {
22    OpenAI(openai::completion::CompletionModel),
23    Gemini(gemini::completion::CompletionModel),
24    Anthropic(anthropic::completion::CompletionModel),
25    Groq(groq::CompletionModel<reqwest::Client>),
26    OpenRouter(openrouter::completion::CompletionModel),
27    Ollama(ollama::CompletionModel<reqwest::Client>),
28    Azure(azure::CompletionModel<reqwest::Client>),
29    XAI(xai::completion::CompletionModel),
30    Cohere(cohere::completion::CompletionModel),
31    Mistral(mistral::completion::CompletionModel),
32    Together(together::completion::CompletionModel),
33    Deepseek(deepseek::CompletionModel<reqwest::Client>),
34}
35
36// Implement the trait for each concrete provider type using the CompletionModel trait from rig
37impl CompletionProvider for openai::completion::CompletionModel {
38    async fn completion(
39        &self,
40        request: CompletionRequest,
41    ) -> Result<CompletionResponse<()>, CompletionError> {
42        let response = rig::completion::CompletionModel::completion(self, request).await?;
43        // Convert the typed response to unit type
44        Ok(CompletionResponse {
45            choice: response.choice,
46            usage: response.usage,
47            raw_response: (),
48        })
49    }
50}
51
52impl CompletionProvider for anthropic::completion::CompletionModel {
53    async fn completion(
54        &self,
55        request: CompletionRequest,
56    ) -> Result<CompletionResponse<()>, CompletionError> {
57        let response = rig::completion::CompletionModel::completion(self, request).await?;
58        Ok(CompletionResponse {
59            choice: response.choice,
60            usage: response.usage,
61            raw_response: (),
62        })
63    }
64}
65
66impl CompletionProvider for gemini::completion::CompletionModel {
67    async fn completion(
68        &self,
69        request: CompletionRequest,
70    ) -> Result<CompletionResponse<()>, CompletionError> {
71        let response = rig::completion::CompletionModel::completion(self, request).await?;
72        Ok(CompletionResponse {
73            choice: response.choice,
74            usage: response.usage,
75            raw_response: (),
76        })
77    }
78}
79
80impl CompletionProvider for groq::CompletionModel<reqwest::Client> {
81    async fn completion(
82        &self,
83        request: CompletionRequest,
84    ) -> Result<CompletionResponse<()>, CompletionError> {
85        let response = rig::completion::CompletionModel::completion(self, request).await?;
86        Ok(CompletionResponse {
87            choice: response.choice,
88            usage: response.usage,
89            raw_response: (),
90        })
91    }
92}
93
94impl CompletionProvider for openrouter::completion::CompletionModel {
95    async fn completion(
96        &self,
97        request: CompletionRequest,
98    ) -> Result<CompletionResponse<()>, CompletionError> {
99        let response = rig::completion::CompletionModel::completion(self, request).await?;
100        Ok(CompletionResponse {
101            choice: response.choice,
102            usage: response.usage,
103            raw_response: (),
104        })
105    }
106}
107
108impl CompletionProvider for ollama::CompletionModel<reqwest::Client> {
109    async fn completion(
110        &self,
111        request: CompletionRequest,
112    ) -> Result<CompletionResponse<()>, CompletionError> {
113        let response = rig::completion::CompletionModel::completion(self, request).await?;
114        Ok(CompletionResponse {
115            choice: response.choice,
116            usage: response.usage,
117            raw_response: (),
118        })
119    }
120}
121
122impl CompletionProvider for azure::CompletionModel<reqwest::Client> {
123    async fn completion(
124        &self,
125        request: CompletionRequest,
126    ) -> Result<CompletionResponse<()>, CompletionError> {
127        let response = rig::completion::CompletionModel::completion(self, request).await?;
128        Ok(CompletionResponse {
129            choice: response.choice,
130            usage: response.usage,
131            raw_response: (),
132        })
133    }
134}
135impl CompletionProvider for xai::completion::CompletionModel {
136    async fn completion(
137        &self,
138        request: CompletionRequest,
139    ) -> Result<CompletionResponse<()>, CompletionError> {
140        let response = rig::completion::CompletionModel::completion(self, request).await?;
141        Ok(CompletionResponse {
142            choice: response.choice,
143            usage: response.usage,
144            raw_response: (),
145        })
146    }
147}
148
149impl CompletionProvider for cohere::completion::CompletionModel {
150    async fn completion(
151        &self,
152        request: CompletionRequest,
153    ) -> Result<CompletionResponse<()>, CompletionError> {
154        let response = rig::completion::CompletionModel::completion(self, request).await?;
155        Ok(CompletionResponse {
156            choice: response.choice,
157            usage: response.usage,
158            raw_response: (),
159        })
160    }
161}
162
163impl CompletionProvider for mistral::completion::CompletionModel {
164    async fn completion(
165        &self,
166        request: CompletionRequest,
167    ) -> Result<CompletionResponse<()>, CompletionError> {
168        let response = rig::completion::CompletionModel::completion(self, request).await?;
169        Ok(CompletionResponse {
170            choice: response.choice,
171            usage: response.usage,
172            raw_response: (),
173        })
174    }
175}
176
177impl CompletionProvider for together::completion::CompletionModel {
178    async fn completion(
179        &self,
180        request: CompletionRequest,
181    ) -> Result<CompletionResponse<()>, CompletionError> {
182        let response = rig::completion::CompletionModel::completion(self, request).await?;
183        Ok(CompletionResponse {
184            choice: response.choice,
185            usage: response.usage,
186            raw_response: (),
187        })
188    }
189}
190
191impl CompletionProvider for deepseek::CompletionModel<reqwest::Client> {
192    async fn completion(
193        &self,
194        request: CompletionRequest,
195    ) -> Result<CompletionResponse<()>, CompletionError> {
196        let response = rig::completion::CompletionModel::completion(self, request).await?;
197        Ok(CompletionResponse {
198            choice: response.choice,
199            usage: response.usage,
200            raw_response: (),
201        })
202    }
203}
204
205impl LMClient {
206    fn get_api_key<'a>(provided: Option<&'a str>, env_var: &str) -> Result<Cow<'a, str>> {
207        match provided {
208            Some(k) => Ok(Cow::Borrowed(k)),
209            None => Ok(Cow::Owned(std::env::var(env_var).map_err(|_| {
210                anyhow::anyhow!("{} environment variable not set", env_var)
211            })?)),
212        }
213    }
214
215    /// Build case 1: OpenAI-compatible API from base_url + api_key
216    pub fn from_openai_compatible(base_url: &str, api_key: &str, model: &str) -> Result<Self> {
217        let client = openai::ClientBuilder::new(api_key)
218            .base_url(base_url)
219            .build();
220        Ok(LMClient::OpenAI(openai::completion::CompletionModel::new(
221            client, model,
222        )))
223    }
224
225    /// Build case 2: Local OpenAI-compatible model from base_url (vLLM, etc.)
226    /// Uses a dummy API key since local servers don't require authentication
227    pub fn from_local(base_url: &str, model: &str) -> Result<Self> {
228        let client = openai::ClientBuilder::new("dummy-key-for-local-server")
229            .base_url(base_url)
230            .build();
231        Ok(LMClient::OpenAI(openai::completion::CompletionModel::new(
232            client, model,
233        )))
234    }
235
236    /// Build case 3: From provider via model name (provider:model format)
237    pub fn from_model_string(model_str: &str, api_key: Option<&str>) -> Result<Self> {
238        let (provider, model_id) = model_str.split_once(':').ok_or(anyhow::anyhow!(
239            "Model string must be in format 'provider:model_name'"
240        ))?;
241
242        match provider {
243            "openai" => {
244                let key = Self::get_api_key(api_key, "OPENAI_API_KEY")?;
245                let client = openai::ClientBuilder::new(&key).build();
246                Ok(LMClient::OpenAI(openai::completion::CompletionModel::new(
247                    client, model_id,
248                )))
249            }
250            "anthropic" => {
251                let key = Self::get_api_key(api_key, "ANTHROPIC_API_KEY")?;
252                let client = anthropic::ClientBuilder::new(&key).build()?;
253                Ok(LMClient::Anthropic(
254                    anthropic::completion::CompletionModel::new(client, model_id),
255                ))
256            }
257            "gemini" => {
258                let key = Self::get_api_key(api_key, "GEMINI_API_KEY")?;
259                let client = gemini::client::ClientBuilder::<reqwest::Client>::new(&key).build()?;
260                Ok(LMClient::Gemini(gemini::completion::CompletionModel::new(
261                    client, model_id,
262                )))
263            }
264            "ollama" => {
265                let client = ollama::ClientBuilder::new().build();
266                Ok(LMClient::Ollama(ollama::CompletionModel::new(
267                    client, model_id,
268                )))
269            }
270            "openrouter" => {
271                let key = Self::get_api_key(api_key, "OPENROUTER_API_KEY")?;
272                let client = openrouter::ClientBuilder::new(&key).build();
273                Ok(LMClient::OpenRouter(
274                    openrouter::completion::CompletionModel::new(client, model_id),
275                ))
276            }
277            _ => {
278                anyhow::bail!(
279                    "Unsupported provider: {}. Supported providers are: openai, anthropic, gemini, groq, openrouter, ollama",
280                    provider
281                );
282            }
283        }
284    }
285
286    /// Convert a concrete completion model to LMClient
287    ///
288    /// This function accepts concrete types that can be converted to LMClient.
289    /// The enum_dispatch macro automatically generates From implementations for
290    /// each variant type, so you can use this with any concrete completion model.
291    pub fn from_custom<T: Into<LMClient>>(client: T) -> Self
292    {
293        client.into()
294    }
295}