1use llmg_core::{
13 provider::{ApiKeyCredentials, Credentials, LlmError, Provider},
14 types::{ChatCompletionRequest, ChatCompletionResponse, EmbeddingRequest, EmbeddingResponse},
15};
16use std::sync::Arc;
17#[derive(Debug, Clone)]
26pub struct OpenRouterClient {
27 http_client: reqwest::Client,
28 base_url: String,
29 credentials: Arc<dyn Credentials>,
30 app_name: Option<String>,
31 http_referer: Option<String>,
32}
33
34#[derive(Debug, Clone, Default)]
36pub struct OpenRouterExtras {
37 pub provider: Option<serde_json::Value>,
39 pub transforms: Option<Vec<String>>,
41 pub route: Option<String>,
43 pub models: Option<Vec<String>>,
45}
46
47#[derive(Debug, serde::Serialize)]
49struct OpenRouterChatRequest {
50 #[serde(flatten)]
51 base: ChatCompletionRequest,
52 #[serde(skip_serializing_if = "Option::is_none")]
53 provider: Option<serde_json::Value>,
54 #[serde(skip_serializing_if = "Option::is_none")]
55 transforms: Option<Vec<String>>,
56 #[serde(skip_serializing_if = "Option::is_none")]
57 route: Option<String>,
58 #[serde(skip_serializing_if = "Option::is_none")]
59 models: Option<Vec<String>>,
60}
61
62impl OpenRouterClient {
63 pub fn from_env() -> Result<Self, LlmError> {
68 let api_key = std::env::var("OPENROUTER_API_KEY").map_err(|_| LlmError::AuthError)?;
69
70 let base_url = std::env::var("OPENROUTER_API_BASE")
71 .unwrap_or_else(|_| "https://openrouter.ai/api/v1".to_string());
72
73 let app_name = std::env::var("OPENROUTER_APP_NAME").ok();
74 let http_referer = std::env::var("OPENROUTER_HTTP_REFERER").ok();
75
76 Ok(Self::with_config(api_key, base_url, app_name, http_referer))
77 }
78
79 pub fn new(api_key: impl Into<String>) -> Self {
81 Self::with_config(
82 api_key,
83 "https://openrouter.ai/api/v1".to_string(),
84 None,
85 None,
86 )
87 }
88
89 pub fn with_config(
91 api_key: impl Into<String>,
92 base_url: impl Into<String>,
93 app_name: Option<String>,
94 http_referer: Option<String>,
95 ) -> Self {
96 let api_key = api_key.into();
97
98 Self {
99 http_client: reqwest::Client::new(),
100 base_url: base_url.into(),
101 credentials: Arc::new(ApiKeyCredentials::bearer(api_key)),
102 app_name,
103 http_referer,
104 }
105 }
106
107 pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
109 self.base_url = url.into();
110 self
111 }
112
113 pub fn with_app_name(mut self, name: impl Into<String>) -> Self {
115 self.app_name = Some(name.into());
116 self
117 }
118
119 pub fn with_http_referer(mut self, referer: impl Into<String>) -> Self {
121 self.http_referer = Some(referer.into());
122 self
123 }
124
125 fn build_request(
127 &self,
128 request: ChatCompletionRequest,
129 extras: Option<OpenRouterExtras>,
130 ) -> Result<reqwest::Request, LlmError> {
131 let url = format!("{}/chat/completions", self.base_url);
132
133 let openrouter_req = if let Some(extras) = extras {
135 OpenRouterChatRequest {
136 base: request,
137 provider: extras.provider,
138 transforms: extras.transforms,
139 route: extras.route,
140 models: extras.models,
141 }
142 } else {
143 OpenRouterChatRequest {
144 base: request,
145 provider: None,
146 transforms: None,
147 route: None,
148 models: None,
149 }
150 };
151
152 let mut req_builder = self.http_client.post(&url).json(&openrouter_req);
153
154 if let Some(ref app_name) = self.app_name {
156 req_builder = req_builder.header("X-Title", app_name);
157 }
158
159 if let Some(ref referer) = self.http_referer {
160 req_builder = req_builder.header("HTTP-Referer", referer);
161 }
162
163 let mut req = req_builder
164 .build()
165 .map_err(|e| LlmError::HttpError(e.to_string()))?;
166
167 self.credentials.apply(&mut req)?;
168
169 Ok(req)
170 }
171
172 async fn make_request(
173 &self,
174 request: ChatCompletionRequest,
175 ) -> Result<ChatCompletionResponse, LlmError> {
176 let req = self.build_request(request, None)?;
177
178 let response = self
179 .http_client
180 .execute(req)
181 .await
182 .map_err(|e| LlmError::HttpError(e.to_string()))?;
183
184 if !response.status().is_success() {
185 let status = response.status().as_u16();
186 let text = response.text().await.unwrap_or_default();
187 return Err(LlmError::ApiError {
188 status,
189 message: text,
190 });
191 }
192
193 response
194 .json::<ChatCompletionResponse>()
195 .await
196 .map_err(|e| LlmError::HttpError(e.to_string()))
197 }
198
199 pub async fn chat_completion_with_extras(
201 &self,
202 request: ChatCompletionRequest,
203 extras: OpenRouterExtras,
204 ) -> Result<ChatCompletionResponse, LlmError> {
205 let req = self.build_request(request, Some(extras))?;
206
207 let response = self
208 .http_client
209 .execute(req)
210 .await
211 .map_err(|e| LlmError::HttpError(e.to_string()))?;
212
213 if !response.status().is_success() {
214 let status = response.status().as_u16();
215 let text = response.text().await.unwrap_or_default();
216 return Err(LlmError::ApiError {
217 status,
218 message: text,
219 });
220 }
221
222 response
223 .json::<ChatCompletionResponse>()
224 .await
225 .map_err(|e| LlmError::HttpError(e.to_string()))
226 }
227}
228
229#[async_trait::async_trait]
230impl Provider for OpenRouterClient {
231 async fn chat_completion(
232 &self,
233 request: ChatCompletionRequest,
234 ) -> Result<ChatCompletionResponse, LlmError> {
235 self.make_request(request).await
236 }
237
238 async fn embeddings(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse, LlmError> {
239 let url = format!("{}/embeddings", self.base_url);
240
241 let mut req = self
242 .http_client
243 .post(&url)
244 .json(&request)
245 .build()
246 .map_err(|e| LlmError::HttpError(e.to_string()))?;
247
248 self.credentials.apply(&mut req)?;
249
250 let response = self
251 .http_client
252 .execute(req)
253 .await
254 .map_err(|e| LlmError::HttpError(e.to_string()))?;
255
256 if !response.status().is_success() {
257 let status = response.status().as_u16();
258 let text = response.text().await.unwrap_or_default();
259 return Err(LlmError::ApiError {
260 status,
261 message: text,
262 });
263 }
264
265 response
266 .json::<EmbeddingResponse>()
267 .await
268 .map_err(|e| LlmError::HttpError(e.to_string()))
269 }
270 fn provider_name(&self) -> &'static str {
271 "openrouter"
272 }
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278 use llmg_core::types::Message;
279
280 #[test]
281 fn test_openrouter_client_creation() {
282 let client = OpenRouterClient::new("test-key");
283 assert_eq!(client.provider_name(), "openrouter");
284 }
285
286 #[test]
287 fn test_from_env_missing_key() {
288 let original = std::env::var("OPENROUTER_API_KEY").ok();
290 std::env::remove_var("OPENROUTER_API_KEY");
291
292 let result = OpenRouterClient::from_env();
293 assert!(result.is_err());
294
295 if let Some(key) = original {
297 std::env::set_var("OPENROUTER_API_KEY", key);
298 }
299 }
300
301 #[test]
302 fn test_custom_config() {
303 let client = OpenRouterClient::with_config(
304 "test-key",
305 "https://custom.openrouter.ai/api/v1",
306 Some("MyApp".to_string()),
307 Some("https://myapp.com".to_string()),
308 );
309
310 assert_eq!(client.base_url, "https://custom.openrouter.ai/api/v1");
311 assert_eq!(client.app_name, Some("MyApp".to_string()));
312 assert_eq!(client.http_referer, Some("https://myapp.com".to_string()));
313 }
314
315 #[test]
316 fn test_extras_builder() {
317 let extras = OpenRouterExtras {
318 provider: Some(serde_json::json!({"order": ["Anthropic", "OpenAI"]})),
319 transforms: Some(vec!["middle-out".to_string()]),
320 route: Some("fallback".to_string()),
321 models: Some(vec!["anthropic/claude-3-opus".to_string()]),
322 };
323
324 let request = ChatCompletionRequest {
325 model: "anthropic/claude-3-opus".to_string(),
326 messages: vec![Message::User {
327 content: "Hello".to_string(),
328 name: None,
329 }],
330 temperature: None,
331 max_tokens: None,
332 stream: None,
333 top_p: None,
334 frequency_penalty: None,
335 presence_penalty: None,
336 stop: None,
337 user: None,
338 tools: None,
339 tool_choice: None,
340 response_format: None,
341 };
342
343 let client = OpenRouterClient::new("test-key").with_app_name("test-app");
344 let built_req = client.build_request(request, Some(extras)).unwrap();
345
346 assert!(built_req.headers().contains_key("x-title"));
348 let body = String::from_utf8_lossy(built_req.body().unwrap().as_bytes().unwrap());
349 assert!(body.contains("provider"));
350 }
351}
352