Skip to main content

codetether_agent/provider/
openrouter.rs

1//! OpenRouter provider implementation using raw HTTP
2//!
3//! This provider uses reqwest directly instead of async_openai to handle
4//! OpenRouter's extended response formats (like Kimi's reasoning fields).
5
6use super::{
7    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
8    Role, StreamChunk, ToolDefinition, Usage,
9};
10use anyhow::{Context, Result};
11use async_trait::async_trait;
12use reqwest::Client;
13use serde::Deserialize;
14use serde_json::{Value, json};
15
16pub struct OpenRouterProvider {
17    client: Client,
18    api_key: String,
19    base_url: String,
20}
21
22impl OpenRouterProvider {
23    pub fn new(api_key: String) -> Result<Self> {
24        Ok(Self {
25            client: Client::new(),
26            api_key,
27            base_url: "https://openrouter.ai/api/v1".to_string(),
28        })
29    }
30
31    fn convert_messages(messages: &[Message]) -> Vec<Value> {
32        messages
33            .iter()
34            .map(|msg| {
35                let role = match msg.role {
36                    Role::System => "system",
37                    Role::User => "user",
38                    Role::Assistant => "assistant",
39                    Role::Tool => "tool",
40                };
41
42                match msg.role {
43                    Role::Tool => {
44                        // Tool result message
45                        if let Some(ContentPart::ToolResult {
46                            tool_call_id,
47                            content,
48                        }) = msg.content.first()
49                        {
50                            json!({
51                                "role": "tool",
52                                "tool_call_id": tool_call_id,
53                                "content": content
54                            })
55                        } else {
56                            json!({"role": role, "content": ""})
57                        }
58                    }
59                    Role::Assistant => {
60                        // Assistant message - may have tool calls
61                        let text: String = msg
62                            .content
63                            .iter()
64                            .filter_map(|p| match p {
65                                ContentPart::Text { text } => Some(text.clone()),
66                                _ => None,
67                            })
68                            .collect::<Vec<_>>()
69                            .join("");
70
71                        let tool_calls: Vec<Value> = msg
72                            .content
73                            .iter()
74                            .filter_map(|p| match p {
75                                ContentPart::ToolCall {
76                                    id,
77                                    name,
78                                    arguments,
79                                } => Some(json!({
80                                    "id": id,
81                                    "type": "function",
82                                    "function": {
83                                        "name": name,
84                                        "arguments": arguments
85                                    }
86                                })),
87                                _ => None,
88                            })
89                            .collect();
90
91                        if tool_calls.is_empty() {
92                            json!({"role": "assistant", "content": text})
93                        } else {
94                            // For assistant with tool calls, content should be empty string or the text
95                            json!({
96                                "role": "assistant",
97                                "content": if text.is_empty() { "".to_string() } else { text },
98                                "tool_calls": tool_calls
99                            })
100                        }
101                    }
102                    _ => {
103                        // System or User message
104                        let text: String = msg
105                            .content
106                            .iter()
107                            .filter_map(|p| match p {
108                                ContentPart::Text { text } => Some(text.clone()),
109                                _ => None,
110                            })
111                            .collect::<Vec<_>>()
112                            .join("\n");
113
114                        json!({"role": role, "content": text})
115                    }
116                }
117            })
118            .collect()
119    }
120
121    fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
122        tools
123            .iter()
124            .map(|t| {
125                json!({
126                    "type": "function",
127                    "function": {
128                        "name": t.name,
129                        "description": t.description,
130                        "parameters": t.parameters
131                    }
132                })
133            })
134            .collect()
135    }
136}
137
138#[derive(Debug, Deserialize)]
139struct OpenRouterResponse {
140    #[serde(default)]
141    id: String,
142    // provider and model fields from OpenRouter
143    #[serde(default)]
144    provider: Option<String>,
145    #[serde(default)]
146    model: Option<String>,
147    choices: Vec<OpenRouterChoice>,
148    #[serde(default)]
149    usage: Option<OpenRouterUsage>,
150}
151
152#[derive(Debug, Deserialize)]
153struct OpenRouterChoice {
154    message: OpenRouterMessage,
155    #[serde(default)]
156    finish_reason: Option<String>,
157    // OpenRouter adds native_finish_reason
158    #[serde(default)]
159    native_finish_reason: Option<String>,
160}
161
162#[derive(Debug, Deserialize)]
163struct OpenRouterMessage {
164    role: String,
165    #[serde(default)]
166    content: Option<String>,
167    #[serde(default)]
168    tool_calls: Option<Vec<OpenRouterToolCall>>,
169    // Extended fields from thinking models like Kimi K2.5
170    #[serde(default)]
171    reasoning: Option<String>,
172    #[serde(default)]
173    reasoning_details: Option<Vec<Value>>,
174    #[serde(default)]
175    refusal: Option<String>,
176}
177
178#[derive(Debug, Deserialize)]
179struct OpenRouterToolCall {
180    id: String,
181    #[serde(rename = "type")]
182    #[allow(dead_code)]
183    call_type: String,
184    function: OpenRouterFunction,
185    #[serde(default)]
186    #[allow(dead_code)]
187    index: Option<usize>,
188}
189
190#[derive(Debug, Deserialize)]
191struct OpenRouterFunction {
192    name: String,
193    arguments: String,
194}
195
196#[derive(Debug, Deserialize)]
197struct OpenRouterUsage {
198    #[serde(default)]
199    prompt_tokens: usize,
200    #[serde(default)]
201    completion_tokens: usize,
202    #[serde(default)]
203    total_tokens: usize,
204}
205
206#[derive(Debug, Deserialize)]
207struct OpenRouterError {
208    error: OpenRouterErrorDetail,
209}
210
211#[derive(Debug, Deserialize)]
212struct OpenRouterErrorDetail {
213    message: String,
214    #[serde(default)]
215    code: Option<i32>,
216}
217
218#[async_trait]
219impl Provider for OpenRouterProvider {
220    fn name(&self) -> &str {
221        "openrouter"
222    }
223
224    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
225        // Fetch models from OpenRouter API
226        let response = self
227            .client
228            .get(format!("{}/models", self.base_url))
229            .header("Authorization", format!("Bearer {}", self.api_key))
230            .send()
231            .await
232            .context("Failed to fetch models")?;
233
234        if !response.status().is_success() {
235            return Ok(vec![]); // Return empty on error
236        }
237
238        #[derive(Deserialize)]
239        struct ModelsResponse {
240            data: Vec<ModelData>,
241        }
242
243        #[derive(Deserialize)]
244        struct ModelData {
245            id: String,
246            #[serde(default)]
247            name: Option<String>,
248            #[serde(default)]
249            context_length: Option<usize>,
250        }
251
252        let models: ModelsResponse = response
253            .json()
254            .await
255            .unwrap_or(ModelsResponse { data: vec![] });
256
257        Ok(models
258            .data
259            .into_iter()
260            .map(|m| ModelInfo {
261                id: m.id.clone(),
262                name: m.name.unwrap_or_else(|| m.id.clone()),
263                provider: "openrouter".to_string(),
264                context_window: m.context_length.unwrap_or(128_000),
265                max_output_tokens: Some(16_384),
266                supports_vision: false,
267                supports_tools: true,
268                supports_streaming: true,
269                input_cost_per_million: None,
270                output_cost_per_million: None,
271            })
272            .collect())
273    }
274
275    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
276        let messages = Self::convert_messages(&request.messages);
277        let tools = Self::convert_tools(&request.tools);
278
279        // Build request body
280        let mut body = json!({
281            "model": request.model,
282            "messages": messages,
283        });
284
285        if !tools.is_empty() {
286            body["tools"] = json!(tools);
287        }
288        if let Some(temp) = request.temperature {
289            body["temperature"] = json!(temp);
290        }
291        if let Some(max) = request.max_tokens {
292            body["max_tokens"] = json!(max);
293        }
294
295        tracing::debug!(
296            "OpenRouter request: {}",
297            serde_json::to_string_pretty(&body).unwrap_or_default()
298        );
299
300        let response = self
301            .client
302            .post(format!("{}/chat/completions", self.base_url))
303            .header("Authorization", format!("Bearer {}", self.api_key))
304            .header("Content-Type", "application/json")
305            .header("HTTP-Referer", "https://codetether.run")
306            .header("X-Title", "CodeTether Agent")
307            .json(&body)
308            .send()
309            .await
310            .context("Failed to send request")?;
311
312        let status = response.status();
313        let text = response.text().await.context("Failed to read response")?;
314
315        if !status.is_success() {
316            // Try to parse as error response
317            if let Ok(err) = serde_json::from_str::<OpenRouterError>(&text) {
318                anyhow::bail!(
319                    "OpenRouter API error: {} (code: {:?})",
320                    err.error.message,
321                    err.error.code
322                );
323            }
324            anyhow::bail!("OpenRouter API error: {} {}", status, text);
325        }
326
327        tracing::debug!("OpenRouter response: {}", &text[..text.len().min(500)]);
328
329        let response: OpenRouterResponse = serde_json::from_str(&text).context(format!(
330            "Failed to parse response: {}",
331            &text[..text.len().min(200)]
332        ))?;
333
334        // Log response metadata for debugging
335        tracing::debug!(
336            response_id = %response.id,
337            provider = ?response.provider,
338            model = ?response.model,
339            "Received OpenRouter response"
340        );
341
342        let choice = response
343            .choices
344            .first()
345            .ok_or_else(|| anyhow::anyhow!("No choices"))?;
346
347        // Log native finish reason if present
348        if let Some(ref native_reason) = choice.native_finish_reason {
349            tracing::debug!(native_finish_reason = %native_reason, "OpenRouter native finish reason");
350        }
351
352        // Log reasoning content if present (e.g., Kimi K2 models)
353        if let Some(ref reasoning) = choice.message.reasoning {
354            if !reasoning.is_empty() {
355                tracing::info!(
356                    reasoning_len = reasoning.len(),
357                    "Model reasoning content received"
358                );
359                tracing::debug!(
360                    reasoning = %reasoning,
361                    "Full model reasoning"
362                );
363            }
364        }
365        if let Some(ref details) = choice.message.reasoning_details {
366            if !details.is_empty() {
367                tracing::debug!(
368                    reasoning_details = ?details,
369                    "Model reasoning details"
370                );
371            }
372        }
373
374        let mut content = Vec::new();
375        let mut has_tool_calls = false;
376
377        // Add text content if present
378        if let Some(text) = &choice.message.content {
379            if !text.is_empty() {
380                content.push(ContentPart::Text { text: text.clone() });
381            }
382        }
383
384        // Log message role for debugging
385        tracing::debug!(message_role = %choice.message.role, "OpenRouter message role");
386
387        // Log refusal if present (model declined to respond)
388        if let Some(ref refusal) = choice.message.refusal {
389            tracing::warn!(refusal = %refusal, "Model refused to respond");
390        }
391
392        // Add tool calls if present
393        if let Some(tool_calls) = &choice.message.tool_calls {
394            has_tool_calls = !tool_calls.is_empty();
395            for tc in tool_calls {
396                // Log tool call details (uses call_type and index fields)
397                tracing::debug!(
398                    tool_call_id = %tc.id,
399                    call_type = %tc.call_type,
400                    index = ?tc.index,
401                    function_name = %tc.function.name,
402                    "Processing OpenRouter tool call"
403                );
404                content.push(ContentPart::ToolCall {
405                    id: tc.id.clone(),
406                    name: tc.function.name.clone(),
407                    arguments: tc.function.arguments.clone(),
408                });
409            }
410        }
411
412        // Determine finish reason
413        let finish_reason = if has_tool_calls {
414            FinishReason::ToolCalls
415        } else {
416            match choice.finish_reason.as_deref() {
417                Some("stop") => FinishReason::Stop,
418                Some("length") => FinishReason::Length,
419                Some("tool_calls") => FinishReason::ToolCalls,
420                Some("content_filter") => FinishReason::ContentFilter,
421                _ => FinishReason::Stop,
422            }
423        };
424
425        Ok(CompletionResponse {
426            message: Message {
427                role: Role::Assistant,
428                content,
429            },
430            usage: Usage {
431                prompt_tokens: response
432                    .usage
433                    .as_ref()
434                    .map(|u| u.prompt_tokens)
435                    .unwrap_or(0),
436                completion_tokens: response
437                    .usage
438                    .as_ref()
439                    .map(|u| u.completion_tokens)
440                    .unwrap_or(0),
441                total_tokens: response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0),
442                ..Default::default()
443            },
444            finish_reason,
445        })
446    }
447
448    async fn complete_stream(
449        &self,
450        request: CompletionRequest,
451    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
452        tracing::debug!(
453            provider = "openrouter",
454            model = %request.model,
455            message_count = request.messages.len(),
456            "Starting streaming completion request (falling back to non-streaming)"
457        );
458
459        // For now, fall back to non-streaming
460        let response = self.complete(request).await?;
461        let text = response
462            .message
463            .content
464            .iter()
465            .filter_map(|p| match p {
466                ContentPart::Text { text } => Some(text.clone()),
467                _ => None,
468            })
469            .collect::<Vec<_>>()
470            .join("");
471
472        Ok(Box::pin(futures::stream::once(async move {
473            StreamChunk::Text(text)
474        })))
475    }
476}