1use anyhow::Result;
2use enum_dispatch::enum_dispatch;
3use reqwest;
4use rig::{
5 completion::{CompletionError, CompletionRequest, CompletionResponse},
6 providers::*,
7};
8use std::borrow::Cow;
9
10#[enum_dispatch]
11#[allow(async_fn_in_trait)]
12pub trait CompletionProvider {
13 async fn completion(
14 &self,
15 request: CompletionRequest,
16 ) -> Result<CompletionResponse<()>, CompletionError>;
17}
18
19#[enum_dispatch(CompletionProvider)]
20#[derive(Clone)]
21pub enum LMClient {
22 OpenAI(openai::completion::CompletionModel),
23 Gemini(gemini::completion::CompletionModel),
24 Anthropic(anthropic::completion::CompletionModel),
25 Groq(groq::CompletionModel<reqwest::Client>),
26 OpenRouter(openrouter::completion::CompletionModel),
27 Ollama(ollama::CompletionModel<reqwest::Client>),
28 Azure(azure::CompletionModel<reqwest::Client>),
29 XAI(xai::completion::CompletionModel),
30 Cohere(cohere::completion::CompletionModel),
31 Mistral(mistral::completion::CompletionModel),
32 Together(together::completion::CompletionModel),
33 Deepseek(deepseek::CompletionModel<reqwest::Client>),
34}
35
36impl CompletionProvider for openai::completion::CompletionModel {
38 async fn completion(
39 &self,
40 request: CompletionRequest,
41 ) -> Result<CompletionResponse<()>, CompletionError> {
42 let response = rig::completion::CompletionModel::completion(self, request).await?;
43 Ok(CompletionResponse {
45 choice: response.choice,
46 usage: response.usage,
47 raw_response: (),
48 })
49 }
50}
51
52impl CompletionProvider for anthropic::completion::CompletionModel {
53 async fn completion(
54 &self,
55 request: CompletionRequest,
56 ) -> Result<CompletionResponse<()>, CompletionError> {
57 let response = rig::completion::CompletionModel::completion(self, request).await?;
58 Ok(CompletionResponse {
59 choice: response.choice,
60 usage: response.usage,
61 raw_response: (),
62 })
63 }
64}
65
66impl CompletionProvider for gemini::completion::CompletionModel {
67 async fn completion(
68 &self,
69 request: CompletionRequest,
70 ) -> Result<CompletionResponse<()>, CompletionError> {
71 let response = rig::completion::CompletionModel::completion(self, request).await?;
72 Ok(CompletionResponse {
73 choice: response.choice,
74 usage: response.usage,
75 raw_response: (),
76 })
77 }
78}
79
80impl CompletionProvider for groq::CompletionModel<reqwest::Client> {
81 async fn completion(
82 &self,
83 request: CompletionRequest,
84 ) -> Result<CompletionResponse<()>, CompletionError> {
85 let response = rig::completion::CompletionModel::completion(self, request).await?;
86 Ok(CompletionResponse {
87 choice: response.choice,
88 usage: response.usage,
89 raw_response: (),
90 })
91 }
92}
93
94impl CompletionProvider for openrouter::completion::CompletionModel {
95 async fn completion(
96 &self,
97 request: CompletionRequest,
98 ) -> Result<CompletionResponse<()>, CompletionError> {
99 let response = rig::completion::CompletionModel::completion(self, request).await?;
100 Ok(CompletionResponse {
101 choice: response.choice,
102 usage: response.usage,
103 raw_response: (),
104 })
105 }
106}
107
108impl CompletionProvider for ollama::CompletionModel<reqwest::Client> {
109 async fn completion(
110 &self,
111 request: CompletionRequest,
112 ) -> Result<CompletionResponse<()>, CompletionError> {
113 let response = rig::completion::CompletionModel::completion(self, request).await?;
114 Ok(CompletionResponse {
115 choice: response.choice,
116 usage: response.usage,
117 raw_response: (),
118 })
119 }
120}
121
122impl CompletionProvider for azure::CompletionModel<reqwest::Client> {
123 async fn completion(
124 &self,
125 request: CompletionRequest,
126 ) -> Result<CompletionResponse<()>, CompletionError> {
127 let response = rig::completion::CompletionModel::completion(self, request).await?;
128 Ok(CompletionResponse {
129 choice: response.choice,
130 usage: response.usage,
131 raw_response: (),
132 })
133 }
134}
135impl CompletionProvider for xai::completion::CompletionModel {
136 async fn completion(
137 &self,
138 request: CompletionRequest,
139 ) -> Result<CompletionResponse<()>, CompletionError> {
140 let response = rig::completion::CompletionModel::completion(self, request).await?;
141 Ok(CompletionResponse {
142 choice: response.choice,
143 usage: response.usage,
144 raw_response: (),
145 })
146 }
147}
148
149impl CompletionProvider for cohere::completion::CompletionModel {
150 async fn completion(
151 &self,
152 request: CompletionRequest,
153 ) -> Result<CompletionResponse<()>, CompletionError> {
154 let response = rig::completion::CompletionModel::completion(self, request).await?;
155 Ok(CompletionResponse {
156 choice: response.choice,
157 usage: response.usage,
158 raw_response: (),
159 })
160 }
161}
162
163impl CompletionProvider for mistral::completion::CompletionModel {
164 async fn completion(
165 &self,
166 request: CompletionRequest,
167 ) -> Result<CompletionResponse<()>, CompletionError> {
168 let response = rig::completion::CompletionModel::completion(self, request).await?;
169 Ok(CompletionResponse {
170 choice: response.choice,
171 usage: response.usage,
172 raw_response: (),
173 })
174 }
175}
176
177impl CompletionProvider for together::completion::CompletionModel {
178 async fn completion(
179 &self,
180 request: CompletionRequest,
181 ) -> Result<CompletionResponse<()>, CompletionError> {
182 let response = rig::completion::CompletionModel::completion(self, request).await?;
183 Ok(CompletionResponse {
184 choice: response.choice,
185 usage: response.usage,
186 raw_response: (),
187 })
188 }
189}
190
191impl CompletionProvider for deepseek::CompletionModel<reqwest::Client> {
192 async fn completion(
193 &self,
194 request: CompletionRequest,
195 ) -> Result<CompletionResponse<()>, CompletionError> {
196 let response = rig::completion::CompletionModel::completion(self, request).await?;
197 Ok(CompletionResponse {
198 choice: response.choice,
199 usage: response.usage,
200 raw_response: (),
201 })
202 }
203}
204
205impl LMClient {
206 fn get_api_key<'a>(provided: Option<&'a str>, env_var: &str) -> Result<Cow<'a, str>> {
207 match provided {
208 Some(k) => Ok(Cow::Borrowed(k)),
209 None => Ok(Cow::Owned(std::env::var(env_var).map_err(|_| {
210 anyhow::anyhow!("{} environment variable not set", env_var)
211 })?)),
212 }
213 }
214
215 pub fn from_openai_compatible(base_url: &str, api_key: &str, model: &str) -> Result<Self> {
217 let client = openai::ClientBuilder::new(api_key)
218 .base_url(base_url)
219 .build();
220 Ok(LMClient::OpenAI(openai::completion::CompletionModel::new(
221 client, model,
222 )))
223 }
224
225 pub fn from_local(base_url: &str, model: &str) -> Result<Self> {
228 let client = openai::ClientBuilder::new("dummy-key-for-local-server")
229 .base_url(base_url)
230 .build();
231 Ok(LMClient::OpenAI(openai::completion::CompletionModel::new(
232 client, model,
233 )))
234 }
235
236 pub fn from_model_string(model_str: &str, api_key: Option<&str>) -> Result<Self> {
238 let (provider, model_id) = model_str.split_once(':').ok_or(anyhow::anyhow!(
239 "Model string must be in format 'provider:model_name'"
240 ))?;
241
242 match provider {
243 "openai" => {
244 let key = Self::get_api_key(api_key, "OPENAI_API_KEY")?;
245 let client = openai::ClientBuilder::new(&key).build();
246 Ok(LMClient::OpenAI(openai::completion::CompletionModel::new(
247 client, model_id,
248 )))
249 }
250 "anthropic" => {
251 let key = Self::get_api_key(api_key, "ANTHROPIC_API_KEY")?;
252 let client = anthropic::ClientBuilder::new(&key).build()?;
253 Ok(LMClient::Anthropic(
254 anthropic::completion::CompletionModel::new(client, model_id),
255 ))
256 }
257 "gemini" => {
258 let key = Self::get_api_key(api_key, "GEMINI_API_KEY")?;
259 let client = gemini::client::ClientBuilder::<reqwest::Client>::new(&key).build()?;
260 Ok(LMClient::Gemini(gemini::completion::CompletionModel::new(
261 client, model_id,
262 )))
263 }
264 "ollama" => {
265 let client = ollama::ClientBuilder::new().build();
266 Ok(LMClient::Ollama(ollama::CompletionModel::new(
267 client, model_id,
268 )))
269 }
270 "openrouter" => {
271 let key = Self::get_api_key(api_key, "OPENROUTER_API_KEY")?;
272 let client = openrouter::ClientBuilder::new(&key).build();
273 Ok(LMClient::OpenRouter(
274 openrouter::completion::CompletionModel::new(client, model_id),
275 ))
276 }
277 _ => {
278 anyhow::bail!(
279 "Unsupported provider: {}. Supported providers are: openai, anthropic, gemini, groq, openrouter, ollama",
280 provider
281 );
282 }
283 }
284 }
285
286 pub fn from_custom<T: Into<LMClient>>(client: T) -> Self
292 {
293 client.into()
294 }
295}