1use crate::ToolSpec;
5use crate::traits::{ChatMessage, ChatRequest, ChatResponse, ModelProvider, TokenUsage, ToolCall};
6use async_trait::async_trait;
7use reqwest::Client;
8use reqwest::header::ACCEPT_ENCODING;
9use serde::{Deserialize, Serialize};
10
11const OPENROUTER_MAX_TRANSPORT_ATTEMPTS: u32 = 3;
12
13pub struct OpenRouterProvider {
14 api_key: Option<String>,
15 client: Client,
16 last_good_provider: std::sync::Mutex<Option<String>>,
19}
20
21#[derive(Debug, Serialize)]
22struct NativeChatRequest {
23 model: String,
24 messages: Vec<NativeMessage>,
25 temperature: f64,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 tools: Option<Vec<NativeToolSpec>>,
28 #[serde(skip_serializing_if = "Option::is_none")]
29 tool_choice: Option<String>,
30 #[serde(skip_serializing_if = "Option::is_none")]
31 provider: Option<NativeProviderRouting>,
32}
33
34#[derive(Debug, Serialize)]
35struct NativeProviderRouting {
36 order: Vec<String>,
37 allow_fallbacks: bool,
38}
39
40#[derive(Debug, Serialize)]
41struct NativeMessage {
42 role: String,
43 #[serde(skip_serializing_if = "Option::is_none")]
44 content: Option<String>,
45 #[serde(skip_serializing_if = "Option::is_none")]
46 tool_call_id: Option<String>,
47 #[serde(skip_serializing_if = "Option::is_none")]
48 tool_calls: Option<Vec<NativeToolCall>>,
49}
50
51#[derive(Debug, Serialize)]
52struct NativeToolSpec {
53 #[serde(rename = "type")]
54 kind: String,
55 function: NativeToolFunctionSpec,
56}
57
58#[derive(Debug, Serialize)]
59struct NativeToolFunctionSpec {
60 name: String,
61 description: String,
62 parameters: serde_json::Value,
63}
64
65#[derive(Debug, Serialize, Deserialize)]
66struct NativeToolCall {
67 #[serde(skip_serializing_if = "Option::is_none")]
68 id: Option<String>,
69 #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
70 kind: Option<String>,
71 function: NativeFunctionCall,
72}
73
74#[derive(Debug, Serialize, Deserialize)]
75struct NativeFunctionCall {
76 name: String,
77 arguments: String,
78}
79
80#[derive(Debug, Deserialize)]
81struct NativeUsage {
82 #[serde(default)]
83 prompt_tokens: u64,
84 #[serde(default)]
85 completion_tokens: u64,
86}
87
88#[derive(Debug, Deserialize)]
89struct NativeChatResponse {
90 choices: Vec<NativeChoice>,
91 #[serde(default)]
95 provider: Option<String>,
96 #[serde(default)]
97 openrouter_metadata: Option<NativeOpenRouterMetadata>,
98 #[serde(default)]
99 usage: Option<NativeUsage>,
100}
101
102#[derive(Debug, Deserialize)]
103struct NativeOpenRouterMetadata {
104 #[serde(default)]
105 endpoints: Option<NativeEndpointsMetadata>,
106}
107
108#[derive(Debug, Deserialize)]
109struct NativeEndpointsMetadata {
110 #[serde(default)]
111 available: Vec<NativeEndpointInfo>,
112}
113
114#[derive(Debug, Deserialize)]
115struct NativeEndpointInfo {
116 provider: String,
117 #[serde(default)]
118 selected: bool,
119}
120
121#[derive(Debug, Deserialize)]
122struct NativeChoice {
123 message: NativeResponseMessage,
124}
125
126#[derive(Debug, Deserialize)]
127struct NativeResponseMessage {
128 #[serde(default)]
129 content: Option<String>,
130 #[serde(default)]
131 tool_calls: Option<Vec<NativeToolCall>>,
132}
133
134impl OpenRouterProvider {
135 pub fn new(api_key: Option<&str>) -> Self {
136 Self {
137 api_key: api_key.map(ToString::to_string),
138 client: Client::builder()
139 .timeout(std::time::Duration::from_secs(120))
140 .connect_timeout(std::time::Duration::from_secs(10))
141 .build()
142 .unwrap_or_else(|_| Client::new()),
143 last_good_provider: std::sync::Mutex::new(None),
144 }
145 }
146
147 fn convert_tools(tools: Option<&[ToolSpec]>) -> Option<Vec<NativeToolSpec>> {
148 let items = tools?;
149 if items.is_empty() {
150 return None;
151 }
152 Some(
153 items
154 .iter()
155 .map(|tool| NativeToolSpec {
156 kind: "function".to_string(),
157 function: NativeToolFunctionSpec {
158 name: crate::sanitize_tool_name(&tool.name),
159 description: tool.description.clone(),
160 parameters: tool.parameters.clone(),
161 },
162 })
163 .collect(),
164 )
165 }
166
167 fn convert_messages(messages: &[ChatMessage]) -> Vec<NativeMessage> {
168 messages
169 .iter()
170 .map(|m| {
171 if m.role == "assistant"
172 && let Ok(value) = serde_json::from_str::<serde_json::Value>(&m.content)
173 && let Some(tool_calls_value) = value.get("tool_calls")
174 && let Ok(parsed_calls) =
175 serde_json::from_value::<Vec<ToolCall>>(tool_calls_value.clone())
176 {
177 let tool_calls = parsed_calls
178 .into_iter()
179 .map(|tc| NativeToolCall {
180 id: Some(tc.id),
181 kind: Some("function".to_string()),
182 function: NativeFunctionCall {
183 name: tc.name,
184 arguments: tc.arguments,
185 },
186 })
187 .collect::<Vec<_>>();
188 let content = value
189 .get("content")
190 .and_then(serde_json::Value::as_str)
191 .map(ToString::to_string);
192 return NativeMessage {
193 role: "assistant".to_string(),
194 content,
195 tool_call_id: None,
196 tool_calls: Some(tool_calls),
197 };
198 }
199
200 if m.role == "tool"
201 && let Ok(value) = serde_json::from_str::<serde_json::Value>(&m.content)
202 {
203 let tool_call_id = value
204 .get("tool_call_id")
205 .and_then(serde_json::Value::as_str)
206 .map(ToString::to_string);
207 let content = value
208 .get("content")
209 .and_then(serde_json::Value::as_str)
210 .map(ToString::to_string);
211 return NativeMessage {
212 role: "tool".to_string(),
213 content,
214 tool_call_id,
215 tool_calls: None,
216 };
217 }
218
219 NativeMessage {
220 role: m.role.clone(),
221 content: Some(m.content.clone()),
222 tool_call_id: None,
223 tool_calls: None,
224 }
225 })
226 .collect()
227 }
228
229 fn parse_native_response(message: NativeResponseMessage) -> ChatResponse {
230 let tool_calls = message
231 .tool_calls
232 .unwrap_or_default()
233 .into_iter()
234 .map(|tc| ToolCall {
235 id: tc.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
236 name: tc.function.name,
237 arguments: tc.function.arguments,
238 })
239 .collect::<Vec<_>>();
240
241 ChatResponse {
242 text: message.content,
243 tool_calls,
244 provider_tool_calls: vec![],
245 usage: TokenUsage::default(),
246 }
247 }
248
249 fn selected_provider_name(response: &NativeChatResponse) -> Option<String> {
250 response.provider.clone().or_else(|| {
251 response
252 .openrouter_metadata
253 .as_ref()
254 .and_then(|metadata| metadata.endpoints.as_ref())
255 .and_then(|endpoints| {
256 endpoints
257 .available
258 .iter()
259 .find(|endpoint| endpoint.selected)
260 })
261 .map(|endpoint| endpoint.provider.clone())
262 })
263 }
264}
265
266#[async_trait]
267impl ModelProvider for OpenRouterProvider {
268 async fn warmup(&self) -> anyhow::Result<()> {
269 if let Some(api_key) = self.api_key.as_ref() {
272 self.client
273 .get("https://openrouter.ai/api/v1/auth/key")
274 .header("Authorization", format!("Bearer {api_key}"))
275 .send()
276 .await?
277 .error_for_status()?;
278 }
279 Ok(())
280 }
281
282 async fn chat(
283 &self,
284 request: ChatRequest<'_>,
285 model: &str,
286 temperature: f64,
287 ) -> anyhow::Result<ChatResponse> {
288 let api_key = self.api_key.as_ref().ok_or_else(|| {
289 anyhow::anyhow!("OpenRouter API key not set. Set OPENROUTER_API_KEY env var.")
290 })?;
291
292 let tools = Self::convert_tools(request.tools);
293
294 let provider_routing = self
297 .last_good_provider
298 .lock()
299 .ok()
300 .and_then(|guard| guard.clone())
301 .map(|p| NativeProviderRouting {
302 order: vec![p],
303 allow_fallbacks: true,
304 });
305
306 let messages = Self::convert_messages(request.messages);
307
308 let estimated_chars: usize = messages
310 .iter()
311 .map(|m| m.content.as_deref().unwrap_or("").len())
312 .sum();
313 let estimated_tokens = estimated_chars / 4;
314 tracing::info!(
315 model = model,
316 messages = messages.len(),
317 estimated_tokens = estimated_tokens,
318 "OpenRouter request"
319 );
320
321 let native_request = NativeChatRequest {
322 model: model.to_string(),
323 messages,
324 temperature,
325 tool_choice: tools.as_ref().map(|_| "auto".to_string()),
326 tools,
327 provider: provider_routing,
328 };
329
330 let body_text = {
331 let mut last_error = None;
332 let mut body = None;
333
334 for attempt in 1..=OPENROUTER_MAX_TRANSPORT_ATTEMPTS {
335 let response = match self
336 .client
337 .post("https://openrouter.ai/api/v1/chat/completions")
338 .header("Authorization", format!("Bearer {api_key}"))
339 .header("HTTP-Referer", "https://github.com/nenjo-ai/nenjo")
340 .header("X-Title", "Nenjo")
341 .header(ACCEPT_ENCODING, "identity")
342 .json(&native_request)
343 .send()
344 .await
345 {
346 Ok(response) => response,
347 Err(error) => {
348 last_error = Some(anyhow::anyhow!(
349 "OpenRouter: request failed (~{estimated_tokens} input tokens, \
350 {messages_count} messages, attempt {attempt}/{OPENROUTER_MAX_TRANSPORT_ATTEMPTS}): {error}",
351 messages_count = native_request.messages.len(),
352 ));
353 if attempt < OPENROUTER_MAX_TRANSPORT_ATTEMPTS {
354 tokio::time::sleep(std::time::Duration::from_millis(
355 250 * u64::from(attempt),
356 ))
357 .await;
358 continue;
359 }
360 break;
361 }
362 };
363
364 let status = response.status();
365 if !status.is_success() {
366 return Err(crate::api_error("OpenRouter", response).await);
367 }
368
369 match response.text().await {
370 Ok(text) => {
371 body = Some(text);
372 break;
373 }
374 Err(error) => {
375 last_error = Some(anyhow::anyhow!(
376 "OpenRouter: failed to read response body (status {status}, \
377 ~{estimated_tokens} input tokens, {messages_count} messages, \
378 attempt {attempt}/{OPENROUTER_MAX_TRANSPORT_ATTEMPTS}): {error}",
379 messages_count = native_request.messages.len(),
380 ));
381 if attempt < OPENROUTER_MAX_TRANSPORT_ATTEMPTS {
382 tokio::time::sleep(std::time::Duration::from_millis(
383 250 * u64::from(attempt),
384 ))
385 .await;
386 }
387 }
388 }
389 }
390
391 body.ok_or_else(|| {
392 last_error.unwrap_or_else(|| anyhow::anyhow!("OpenRouter: empty response body"))
393 })?
394 };
395 if let Ok(value) = serde_json::from_str::<serde_json::Value>(&body_text)
399 && let Some(err) = value.get("error")
400 {
401 let msg = err
402 .get("message")
403 .and_then(serde_json::Value::as_str)
404 .unwrap_or("unknown error");
405 return Err(anyhow::anyhow!(
406 "OpenRouter returned an error in a 200 response: {msg}"
407 ));
408 }
409
410 let native_response: NativeChatResponse =
411 serde_json::from_str(&body_text).map_err(|e| {
412 anyhow::anyhow!(
413 "OpenRouter response decode error: {e}\nBody: {}",
414 &body_text[..body_text.len().min(500)]
415 )
416 })?;
417
418 if let Some(provider_name) = Self::selected_provider_name(&native_response)
421 && let Ok(mut guard) = self.last_good_provider.lock()
422 {
423 *guard = Some(provider_name);
424 }
425
426 let usage = native_response
427 .usage
428 .map(|u| TokenUsage {
429 input_tokens: u.prompt_tokens,
430 output_tokens: u.completion_tokens,
431 })
432 .unwrap_or_default();
433
434 let message = native_response
435 .choices
436 .into_iter()
437 .next()
438 .map(|c| c.message)
439 .ok_or_else(|| anyhow::anyhow!("No response from OpenRouter"))?;
440 let mut result = Self::parse_native_response(message);
441 result.usage = usage;
442 Ok(result)
443 }
444
445 fn context_window(&self, model: &str) -> Option<usize> {
446 let m = model.to_lowercase();
448 if m.contains("claude-opus-4")
449 || m.contains("claude-sonnet-4.6")
450 || m.contains("claude-sonnet-4-6")
451 {
452 Some(1_000_000)
453 } else if m.contains("claude-sonnet-4")
454 || m.contains("claude-haiku-4")
455 || m.contains("claude-3.5")
456 || m.contains("claude-3-")
457 || m.contains("claude-3.7")
458 {
459 Some(200_000)
460 } else if m.contains("gpt-5") {
461 Some(1_000_000)
462 } else if m.contains("gpt-4o") {
463 Some(128_000)
464 } else if m.contains("o1") || m.contains("o3") || m.contains("o4") {
465 Some(200_000)
466 } else if m.contains("gemini") {
467 Some(1_000_000)
468 } else if m.contains("deepseek") {
469 Some(128_000)
470 } else if m.contains("llama-4") || m.contains("llama4") {
471 Some(1_000_000)
472 } else if m.contains("llama-3") || m.contains("llama3") {
473 Some(128_000)
474 } else if m.contains("mistral-large") || m.contains("qwen") {
475 Some(256_000)
476 } else if m.contains("grok-4") && m.contains("fast") {
477 Some(2_000_000)
478 } else if m.contains("grok-4") {
479 Some(256_000)
480 } else if m.contains("grok-3") {
481 Some(1_000_000)
482 } else if m.contains("kimi") {
483 Some(256_000)
484 } else if m.contains("minimax") {
485 Some(200_000)
486 } else {
487 None
488 }
489 }
490
491 fn supports_native_tools(&self) -> bool {
492 true
493 }
494
495 fn supports_developer_role(&self, model: &str) -> bool {
496 let m = model.to_lowercase();
497 (m.contains("openai/") || m.contains("azure/"))
500 && (m.contains("/o1")
501 || m.contains("/o3")
502 || m.contains("/o4")
503 || m.contains("/gpt-5")
504 || m.contains("/gpt-4.5")
505 || m.contains("/gpt-4.1"))
506 }
507}
508
509#[cfg(test)]
510mod tests {
511 use super::*;
512 use crate::traits::{ChatMessage, ChatRequest, ModelProvider};
513
514 #[test]
515 fn creates_with_key() {
516 let provider = OpenRouterProvider::new(Some("sk-or-123"));
517 assert_eq!(provider.api_key.as_deref(), Some("sk-or-123"));
518 }
519
520 #[test]
521 fn creates_without_key() {
522 let provider = OpenRouterProvider::new(None);
523 assert!(provider.api_key.is_none());
524 }
525
526 #[tokio::test]
527 async fn warmup_without_key_is_noop() {
528 let provider = OpenRouterProvider::new(None);
529 let result = provider.warmup().await;
530 assert!(result.is_ok());
531 }
532
533 #[test]
534 fn developer_role_only_for_openai_newer_models() {
535 let provider = OpenRouterProvider::new(None);
536 assert!(provider.supports_developer_role("openai/gpt-5.1"));
537 assert!(provider.supports_developer_role("openai/gpt-4.1"));
538 assert!(provider.supports_developer_role("openai/o3"));
539 assert!(!provider.supports_developer_role("openai/gpt-4o"));
540 assert!(!provider.supports_developer_role("anthropic/claude-sonnet-4"));
541 assert!(!provider.supports_developer_role("minimax/minimax-m2.5"));
542 }
543
544 #[test]
545 fn selected_provider_uses_openrouter_metadata() {
546 let response: NativeChatResponse = serde_json::from_value(serde_json::json!({
547 "choices": [{
548 "message": {
549 "role": "assistant",
550 "content": "ok"
551 }
552 }],
553 "openrouter_metadata": {
554 "endpoints": {
555 "available": [
556 {
557 "model": "minimax/minimax-m2.5",
558 "provider": "Clarifai",
559 "selected": false
560 },
561 {
562 "model": "minimax/minimax-m2.5",
563 "provider": "Minimax",
564 "selected": true
565 }
566 ],
567 "total": 2
568 }
569 }
570 }))
571 .unwrap();
572
573 assert_eq!(
574 OpenRouterProvider::selected_provider_name(&response).as_deref(),
575 Some("Minimax")
576 );
577 }
578
579 #[test]
580 fn selected_provider_preserves_legacy_top_level_provider() {
581 let response: NativeChatResponse = serde_json::from_value(serde_json::json!({
582 "provider": "SambaNova",
583 "choices": [{
584 "message": {
585 "role": "assistant",
586 "content": "ok"
587 }
588 }],
589 "openrouter_metadata": {
590 "endpoints": {
591 "available": [{
592 "model": "meta-llama/llama-3",
593 "provider": "Together",
594 "selected": true
595 }],
596 "total": 1
597 }
598 }
599 }))
600 .unwrap();
601
602 assert_eq!(
603 OpenRouterProvider::selected_provider_name(&response).as_deref(),
604 Some("SambaNova")
605 );
606 }
607
608 #[tokio::test]
609 async fn chat_fails_without_key() {
610 let provider = OpenRouterProvider::new(None);
611 let messages = vec![ChatMessage::system("system"), ChatMessage::user("hello")];
612 let request = ChatRequest {
613 messages: &messages,
614 tools: None,
615 native_tools: None,
616 };
617 let result = provider.chat(request, "openai/gpt-4o", 0.2).await;
618
619 assert!(result.is_err());
620 assert!(result.unwrap_err().to_string().contains("API key not set"));
621 }
622
623 #[tokio::test]
624 async fn chat_with_history_fails_without_key() {
625 let provider = OpenRouterProvider::new(None);
626 let messages = vec![
627 ChatMessage::system("be concise"),
628 ChatMessage::user("hello"),
629 ];
630 let request = ChatRequest {
631 messages: &messages,
632 tools: None,
633 native_tools: None,
634 };
635 let result = provider
636 .chat(request, "anthropic/claude-sonnet-4", 0.7)
637 .await;
638
639 assert!(result.is_err());
640 assert!(result.unwrap_err().to_string().contains("API key not set"));
641 }
642}