Skip to main content

codetether_agent/provider/
openrouter.rs

1//! OpenRouter provider implementation using raw HTTP
2//!
3//! This provider uses reqwest directly instead of async_openai to handle
4//! OpenRouter's extended response formats (like Kimi's reasoning fields).
5
6use super::util;
7use super::{
8    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
9    Role, StreamChunk, ToolDefinition, Usage,
10};
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use reqwest::Client;
14use serde::Deserialize;
15use serde_json::{Value, json};
16
17pub struct OpenRouterProvider {
18    client: Client,
19    api_key: String,
20    base_url: String,
21}
22
23impl OpenRouterProvider {
24    pub fn new(api_key: String) -> Result<Self> {
25        let client = Client::builder()
26            .connect_timeout(std::time::Duration::from_secs(15))
27            .timeout(std::time::Duration::from_secs(300))
28            .build()
29            .context("Failed to build reqwest client")?;
30        Ok(Self {
31            client,
32            api_key,
33            base_url: "https://openrouter.ai/api/v1".to_string(),
34        })
35    }
36
37    fn convert_messages(messages: &[Message]) -> Vec<Value> {
38        messages
39            .iter()
40            .map(|msg| {
41                let role = match msg.role {
42                    Role::System => "system",
43                    Role::User => "user",
44                    Role::Assistant => "assistant",
45                    Role::Tool => "tool",
46                };
47
48                match msg.role {
49                    Role::Tool => {
50                        // Tool result message
51                        if let Some(ContentPart::ToolResult {
52                            tool_call_id,
53                            content,
54                        }) = msg.content.first()
55                        {
56                            json!({
57                                "role": "tool",
58                                "tool_call_id": tool_call_id,
59                                "content": content
60                            })
61                        } else {
62                            json!({"role": role, "content": ""})
63                        }
64                    }
65                    Role::Assistant => {
66                        // Assistant message - may have tool calls
67                        let text: String = msg
68                            .content
69                            .iter()
70                            .filter_map(|p| match p {
71                                ContentPart::Text { text } => Some(text.clone()),
72                                _ => None,
73                            })
74                            .collect::<Vec<_>>()
75                            .join("");
76
77                        let tool_calls: Vec<Value> = msg
78                            .content
79                            .iter()
80                            .filter_map(|p| match p {
81                                ContentPart::ToolCall {
82                                    id,
83                                    name,
84                                    arguments,
85                                    ..
86                                } => Some(json!({
87                                    "id": id,
88                                    "type": "function",
89                                    "function": {
90                                        "name": name,
91                                        "arguments": arguments
92                                    }
93                                })),
94                                _ => None,
95                            })
96                            .collect();
97
98                        if tool_calls.is_empty() {
99                            json!({"role": "assistant", "content": text})
100                        } else {
101                            // For assistant with tool calls, content should be empty string or the text
102                            json!({
103                                "role": "assistant",
104                                "content": if text.is_empty() { "".to_string() } else { text },
105                                "tool_calls": tool_calls
106                            })
107                        }
108                    }
109                    _ => {
110                        // System or User message
111                        let text: String = msg
112                            .content
113                            .iter()
114                            .filter_map(|p| match p {
115                                ContentPart::Text { text } => Some(text.clone()),
116                                _ => None,
117                            })
118                            .collect::<Vec<_>>()
119                            .join("\n");
120
121                        json!({"role": role, "content": text})
122                    }
123                }
124            })
125            .collect()
126    }
127
128    fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
129        tools
130            .iter()
131            .map(|t| {
132                json!({
133                    "type": "function",
134                    "function": {
135                        "name": t.name,
136                        "description": t.description,
137                        "parameters": t.parameters
138                    }
139                })
140            })
141            .collect()
142    }
143
144    fn parse_error_body(text: &str) -> Option<String> {
145        let err = serde_json::from_str::<OpenRouterError>(text).ok()?;
146        let mut message = format!("OpenRouter API error: {}", err.error.message);
147        if let Some(code) = err.error.code {
148            message.push_str(&format!(" (code: {code})"));
149        }
150        Some(message)
151    }
152}
153
154#[derive(Debug, Deserialize)]
155struct OpenRouterResponse {
156    #[serde(default)]
157    id: String,
158    // provider and model fields from OpenRouter
159    #[serde(default)]
160    provider: Option<String>,
161    #[serde(default)]
162    model: Option<String>,
163    choices: Vec<OpenRouterChoice>,
164    #[serde(default)]
165    usage: Option<OpenRouterUsage>,
166}
167
168#[derive(Debug, Deserialize)]
169struct OpenRouterChoice {
170    message: OpenRouterMessage,
171    #[serde(default)]
172    finish_reason: Option<String>,
173    // OpenRouter adds native_finish_reason
174    #[serde(default)]
175    native_finish_reason: Option<String>,
176}
177
178#[derive(Debug, Deserialize)]
179struct OpenRouterMessage {
180    role: String,
181    #[serde(default)]
182    content: Option<String>,
183    #[serde(default)]
184    tool_calls: Option<Vec<OpenRouterToolCall>>,
185    // Extended fields from thinking models like Kimi K2.5
186    #[serde(default)]
187    reasoning: Option<String>,
188    #[serde(default)]
189    reasoning_details: Option<Vec<Value>>,
190    #[serde(default)]
191    refusal: Option<String>,
192}
193
194#[derive(Debug, Deserialize)]
195struct OpenRouterToolCall {
196    id: String,
197    #[serde(rename = "type")]
198    #[allow(dead_code)]
199    call_type: String,
200    function: OpenRouterFunction,
201    #[serde(default)]
202    #[allow(dead_code)]
203    index: Option<usize>,
204}
205
206#[derive(Debug, Deserialize)]
207struct OpenRouterFunction {
208    name: String,
209    arguments: String,
210}
211
212#[derive(Debug, Deserialize)]
213struct OpenRouterUsage {
214    #[serde(default)]
215    prompt_tokens: usize,
216    #[serde(default)]
217    completion_tokens: usize,
218    #[serde(default)]
219    total_tokens: usize,
220}
221
222#[derive(Debug, Deserialize)]
223struct OpenRouterError {
224    error: OpenRouterErrorDetail,
225}
226
227#[derive(Debug, Deserialize)]
228struct OpenRouterErrorDetail {
229    message: String,
230    #[serde(default)]
231    code: Option<Value>,
232}
233
234#[async_trait]
235impl Provider for OpenRouterProvider {
236    fn name(&self) -> &str {
237        "openrouter"
238    }
239
240    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
241        // Fetch models from OpenRouter API
242        let response = self
243            .client
244            .get(format!("{}/models", self.base_url))
245            .header("Authorization", format!("Bearer {}", self.api_key))
246            .send()
247            .await
248            .context("Failed to fetch models")?;
249
250        if !response.status().is_success() {
251            return Ok(vec![]); // Return empty on error
252        }
253
254        #[derive(Deserialize)]
255        struct ModelsResponse {
256            data: Vec<ModelData>,
257        }
258
259        #[derive(Deserialize)]
260        struct ModelData {
261            id: String,
262            #[serde(default)]
263            name: Option<String>,
264            #[serde(default)]
265            context_length: Option<usize>,
266        }
267
268        let models: ModelsResponse = response
269            .json()
270            .await
271            .unwrap_or(ModelsResponse { data: vec![] });
272
273        Ok(models
274            .data
275            .into_iter()
276            .map(|m| ModelInfo {
277                id: m.id.clone(),
278                name: m.name.unwrap_or_else(|| m.id.clone()),
279                provider: "openrouter".to_string(),
280                context_window: m.context_length.unwrap_or(128_000),
281                max_output_tokens: Some(16_384),
282                supports_vision: false,
283                supports_tools: true,
284                supports_streaming: true,
285                input_cost_per_million: None,
286                output_cost_per_million: None,
287            })
288            .collect())
289    }
290
291    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
292        let messages = Self::convert_messages(&request.messages);
293        let tools = Self::convert_tools(&request.tools);
294
295        // Build request body
296        let mut body = json!({
297            "model": request.model,
298            "messages": messages,
299        });
300
301        if !tools.is_empty() {
302            body["tools"] = json!(tools);
303        }
304        if let Some(temp) = request.temperature {
305            body["temperature"] = json!(temp);
306        }
307        if let Some(max) = request.max_tokens {
308            body["max_tokens"] = json!(max);
309        }
310
311        tracing::debug!(
312            "OpenRouter request: {}",
313            serde_json::to_string_pretty(&body).unwrap_or_default()
314        );
315
316        let response = self
317            .client
318            .post(format!("{}/chat/completions", self.base_url))
319            .header("Authorization", format!("Bearer {}", self.api_key))
320            .header("Content-Type", "application/json")
321            .header("HTTP-Referer", "https://codetether.run")
322            .header("X-Title", "CodeTether Agent")
323            .json(&body)
324            .send()
325            .await
326            .context("Failed to send request")?;
327
328        let status = response.status();
329        let text = response.text().await.context("Failed to read response")?;
330
331        if let Some(error_message) = Self::parse_error_body(&text) {
332            anyhow::bail!(error_message);
333        }
334
335        if !status.is_success() {
336            anyhow::bail!("OpenRouter API error: {} {}", status, text);
337        }
338
339        tracing::debug!(
340            "OpenRouter response: {}",
341            util::truncate_bytes_safe(&text, 500)
342        );
343
344        let response: OpenRouterResponse = serde_json::from_str(&text).context(format!(
345            "Failed to parse response: {}",
346            util::truncate_bytes_safe(&text, 200)
347        ))?;
348
349        // Log response metadata for debugging
350        tracing::debug!(
351            response_id = %response.id,
352            provider = ?response.provider,
353            model = ?response.model,
354            "Received OpenRouter response"
355        );
356
357        let choice = response
358            .choices
359            .first()
360            .ok_or_else(|| anyhow::anyhow!("No choices"))?;
361
362        // Log native finish reason if present
363        if let Some(ref native_reason) = choice.native_finish_reason {
364            tracing::debug!(native_finish_reason = %native_reason, "OpenRouter native finish reason");
365        }
366
367        // Log reasoning content if present (e.g., Kimi K2 models)
368        if let Some(ref reasoning) = choice.message.reasoning
369            && !reasoning.is_empty()
370        {
371            tracing::info!(
372                reasoning_len = reasoning.len(),
373                "Model reasoning content received"
374            );
375            tracing::debug!(
376                reasoning = %reasoning,
377                "Full model reasoning"
378            );
379        }
380        if let Some(ref details) = choice.message.reasoning_details
381            && !details.is_empty()
382        {
383            tracing::debug!(
384                reasoning_details = ?details,
385                "Model reasoning details"
386            );
387        }
388
389        let mut content = Vec::new();
390        let mut has_tool_calls = false;
391
392        // Add text content if present
393        if let Some(text) = &choice.message.content
394            && !text.is_empty()
395        {
396            content.push(ContentPart::Text { text: text.clone() });
397        }
398
399        // Log message role for debugging
400        tracing::debug!(message_role = %choice.message.role, "OpenRouter message role");
401
402        // Log refusal if present (model declined to respond)
403        if let Some(ref refusal) = choice.message.refusal {
404            tracing::warn!(refusal = %refusal, "Model refused to respond");
405        }
406
407        // Add tool calls if present
408        if let Some(tool_calls) = &choice.message.tool_calls {
409            has_tool_calls = !tool_calls.is_empty();
410            for tc in tool_calls {
411                // Log tool call details (uses call_type and index fields)
412                tracing::debug!(
413                    tool_call_id = %tc.id,
414                    call_type = %tc.call_type,
415                    index = ?tc.index,
416                    function_name = %tc.function.name,
417                    "Processing OpenRouter tool call"
418                );
419                content.push(ContentPart::ToolCall {
420                    id: tc.id.clone(),
421                    name: tc.function.name.clone(),
422                    arguments: tc.function.arguments.clone(),
423                    thought_signature: None,
424                });
425            }
426        }
427
428        // Determine finish reason
429        let finish_reason = if has_tool_calls {
430            FinishReason::ToolCalls
431        } else {
432            match choice.finish_reason.as_deref() {
433                Some("stop") => FinishReason::Stop,
434                Some("length") => FinishReason::Length,
435                Some("tool_calls") => FinishReason::ToolCalls,
436                Some("content_filter") => FinishReason::ContentFilter,
437                _ => FinishReason::Stop,
438            }
439        };
440
441        Ok(CompletionResponse {
442            message: Message {
443                role: Role::Assistant,
444                content,
445            },
446            usage: Usage {
447                prompt_tokens: response
448                    .usage
449                    .as_ref()
450                    .map(|u| u.prompt_tokens)
451                    .unwrap_or(0),
452                completion_tokens: response
453                    .usage
454                    .as_ref()
455                    .map(|u| u.completion_tokens)
456                    .unwrap_or(0),
457                total_tokens: response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0),
458                ..Default::default()
459            },
460            finish_reason,
461        })
462    }
463
464    async fn complete_stream(
465        &self,
466        request: CompletionRequest,
467    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
468        use futures::StreamExt;
469
470        let messages = Self::convert_messages(&request.messages);
471        let tools = Self::convert_tools(&request.tools);
472
473        let mut body = json!({
474            "model": request.model,
475            "messages": messages,
476            "stream": true,
477        });
478        if !tools.is_empty() {
479            body["tools"] = json!(tools);
480        }
481        if let Some(temp) = request.temperature {
482            body["temperature"] = json!(temp);
483        }
484        if let Some(max) = request.max_tokens {
485            body["max_tokens"] = json!(max);
486        }
487
488        tracing::debug!(
489            provider = "openrouter",
490            model = %request.model,
491            message_count = request.messages.len(),
492            "Starting streaming completion request"
493        );
494
495        let response = self
496            .client
497            .post(format!("{}/chat/completions", self.base_url))
498            .header("Authorization", format!("Bearer {}", self.api_key))
499            .header("Content-Type", "application/json")
500            .header("HTTP-Referer", "https://codetether.run")
501            .header("X-Title", "CodeTether Agent")
502            .json(&body)
503            .send()
504            .await
505            .context("Failed to send streaming request to OpenRouter")?;
506
507        if !response.status().is_success() {
508            let status = response.status();
509            let text = response.text().await.unwrap_or_default();
510            if let Some(error_message) = Self::parse_error_body(&text) {
511                anyhow::bail!(error_message);
512            }
513            anyhow::bail!("OpenRouter streaming error: {} {}", status, text);
514        }
515
516        let stream = response.bytes_stream();
517        let mut buffer = String::new();
518
519        Ok(stream
520            .flat_map(move |chunk_result| {
521                let mut chunks: Vec<StreamChunk> = Vec::new();
522                match chunk_result {
523                    Ok(bytes) => {
524                        let text = String::from_utf8_lossy(&bytes);
525                        buffer.push_str(&text);
526
527                        while let Some(line_end) = buffer.find('\n') {
528                            let line = buffer[..line_end].trim().to_string();
529                            buffer = buffer[line_end + 1..].to_string();
530
531                            if line.is_empty() {
532                                continue;
533                            }
534
535                            if line == "data: [DONE]" {
536                                chunks.push(StreamChunk::Done { usage: None });
537                                continue;
538                            }
539
540                            if let Some(data) = line.strip_prefix("data: ") {
541                                if let Ok(parsed) =
542                                    serde_json::from_str::<OpenRouterStreamResponse>(data)
543                                {
544                                    if let Some(choice) = parsed.choices.first() {
545                                        if let Some(ref content) = choice.delta.content {
546                                            if !content.is_empty() {
547                                                chunks.push(StreamChunk::Text(content.clone()));
548                                            }
549                                        }
550                                        if let Some(ref tool_calls) = choice.delta.tool_calls {
551                                            for tc in tool_calls {
552                                                if let Some(ref func) = tc.function {
553                                                    if let Some(ref name) = func.name {
554                                                        let id = tc.id.clone().unwrap_or_default();
555                                                        chunks.push(StreamChunk::ToolCallStart {
556                                                            id: id.clone(),
557                                                            name: name.clone(),
558                                                        });
559                                                    }
560                                                    if let Some(ref args) = func.arguments {
561                                                        let id = tc.id.clone().unwrap_or_default();
562                                                        if !args.is_empty() {
563                                                            chunks.push(
564                                                                StreamChunk::ToolCallDelta {
565                                                                    id,
566                                                                    arguments_delta: args.clone(),
567                                                                },
568                                                            );
569                                                        }
570                                                    }
571                                                }
572                                            }
573                                        }
574                                        if choice.finish_reason.as_deref() == Some("stop")
575                                            || choice.finish_reason.as_deref() == Some("tool_calls")
576                                        {
577                                            let usage = parsed.usage.map(|u| Usage {
578                                                prompt_tokens: u.prompt_tokens,
579                                                completion_tokens: u.completion_tokens,
580                                                total_tokens: u.total_tokens,
581                                                ..Default::default()
582                                            });
583                                            chunks.push(StreamChunk::Done { usage });
584                                        }
585                                    }
586                                }
587                            }
588                        }
589                    }
590                    Err(e) => {
591                        chunks.push(StreamChunk::Error(e.to_string()));
592                    }
593                }
594                futures::stream::iter(chunks)
595            })
596            .boxed())
597    }
598}
599
600/// Streaming SSE delta types for OpenRouter (OpenAI-compatible)
601#[derive(Debug, Deserialize)]
602struct OpenRouterStreamResponse {
603    #[serde(default)]
604    choices: Vec<OpenRouterStreamChoice>,
605    #[serde(default)]
606    usage: Option<OpenRouterUsage>,
607}
608
609#[derive(Debug, Deserialize)]
610struct OpenRouterStreamChoice {
611    #[serde(default)]
612    delta: OpenRouterStreamDelta,
613    #[serde(default)]
614    finish_reason: Option<String>,
615}
616
617#[derive(Debug, Default, Deserialize)]
618struct OpenRouterStreamDelta {
619    #[serde(default)]
620    content: Option<String>,
621    #[serde(default)]
622    tool_calls: Option<Vec<OpenRouterStreamToolCall>>,
623}
624
625#[derive(Debug, Deserialize)]
626struct OpenRouterStreamToolCall {
627    #[serde(default)]
628    id: Option<String>,
629    #[serde(default)]
630    function: Option<OpenRouterStreamFunction>,
631}
632
633#[derive(Debug, Deserialize)]
634struct OpenRouterStreamFunction {
635    #[serde(default)]
636    name: Option<String>,
637    #[serde(default)]
638    arguments: Option<String>,
639}
640
641#[cfg(test)]
642mod tests {
643    use super::OpenRouterProvider;
644
645    #[test]
646    fn parses_embedded_error_body() {
647        let body = r#"{"error":{"message":"Internal Server Error","code":500}}"#;
648        let message = OpenRouterProvider::parse_error_body(body);
649
650        assert_eq!(
651            message.as_deref(),
652            Some("OpenRouter API error: Internal Server Error (code: 500)")
653        );
654    }
655
656    #[test]
657    fn ignores_success_body_without_error_envelope() {
658        let body = r#"{
659            "id":"chatcmpl-123",
660            "choices":[{
661                "message":{"role":"assistant","content":"ok"},
662                "finish_reason":"stop"
663            }]
664        }"#;
665
666        assert_eq!(OpenRouterProvider::parse_error_body(body), None);
667    }
668}