Skip to main content

codetether_agent/provider/
openrouter.rs

1//! OpenRouter provider implementation using raw HTTP
2//!
3//! This provider uses reqwest directly instead of async_openai to handle
4//! OpenRouter's extended response formats (like Kimi's reasoning fields).
5
6use super::{
7    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
8    Role, StreamChunk, ToolDefinition, Usage,
9};
10use anyhow::{Context, Result};
11use async_trait::async_trait;
12use reqwest::Client;
13use serde::Deserialize;
14use serde_json::{json, Value};
15
16pub struct OpenRouterProvider {
17    client: Client,
18    api_key: String,
19    base_url: String,
20}
21
22impl OpenRouterProvider {
23    pub fn new(api_key: String) -> Result<Self> {
24        Ok(Self {
25            client: Client::new(),
26            api_key,
27            base_url: "https://openrouter.ai/api/v1".to_string(),
28        })
29    }
30
31    fn convert_messages(messages: &[Message]) -> Vec<Value> {
32        messages
33            .iter()
34            .map(|msg| {
35                let role = match msg.role {
36                    Role::System => "system",
37                    Role::User => "user",
38                    Role::Assistant => "assistant",
39                    Role::Tool => "tool",
40                };
41
42                match msg.role {
43                    Role::Tool => {
44                        // Tool result message
45                        if let Some(ContentPart::ToolResult { tool_call_id, content }) = msg.content.first() {
46                            json!({
47                                "role": "tool",
48                                "tool_call_id": tool_call_id,
49                                "content": content
50                            })
51                        } else {
52                            json!({"role": role, "content": ""})
53                        }
54                    }
55                    Role::Assistant => {
56                        // Assistant message - may have tool calls
57                        let text: String = msg.content.iter()
58                            .filter_map(|p| match p {
59                                ContentPart::Text { text } => Some(text.clone()),
60                                _ => None,
61                            })
62                            .collect::<Vec<_>>()
63                            .join("");
64
65                        let tool_calls: Vec<Value> = msg.content.iter()
66                            .filter_map(|p| match p {
67                                ContentPart::ToolCall { id, name, arguments } => Some(json!({
68                                    "id": id,
69                                    "type": "function",
70                                    "function": {
71                                        "name": name,
72                                        "arguments": arguments
73                                    }
74                                })),
75                                _ => None,
76                            })
77                            .collect();
78
79                        if tool_calls.is_empty() {
80                            json!({"role": "assistant", "content": text})
81                        } else {
82                            // For assistant with tool calls, content should be empty string or the text
83                            json!({
84                                "role": "assistant",
85                                "content": if text.is_empty() { "".to_string() } else { text },
86                                "tool_calls": tool_calls
87                            })
88                        }
89                    }
90                    _ => {
91                        // System or User message
92                        let text: String = msg.content.iter()
93                            .filter_map(|p| match p {
94                                ContentPart::Text { text } => Some(text.clone()),
95                                _ => None,
96                            })
97                            .collect::<Vec<_>>()
98                            .join("\n");
99
100                        json!({"role": role, "content": text})
101                    }
102                }
103            })
104            .collect()
105    }
106
107    fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
108        tools
109            .iter()
110            .map(|t| {
111                json!({
112                    "type": "function",
113                    "function": {
114                        "name": t.name,
115                        "description": t.description,
116                        "parameters": t.parameters
117                    }
118                })
119            })
120            .collect()
121    }
122}
123
124#[derive(Debug, Deserialize)]
125struct OpenRouterResponse {
126    #[serde(default)]
127    id: String,
128    // provider and model fields from OpenRouter
129    #[serde(default)]
130    provider: Option<String>,
131    #[serde(default)]
132    model: Option<String>,
133    choices: Vec<OpenRouterChoice>,
134    #[serde(default)]
135    usage: Option<OpenRouterUsage>,
136}
137
138#[derive(Debug, Deserialize)]
139struct OpenRouterChoice {
140    message: OpenRouterMessage,
141    #[serde(default)]
142    finish_reason: Option<String>,
143    // OpenRouter adds native_finish_reason
144    #[serde(default)]
145    native_finish_reason: Option<String>,
146}
147
148#[derive(Debug, Deserialize)]
149struct OpenRouterMessage {
150    role: String,
151    #[serde(default)]
152    content: Option<String>,
153    #[serde(default)]
154    tool_calls: Option<Vec<OpenRouterToolCall>>,
155    // Extended fields from thinking models like Kimi K2.5
156    #[serde(default)]
157    reasoning: Option<String>,
158    #[serde(default)]
159    reasoning_details: Option<Vec<Value>>,
160    #[serde(default)]
161    refusal: Option<String>,
162}
163
164#[derive(Debug, Deserialize)]
165struct OpenRouterToolCall {
166    id: String,
167    #[serde(rename = "type")]
168    #[allow(dead_code)]
169    call_type: String,
170    function: OpenRouterFunction,
171    #[serde(default)]
172    #[allow(dead_code)]
173    index: Option<usize>,
174}
175
176#[derive(Debug, Deserialize)]
177struct OpenRouterFunction {
178    name: String,
179    arguments: String,
180}
181
182#[derive(Debug, Deserialize)]
183struct OpenRouterUsage {
184    #[serde(default)]
185    prompt_tokens: usize,
186    #[serde(default)]
187    completion_tokens: usize,
188    #[serde(default)]
189    total_tokens: usize,
190}
191
192#[derive(Debug, Deserialize)]
193struct OpenRouterError {
194    error: OpenRouterErrorDetail,
195}
196
197#[derive(Debug, Deserialize)]
198struct OpenRouterErrorDetail {
199    message: String,
200    #[serde(default)]
201    code: Option<i32>,
202}
203
204#[async_trait]
205impl Provider for OpenRouterProvider {
206    fn name(&self) -> &str {
207        "openrouter"
208    }
209
210    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
211        // Fetch models from OpenRouter API
212        let response = self
213            .client
214            .get(format!("{}/models", self.base_url))
215            .header("Authorization", format!("Bearer {}", self.api_key))
216            .send()
217            .await
218            .context("Failed to fetch models")?;
219
220        if !response.status().is_success() {
221            return Ok(vec![]); // Return empty on error
222        }
223
224        #[derive(Deserialize)]
225        struct ModelsResponse {
226            data: Vec<ModelData>,
227        }
228        
229        #[derive(Deserialize)]
230        struct ModelData {
231            id: String,
232            #[serde(default)]
233            name: Option<String>,
234            #[serde(default)]
235            context_length: Option<usize>,
236        }
237
238        let models: ModelsResponse = response.json().await.unwrap_or(ModelsResponse { data: vec![] });
239
240        Ok(models.data.into_iter().map(|m| ModelInfo {
241            id: m.id.clone(),
242            name: m.name.unwrap_or_else(|| m.id.clone()),
243            provider: "openrouter".to_string(),
244            context_window: m.context_length.unwrap_or(128_000),
245            max_output_tokens: Some(16_384),
246            supports_vision: false,
247            supports_tools: true,
248            supports_streaming: true,
249            input_cost_per_million: None,
250            output_cost_per_million: None,
251        }).collect())
252    }
253
254    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
255        let messages = Self::convert_messages(&request.messages);
256        let tools = Self::convert_tools(&request.tools);
257
258        // Build request body
259        let mut body = json!({
260            "model": request.model,
261            "messages": messages,
262        });
263
264        if !tools.is_empty() {
265            body["tools"] = json!(tools);
266        }
267        if let Some(temp) = request.temperature {
268            body["temperature"] = json!(temp);
269        }
270        if let Some(max) = request.max_tokens {
271            body["max_tokens"] = json!(max);
272        }
273
274        tracing::debug!("OpenRouter request: {}", serde_json::to_string_pretty(&body).unwrap_or_default());
275
276        let response = self
277            .client
278            .post(format!("{}/chat/completions", self.base_url))
279            .header("Authorization", format!("Bearer {}", self.api_key))
280            .header("Content-Type", "application/json")
281            .header("HTTP-Referer", "https://codetether.run")
282            .header("X-Title", "CodeTether Agent")
283            .json(&body)
284            .send()
285            .await
286            .context("Failed to send request")?;
287
288        let status = response.status();
289        let text = response.text().await.context("Failed to read response")?;
290
291        if !status.is_success() {
292            // Try to parse as error response
293            if let Ok(err) = serde_json::from_str::<OpenRouterError>(&text) {
294                anyhow::bail!("OpenRouter API error: {} (code: {:?})", err.error.message, err.error.code);
295            }
296            anyhow::bail!("OpenRouter API error: {} {}", status, text);
297        }
298
299        tracing::debug!("OpenRouter response: {}", &text[..text.len().min(500)]);
300
301        let response: OpenRouterResponse = serde_json::from_str(&text)
302            .context(format!("Failed to parse response: {}", &text[..text.len().min(200)]))?;
303
304        // Log response metadata for debugging
305        tracing::debug!(
306            response_id = %response.id,
307            provider = ?response.provider,
308            model = ?response.model,
309            "Received OpenRouter response"
310        );
311
312        let choice = response.choices.first().ok_or_else(|| anyhow::anyhow!("No choices"))?;
313        
314        // Log native finish reason if present
315        if let Some(ref native_reason) = choice.native_finish_reason {
316            tracing::debug!(native_finish_reason = %native_reason, "OpenRouter native finish reason");
317        }
318
319        // Log reasoning content if present (e.g., Kimi K2 models)
320        if let Some(ref reasoning) = choice.message.reasoning {
321            if !reasoning.is_empty() {
322                tracing::info!(
323                    reasoning_len = reasoning.len(),
324                    "Model reasoning content received"
325                );
326                tracing::debug!(
327                    reasoning = %reasoning,
328                    "Full model reasoning"
329                );
330            }
331        }
332        if let Some(ref details) = choice.message.reasoning_details {
333            if !details.is_empty() {
334                tracing::debug!(
335                    reasoning_details = ?details,
336                    "Model reasoning details"
337                );
338            }
339        }
340
341        let mut content = Vec::new();
342        let mut has_tool_calls = false;
343
344        // Add text content if present
345        if let Some(text) = &choice.message.content {
346            if !text.is_empty() {
347                content.push(ContentPart::Text { text: text.clone() });
348            }
349        }
350
351        // Log message role for debugging
352        tracing::debug!(message_role = %choice.message.role, "OpenRouter message role");
353        
354        // Log refusal if present (model declined to respond)
355        if let Some(ref refusal) = choice.message.refusal {
356            tracing::warn!(refusal = %refusal, "Model refused to respond");
357        }
358
359        // Add tool calls if present
360        if let Some(tool_calls) = &choice.message.tool_calls {
361            has_tool_calls = !tool_calls.is_empty();
362            for tc in tool_calls {
363                // Log tool call details (uses call_type and index fields)
364                tracing::debug!(
365                    tool_call_id = %tc.id,
366                    call_type = %tc.call_type,
367                    index = ?tc.index,
368                    function_name = %tc.function.name,
369                    "Processing OpenRouter tool call"
370                );
371                content.push(ContentPart::ToolCall {
372                    id: tc.id.clone(),
373                    name: tc.function.name.clone(),
374                    arguments: tc.function.arguments.clone(),
375                });
376            }
377        }
378
379        // Determine finish reason
380        let finish_reason = if has_tool_calls {
381            FinishReason::ToolCalls
382        } else {
383            match choice.finish_reason.as_deref() {
384                Some("stop") => FinishReason::Stop,
385                Some("length") => FinishReason::Length,
386                Some("tool_calls") => FinishReason::ToolCalls,
387                Some("content_filter") => FinishReason::ContentFilter,
388                _ => FinishReason::Stop,
389            }
390        };
391
392        Ok(CompletionResponse {
393            message: Message {
394                role: Role::Assistant,
395                content,
396            },
397            usage: Usage {
398                prompt_tokens: response.usage.as_ref().map(|u| u.prompt_tokens).unwrap_or(0),
399                completion_tokens: response.usage.as_ref().map(|u| u.completion_tokens).unwrap_or(0),
400                total_tokens: response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0),
401                ..Default::default()
402            },
403            finish_reason,
404        })
405    }
406
407    async fn complete_stream(
408        &self,
409        request: CompletionRequest,
410    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
411        tracing::debug!(
412            provider = "openrouter",
413            model = %request.model,
414            message_count = request.messages.len(),
415            "Starting streaming completion request (falling back to non-streaming)"
416        );
417        
418        // For now, fall back to non-streaming
419        let response = self.complete(request).await?;
420        let text = response.message.content.iter()
421            .filter_map(|p| match p {
422                ContentPart::Text { text } => Some(text.clone()),
423                _ => None,
424            })
425            .collect::<Vec<_>>()
426            .join("");
427        
428        Ok(Box::pin(futures::stream::once(async move {
429            StreamChunk::Text(text)
430        })))
431    }
432}