Skip to main content

nenjo_models/
openrouter.rs

1//! OpenRouter aggregator provider. Authenticates via Bearer token, routes to
2//! multiple upstream models with provider-order pinning.
3
4use crate::ToolSpec;
5use crate::traits::{ChatMessage, ChatRequest, ChatResponse, ModelProvider, TokenUsage, ToolCall};
6use async_trait::async_trait;
7use reqwest::Client;
8use reqwest::header::ACCEPT_ENCODING;
9use serde::{Deserialize, Serialize};
10
11const OPENROUTER_MAX_TRANSPORT_ATTEMPTS: u32 = 3;
12
13pub struct OpenRouterProvider {
14    api_key: Option<String>,
15    client: Client,
16    /// Track the last upstream provider that served a successful response
17    /// so we can pin future requests to it and avoid broken fallbacks.
18    last_good_provider: std::sync::Mutex<Option<String>>,
19}
20
21#[derive(Debug, Serialize)]
22struct NativeChatRequest {
23    model: String,
24    messages: Vec<NativeMessage>,
25    temperature: f64,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    tools: Option<Vec<NativeToolSpec>>,
28    #[serde(skip_serializing_if = "Option::is_none")]
29    tool_choice: Option<String>,
30    #[serde(skip_serializing_if = "Option::is_none")]
31    provider: Option<NativeProviderRouting>,
32}
33
34#[derive(Debug, Serialize)]
35struct NativeProviderRouting {
36    order: Vec<String>,
37    allow_fallbacks: bool,
38}
39
40#[derive(Debug, Serialize)]
41struct NativeMessage {
42    role: String,
43    #[serde(skip_serializing_if = "Option::is_none")]
44    content: Option<String>,
45    #[serde(skip_serializing_if = "Option::is_none")]
46    tool_call_id: Option<String>,
47    #[serde(skip_serializing_if = "Option::is_none")]
48    tool_calls: Option<Vec<NativeToolCall>>,
49}
50
51#[derive(Debug, Serialize)]
52struct NativeToolSpec {
53    #[serde(rename = "type")]
54    kind: String,
55    function: NativeToolFunctionSpec,
56}
57
58#[derive(Debug, Serialize)]
59struct NativeToolFunctionSpec {
60    name: String,
61    description: String,
62    parameters: serde_json::Value,
63}
64
65#[derive(Debug, Serialize, Deserialize)]
66struct NativeToolCall {
67    #[serde(skip_serializing_if = "Option::is_none")]
68    id: Option<String>,
69    #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
70    kind: Option<String>,
71    function: NativeFunctionCall,
72}
73
74#[derive(Debug, Serialize, Deserialize)]
75struct NativeFunctionCall {
76    name: String,
77    arguments: String,
78}
79
80#[derive(Debug, Deserialize)]
81struct NativeUsage {
82    #[serde(default)]
83    prompt_tokens: u64,
84    #[serde(default)]
85    completion_tokens: u64,
86}
87
88#[derive(Debug, Deserialize)]
89struct NativeChatResponse {
90    choices: Vec<NativeChoice>,
91    /// The upstream provider that served this response (e.g. "SambaNova").
92    /// Older OpenRouter responses exposed this at the top level; the current
93    /// OpenAPI schema exposes it in `openrouter_metadata`.
94    #[serde(default)]
95    provider: Option<String>,
96    #[serde(default)]
97    openrouter_metadata: Option<NativeOpenRouterMetadata>,
98    #[serde(default)]
99    usage: Option<NativeUsage>,
100}
101
102#[derive(Debug, Deserialize)]
103struct NativeOpenRouterMetadata {
104    #[serde(default)]
105    endpoints: Option<NativeEndpointsMetadata>,
106}
107
108#[derive(Debug, Deserialize)]
109struct NativeEndpointsMetadata {
110    #[serde(default)]
111    available: Vec<NativeEndpointInfo>,
112}
113
114#[derive(Debug, Deserialize)]
115struct NativeEndpointInfo {
116    provider: String,
117    #[serde(default)]
118    selected: bool,
119}
120
121#[derive(Debug, Deserialize)]
122struct NativeChoice {
123    message: NativeResponseMessage,
124}
125
126#[derive(Debug, Deserialize)]
127struct NativeResponseMessage {
128    #[serde(default)]
129    content: Option<String>,
130    #[serde(default)]
131    tool_calls: Option<Vec<NativeToolCall>>,
132}
133
134impl OpenRouterProvider {
135    pub fn new(api_key: Option<&str>) -> Self {
136        Self {
137            api_key: api_key.map(ToString::to_string),
138            client: Client::builder()
139                .timeout(std::time::Duration::from_secs(120))
140                .connect_timeout(std::time::Duration::from_secs(10))
141                .build()
142                .unwrap_or_else(|_| Client::new()),
143            last_good_provider: std::sync::Mutex::new(None),
144        }
145    }
146
147    fn convert_tools(tools: Option<&[ToolSpec]>) -> Option<Vec<NativeToolSpec>> {
148        let items = tools?;
149        if items.is_empty() {
150            return None;
151        }
152        Some(
153            items
154                .iter()
155                .map(|tool| NativeToolSpec {
156                    kind: "function".to_string(),
157                    function: NativeToolFunctionSpec {
158                        name: crate::sanitize_tool_name(&tool.name),
159                        description: tool.description.clone(),
160                        parameters: tool.parameters.clone(),
161                    },
162                })
163                .collect(),
164        )
165    }
166
167    fn convert_messages(messages: &[ChatMessage]) -> Vec<NativeMessage> {
168        messages
169            .iter()
170            .map(|m| {
171                if m.role == "assistant"
172                    && let Ok(value) = serde_json::from_str::<serde_json::Value>(&m.content)
173                    && let Some(tool_calls_value) = value.get("tool_calls")
174                    && let Ok(parsed_calls) =
175                        serde_json::from_value::<Vec<ToolCall>>(tool_calls_value.clone())
176                {
177                    let tool_calls = parsed_calls
178                        .into_iter()
179                        .map(|tc| NativeToolCall {
180                            id: Some(tc.id),
181                            kind: Some("function".to_string()),
182                            function: NativeFunctionCall {
183                                name: tc.name,
184                                arguments: tc.arguments,
185                            },
186                        })
187                        .collect::<Vec<_>>();
188                    let content = value
189                        .get("content")
190                        .and_then(serde_json::Value::as_str)
191                        .map(ToString::to_string);
192                    return NativeMessage {
193                        role: "assistant".to_string(),
194                        content,
195                        tool_call_id: None,
196                        tool_calls: Some(tool_calls),
197                    };
198                }
199
200                if m.role == "tool"
201                    && let Ok(value) = serde_json::from_str::<serde_json::Value>(&m.content)
202                {
203                    let tool_call_id = value
204                        .get("tool_call_id")
205                        .and_then(serde_json::Value::as_str)
206                        .map(ToString::to_string);
207                    let content = value
208                        .get("content")
209                        .and_then(serde_json::Value::as_str)
210                        .map(ToString::to_string);
211                    return NativeMessage {
212                        role: "tool".to_string(),
213                        content,
214                        tool_call_id,
215                        tool_calls: None,
216                    };
217                }
218
219                NativeMessage {
220                    role: m.role.clone(),
221                    content: Some(m.content.clone()),
222                    tool_call_id: None,
223                    tool_calls: None,
224                }
225            })
226            .collect()
227    }
228
229    fn parse_native_response(message: NativeResponseMessage) -> ChatResponse {
230        let tool_calls = message
231            .tool_calls
232            .unwrap_or_default()
233            .into_iter()
234            .map(|tc| ToolCall {
235                id: tc.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
236                name: tc.function.name,
237                arguments: tc.function.arguments,
238            })
239            .collect::<Vec<_>>();
240
241        ChatResponse {
242            text: message.content,
243            tool_calls,
244            provider_tool_calls: vec![],
245            usage: TokenUsage::default(),
246        }
247    }
248
249    fn selected_provider_name(response: &NativeChatResponse) -> Option<String> {
250        response.provider.clone().or_else(|| {
251            response
252                .openrouter_metadata
253                .as_ref()
254                .and_then(|metadata| metadata.endpoints.as_ref())
255                .and_then(|endpoints| {
256                    endpoints
257                        .available
258                        .iter()
259                        .find(|endpoint| endpoint.selected)
260                })
261                .map(|endpoint| endpoint.provider.clone())
262        })
263    }
264}
265
266#[async_trait]
267impl ModelProvider for OpenRouterProvider {
268    async fn warmup(&self) -> anyhow::Result<()> {
269        // Hit a lightweight endpoint to establish TLS + HTTP/2 connection pool.
270        // This prevents the first real chat request from timing out on cold start.
271        if let Some(api_key) = self.api_key.as_ref() {
272            self.client
273                .get("https://openrouter.ai/api/v1/auth/key")
274                .header("Authorization", format!("Bearer {api_key}"))
275                .send()
276                .await?
277                .error_for_status()?;
278        }
279        Ok(())
280    }
281
282    async fn chat(
283        &self,
284        request: ChatRequest<'_>,
285        model: &str,
286        temperature: f64,
287    ) -> anyhow::Result<ChatResponse> {
288        let api_key = self.api_key.as_ref().ok_or_else(|| {
289            anyhow::anyhow!("OpenRouter API key not set. Set OPENROUTER_API_KEY env var.")
290        })?;
291
292        let tools = Self::convert_tools(request.tools);
293
294        // Pin to the last successful upstream provider to avoid broken
295        // fallbacks (e.g. Clarifai failing for minimax models).
296        let provider_routing = self
297            .last_good_provider
298            .lock()
299            .ok()
300            .and_then(|guard| guard.clone())
301            .map(|p| NativeProviderRouting {
302                order: vec![p],
303                allow_fallbacks: true,
304            });
305
306        let messages = Self::convert_messages(request.messages);
307
308        // Log estimated request size so context-too-large issues are visible
309        let estimated_chars: usize = messages
310            .iter()
311            .map(|m| m.content.as_deref().unwrap_or("").len())
312            .sum();
313        let estimated_tokens = estimated_chars / 4;
314        tracing::info!(
315            model = model,
316            messages = messages.len(),
317            estimated_tokens = estimated_tokens,
318            "OpenRouter request"
319        );
320
321        let native_request = NativeChatRequest {
322            model: model.to_string(),
323            messages,
324            temperature,
325            tool_choice: tools.as_ref().map(|_| "auto".to_string()),
326            tools,
327            provider: provider_routing,
328        };
329
330        let body_text = {
331            let mut last_error = None;
332            let mut body = None;
333
334            for attempt in 1..=OPENROUTER_MAX_TRANSPORT_ATTEMPTS {
335                let response = match self
336                    .client
337                    .post("https://openrouter.ai/api/v1/chat/completions")
338                    .header("Authorization", format!("Bearer {api_key}"))
339                    .header("HTTP-Referer", "https://github.com/nenjo-ai/nenjo")
340                    .header("X-Title", "Nenjo")
341                    .header(ACCEPT_ENCODING, "identity")
342                    .json(&native_request)
343                    .send()
344                    .await
345                {
346                    Ok(response) => response,
347                    Err(error) => {
348                        last_error = Some(anyhow::anyhow!(
349                            "OpenRouter: request failed (~{estimated_tokens} input tokens, \
350                             {messages_count} messages, attempt {attempt}/{OPENROUTER_MAX_TRANSPORT_ATTEMPTS}): {error}",
351                            messages_count = native_request.messages.len(),
352                        ));
353                        if attempt < OPENROUTER_MAX_TRANSPORT_ATTEMPTS {
354                            tokio::time::sleep(std::time::Duration::from_millis(
355                                250 * u64::from(attempt),
356                            ))
357                            .await;
358                            continue;
359                        }
360                        break;
361                    }
362                };
363
364                let status = response.status();
365                if !status.is_success() {
366                    return Err(crate::api_error("OpenRouter", response).await);
367                }
368
369                match response.text().await {
370                    Ok(text) => {
371                        body = Some(text);
372                        break;
373                    }
374                    Err(error) => {
375                        last_error = Some(anyhow::anyhow!(
376                            "OpenRouter: failed to read response body (status {status}, \
377                             ~{estimated_tokens} input tokens, {messages_count} messages, \
378                             attempt {attempt}/{OPENROUTER_MAX_TRANSPORT_ATTEMPTS}): {error}",
379                            messages_count = native_request.messages.len(),
380                        ));
381                        if attempt < OPENROUTER_MAX_TRANSPORT_ATTEMPTS {
382                            tokio::time::sleep(std::time::Duration::from_millis(
383                                250 * u64::from(attempt),
384                            ))
385                            .await;
386                        }
387                    }
388                }
389            }
390
391            body.ok_or_else(|| {
392                last_error.unwrap_or_else(|| anyhow::anyhow!("OpenRouter: empty response body"))
393            })?
394        };
395        // OpenRouter can return HTTP 200 with an error payload when a
396        // downstream provider (e.g. Clarifai) fails.  Detect this before
397        // trying to parse as a normal chat completion.
398        if let Ok(value) = serde_json::from_str::<serde_json::Value>(&body_text)
399            && let Some(err) = value.get("error")
400        {
401            let msg = err
402                .get("message")
403                .and_then(serde_json::Value::as_str)
404                .unwrap_or("unknown error");
405            return Err(anyhow::anyhow!(
406                "OpenRouter returned an error in a 200 response: {msg}"
407            ));
408        }
409
410        let native_response: NativeChatResponse =
411            serde_json::from_str(&body_text).map_err(|e| {
412                anyhow::anyhow!(
413                    "OpenRouter response decode error: {e}\nBody: {}",
414                    &body_text[..body_text.len().min(500)]
415                )
416            })?;
417
418        // Track the upstream provider that served this response so we
419        // can pin future requests to it.
420        if let Some(provider_name) = Self::selected_provider_name(&native_response)
421            && let Ok(mut guard) = self.last_good_provider.lock()
422        {
423            *guard = Some(provider_name);
424        }
425
426        let usage = native_response
427            .usage
428            .map(|u| TokenUsage {
429                input_tokens: u.prompt_tokens,
430                output_tokens: u.completion_tokens,
431            })
432            .unwrap_or_default();
433
434        let message = native_response
435            .choices
436            .into_iter()
437            .next()
438            .map(|c| c.message)
439            .ok_or_else(|| anyhow::anyhow!("No response from OpenRouter"))?;
440        let mut result = Self::parse_native_response(message);
441        result.usage = usage;
442        Ok(result)
443    }
444
445    fn context_window(&self, model: &str) -> Option<usize> {
446        // OpenRouter routes to many models. Match on the model slug.
447        let m = model.to_lowercase();
448        if m.contains("claude-opus-4")
449            || m.contains("claude-sonnet-4.6")
450            || m.contains("claude-sonnet-4-6")
451        {
452            Some(1_000_000)
453        } else if m.contains("claude-sonnet-4")
454            || m.contains("claude-haiku-4")
455            || m.contains("claude-3.5")
456            || m.contains("claude-3-")
457            || m.contains("claude-3.7")
458        {
459            Some(200_000)
460        } else if m.contains("gpt-5") {
461            Some(1_000_000)
462        } else if m.contains("gpt-4o") {
463            Some(128_000)
464        } else if m.contains("o1") || m.contains("o3") || m.contains("o4") {
465            Some(200_000)
466        } else if m.contains("gemini") {
467            Some(1_000_000)
468        } else if m.contains("deepseek") {
469            Some(128_000)
470        } else if m.contains("llama-4") || m.contains("llama4") {
471            Some(1_000_000)
472        } else if m.contains("llama-3") || m.contains("llama3") {
473            Some(128_000)
474        } else if m.contains("mistral-large") || m.contains("qwen") {
475            Some(256_000)
476        } else if m.contains("grok-4") && m.contains("fast") {
477            Some(2_000_000)
478        } else if m.contains("grok-4") {
479            Some(256_000)
480        } else if m.contains("grok-3") {
481            Some(1_000_000)
482        } else if m.contains("kimi") {
483            Some(256_000)
484        } else if m.contains("minimax") {
485            Some(200_000)
486        } else {
487            None
488        }
489    }
490
491    fn supports_native_tools(&self) -> bool {
492        true
493    }
494
495    fn supports_developer_role(&self, model: &str) -> bool {
496        let m = model.to_lowercase();
497        // Only OpenAI newer models support the developer role.
498        // Other providers behind OpenRouter (Anthropic, Google, Meta, etc.) do not.
499        (m.contains("openai/") || m.contains("azure/"))
500            && (m.contains("/o1")
501                || m.contains("/o3")
502                || m.contains("/o4")
503                || m.contains("/gpt-5")
504                || m.contains("/gpt-4.5")
505                || m.contains("/gpt-4.1"))
506    }
507}
508
509#[cfg(test)]
510mod tests {
511    use super::*;
512    use crate::traits::{ChatMessage, ChatRequest, ModelProvider};
513
514    #[test]
515    fn creates_with_key() {
516        let provider = OpenRouterProvider::new(Some("sk-or-123"));
517        assert_eq!(provider.api_key.as_deref(), Some("sk-or-123"));
518    }
519
520    #[test]
521    fn creates_without_key() {
522        let provider = OpenRouterProvider::new(None);
523        assert!(provider.api_key.is_none());
524    }
525
526    #[tokio::test]
527    async fn warmup_without_key_is_noop() {
528        let provider = OpenRouterProvider::new(None);
529        let result = provider.warmup().await;
530        assert!(result.is_ok());
531    }
532
533    #[test]
534    fn developer_role_only_for_openai_newer_models() {
535        let provider = OpenRouterProvider::new(None);
536        assert!(provider.supports_developer_role("openai/gpt-5.1"));
537        assert!(provider.supports_developer_role("openai/gpt-4.1"));
538        assert!(provider.supports_developer_role("openai/o3"));
539        assert!(!provider.supports_developer_role("openai/gpt-4o"));
540        assert!(!provider.supports_developer_role("anthropic/claude-sonnet-4"));
541        assert!(!provider.supports_developer_role("minimax/minimax-m2.5"));
542    }
543
544    #[test]
545    fn selected_provider_uses_openrouter_metadata() {
546        let response: NativeChatResponse = serde_json::from_value(serde_json::json!({
547            "choices": [{
548                "message": {
549                    "role": "assistant",
550                    "content": "ok"
551                }
552            }],
553            "openrouter_metadata": {
554                "endpoints": {
555                    "available": [
556                        {
557                            "model": "minimax/minimax-m2.5",
558                            "provider": "Clarifai",
559                            "selected": false
560                        },
561                        {
562                            "model": "minimax/minimax-m2.5",
563                            "provider": "Minimax",
564                            "selected": true
565                        }
566                    ],
567                    "total": 2
568                }
569            }
570        }))
571        .unwrap();
572
573        assert_eq!(
574            OpenRouterProvider::selected_provider_name(&response).as_deref(),
575            Some("Minimax")
576        );
577    }
578
579    #[test]
580    fn selected_provider_preserves_legacy_top_level_provider() {
581        let response: NativeChatResponse = serde_json::from_value(serde_json::json!({
582            "provider": "SambaNova",
583            "choices": [{
584                "message": {
585                    "role": "assistant",
586                    "content": "ok"
587                }
588            }],
589            "openrouter_metadata": {
590                "endpoints": {
591                    "available": [{
592                        "model": "meta-llama/llama-3",
593                        "provider": "Together",
594                        "selected": true
595                    }],
596                    "total": 1
597                }
598            }
599        }))
600        .unwrap();
601
602        assert_eq!(
603            OpenRouterProvider::selected_provider_name(&response).as_deref(),
604            Some("SambaNova")
605        );
606    }
607
608    #[tokio::test]
609    async fn chat_fails_without_key() {
610        let provider = OpenRouterProvider::new(None);
611        let messages = vec![ChatMessage::system("system"), ChatMessage::user("hello")];
612        let request = ChatRequest {
613            messages: &messages,
614            tools: None,
615            native_tools: None,
616        };
617        let result = provider.chat(request, "openai/gpt-4o", 0.2).await;
618
619        assert!(result.is_err());
620        assert!(result.unwrap_err().to_string().contains("API key not set"));
621    }
622
623    #[tokio::test]
624    async fn chat_with_history_fails_without_key() {
625        let provider = OpenRouterProvider::new(None);
626        let messages = vec![
627            ChatMessage::system("be concise"),
628            ChatMessage::user("hello"),
629        ];
630        let request = ChatRequest {
631            messages: &messages,
632            tools: None,
633            native_tools: None,
634        };
635        let result = provider
636            .chat(request, "anthropic/claude-sonnet-4", 0.7)
637            .await;
638
639        assert!(result.is_err());
640        assert!(result.unwrap_err().to_string().contains("API key not set"));
641    }
642}