Skip to main content

codetether_agent/provider/
openrouter.rs

1//! OpenRouter provider implementation using raw HTTP
2//!
3//! This provider uses reqwest directly instead of async_openai to handle
4//! OpenRouter's extended response formats (like Kimi's reasoning fields).
5
6use super::{
7    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
8    Role, StreamChunk, ToolDefinition, Usage,
9};
10use anyhow::{Context, Result};
11use async_trait::async_trait;
12use reqwest::Client;
13use serde::Deserialize;
14use serde_json::{Value, json};
15
16pub struct OpenRouterProvider {
17    client: Client,
18    api_key: String,
19    base_url: String,
20}
21
22impl OpenRouterProvider {
23    pub fn new(api_key: String) -> Result<Self> {
24        Ok(Self {
25            client: Client::new(),
26            api_key,
27            base_url: "https://openrouter.ai/api/v1".to_string(),
28        })
29    }
30
31    fn convert_messages(messages: &[Message]) -> Vec<Value> {
32        messages
33            .iter()
34            .map(|msg| {
35                let role = match msg.role {
36                    Role::System => "system",
37                    Role::User => "user",
38                    Role::Assistant => "assistant",
39                    Role::Tool => "tool",
40                };
41
42                match msg.role {
43                    Role::Tool => {
44                        // Tool result message
45                        if let Some(ContentPart::ToolResult {
46                            tool_call_id,
47                            content,
48                        }) = msg.content.first()
49                        {
50                            json!({
51                                "role": "tool",
52                                "tool_call_id": tool_call_id,
53                                "content": content
54                            })
55                        } else {
56                            json!({"role": role, "content": ""})
57                        }
58                    }
59                    Role::Assistant => {
60                        // Assistant message - may have tool calls
61                        let text: String = msg
62                            .content
63                            .iter()
64                            .filter_map(|p| match p {
65                                ContentPart::Text { text } => Some(text.clone()),
66                                _ => None,
67                            })
68                            .collect::<Vec<_>>()
69                            .join("");
70
71                        let tool_calls: Vec<Value> = msg
72                            .content
73                            .iter()
74                            .filter_map(|p| match p {
75                                ContentPart::ToolCall {
76                                    id,
77                                    name,
78                                    arguments,
79                                    ..
80                                } => Some(json!({
81                                    "id": id,
82                                    "type": "function",
83                                    "function": {
84                                        "name": name,
85                                        "arguments": arguments
86                                    }
87                                })),
88                                _ => None,
89                            })
90                            .collect();
91
92                        if tool_calls.is_empty() {
93                            json!({"role": "assistant", "content": text})
94                        } else {
95                            // For assistant with tool calls, content should be empty string or the text
96                            json!({
97                                "role": "assistant",
98                                "content": if text.is_empty() { "".to_string() } else { text },
99                                "tool_calls": tool_calls
100                            })
101                        }
102                    }
103                    _ => {
104                        // System or User message
105                        let text: String = msg
106                            .content
107                            .iter()
108                            .filter_map(|p| match p {
109                                ContentPart::Text { text } => Some(text.clone()),
110                                _ => None,
111                            })
112                            .collect::<Vec<_>>()
113                            .join("\n");
114
115                        json!({"role": role, "content": text})
116                    }
117                }
118            })
119            .collect()
120    }
121
122    fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
123        tools
124            .iter()
125            .map(|t| {
126                json!({
127                    "type": "function",
128                    "function": {
129                        "name": t.name,
130                        "description": t.description,
131                        "parameters": t.parameters
132                    }
133                })
134            })
135            .collect()
136    }
137}
138
139#[derive(Debug, Deserialize)]
140struct OpenRouterResponse {
141    #[serde(default)]
142    id: String,
143    // provider and model fields from OpenRouter
144    #[serde(default)]
145    provider: Option<String>,
146    #[serde(default)]
147    model: Option<String>,
148    choices: Vec<OpenRouterChoice>,
149    #[serde(default)]
150    usage: Option<OpenRouterUsage>,
151}
152
153#[derive(Debug, Deserialize)]
154struct OpenRouterChoice {
155    message: OpenRouterMessage,
156    #[serde(default)]
157    finish_reason: Option<String>,
158    // OpenRouter adds native_finish_reason
159    #[serde(default)]
160    native_finish_reason: Option<String>,
161}
162
163#[derive(Debug, Deserialize)]
164struct OpenRouterMessage {
165    role: String,
166    #[serde(default)]
167    content: Option<String>,
168    #[serde(default)]
169    tool_calls: Option<Vec<OpenRouterToolCall>>,
170    // Extended fields from thinking models like Kimi K2.5
171    #[serde(default)]
172    reasoning: Option<String>,
173    #[serde(default)]
174    reasoning_details: Option<Vec<Value>>,
175    #[serde(default)]
176    refusal: Option<String>,
177}
178
179#[derive(Debug, Deserialize)]
180struct OpenRouterToolCall {
181    id: String,
182    #[serde(rename = "type")]
183    #[allow(dead_code)]
184    call_type: String,
185    function: OpenRouterFunction,
186    #[serde(default)]
187    #[allow(dead_code)]
188    index: Option<usize>,
189}
190
191#[derive(Debug, Deserialize)]
192struct OpenRouterFunction {
193    name: String,
194    arguments: String,
195}
196
197#[derive(Debug, Deserialize)]
198struct OpenRouterUsage {
199    #[serde(default)]
200    prompt_tokens: usize,
201    #[serde(default)]
202    completion_tokens: usize,
203    #[serde(default)]
204    total_tokens: usize,
205}
206
207#[derive(Debug, Deserialize)]
208struct OpenRouterError {
209    error: OpenRouterErrorDetail,
210}
211
212#[derive(Debug, Deserialize)]
213struct OpenRouterErrorDetail {
214    message: String,
215    #[serde(default)]
216    code: Option<i32>,
217}
218
219#[async_trait]
220impl Provider for OpenRouterProvider {
221    fn name(&self) -> &str {
222        "openrouter"
223    }
224
225    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
226        // Fetch models from OpenRouter API
227        let response = self
228            .client
229            .get(format!("{}/models", self.base_url))
230            .header("Authorization", format!("Bearer {}", self.api_key))
231            .send()
232            .await
233            .context("Failed to fetch models")?;
234
235        if !response.status().is_success() {
236            return Ok(vec![]); // Return empty on error
237        }
238
239        #[derive(Deserialize)]
240        struct ModelsResponse {
241            data: Vec<ModelData>,
242        }
243
244        #[derive(Deserialize)]
245        struct ModelData {
246            id: String,
247            #[serde(default)]
248            name: Option<String>,
249            #[serde(default)]
250            context_length: Option<usize>,
251        }
252
253        let models: ModelsResponse = response
254            .json()
255            .await
256            .unwrap_or(ModelsResponse { data: vec![] });
257
258        Ok(models
259            .data
260            .into_iter()
261            .map(|m| ModelInfo {
262                id: m.id.clone(),
263                name: m.name.unwrap_or_else(|| m.id.clone()),
264                provider: "openrouter".to_string(),
265                context_window: m.context_length.unwrap_or(128_000),
266                max_output_tokens: Some(16_384),
267                supports_vision: false,
268                supports_tools: true,
269                supports_streaming: true,
270                input_cost_per_million: None,
271                output_cost_per_million: None,
272            })
273            .collect())
274    }
275
276    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
277        let messages = Self::convert_messages(&request.messages);
278        let tools = Self::convert_tools(&request.tools);
279
280        // Build request body
281        let mut body = json!({
282            "model": request.model,
283            "messages": messages,
284        });
285
286        if !tools.is_empty() {
287            body["tools"] = json!(tools);
288        }
289        if let Some(temp) = request.temperature {
290            body["temperature"] = json!(temp);
291        }
292        if let Some(max) = request.max_tokens {
293            body["max_tokens"] = json!(max);
294        }
295
296        tracing::debug!(
297            "OpenRouter request: {}",
298            serde_json::to_string_pretty(&body).unwrap_or_default()
299        );
300
301        let response = self
302            .client
303            .post(format!("{}/chat/completions", self.base_url))
304            .header("Authorization", format!("Bearer {}", self.api_key))
305            .header("Content-Type", "application/json")
306            .header("HTTP-Referer", "https://codetether.run")
307            .header("X-Title", "CodeTether Agent")
308            .json(&body)
309            .send()
310            .await
311            .context("Failed to send request")?;
312
313        let status = response.status();
314        let text = response.text().await.context("Failed to read response")?;
315
316        if !status.is_success() {
317            // Try to parse as error response
318            if let Ok(err) = serde_json::from_str::<OpenRouterError>(&text) {
319                anyhow::bail!(
320                    "OpenRouter API error: {} (code: {:?})",
321                    err.error.message,
322                    err.error.code
323                );
324            }
325            anyhow::bail!("OpenRouter API error: {} {}", status, text);
326        }
327
328        tracing::debug!("OpenRouter response: {}", &text[..text.len().min(500)]);
329
330        let response: OpenRouterResponse = serde_json::from_str(&text).context(format!(
331            "Failed to parse response: {}",
332            &text[..text.len().min(200)]
333        ))?;
334
335        // Log response metadata for debugging
336        tracing::debug!(
337            response_id = %response.id,
338            provider = ?response.provider,
339            model = ?response.model,
340            "Received OpenRouter response"
341        );
342
343        let choice = response
344            .choices
345            .first()
346            .ok_or_else(|| anyhow::anyhow!("No choices"))?;
347
348        // Log native finish reason if present
349        if let Some(ref native_reason) = choice.native_finish_reason {
350            tracing::debug!(native_finish_reason = %native_reason, "OpenRouter native finish reason");
351        }
352
353        // Log reasoning content if present (e.g., Kimi K2 models)
354        if let Some(ref reasoning) = choice.message.reasoning {
355            if !reasoning.is_empty() {
356                tracing::info!(
357                    reasoning_len = reasoning.len(),
358                    "Model reasoning content received"
359                );
360                tracing::debug!(
361                    reasoning = %reasoning,
362                    "Full model reasoning"
363                );
364            }
365        }
366        if let Some(ref details) = choice.message.reasoning_details {
367            if !details.is_empty() {
368                tracing::debug!(
369                    reasoning_details = ?details,
370                    "Model reasoning details"
371                );
372            }
373        }
374
375        let mut content = Vec::new();
376        let mut has_tool_calls = false;
377
378        // Add text content if present
379        if let Some(text) = &choice.message.content {
380            if !text.is_empty() {
381                content.push(ContentPart::Text { text: text.clone() });
382            }
383        }
384
385        // Log message role for debugging
386        tracing::debug!(message_role = %choice.message.role, "OpenRouter message role");
387
388        // Log refusal if present (model declined to respond)
389        if let Some(ref refusal) = choice.message.refusal {
390            tracing::warn!(refusal = %refusal, "Model refused to respond");
391        }
392
393        // Add tool calls if present
394        if let Some(tool_calls) = &choice.message.tool_calls {
395            has_tool_calls = !tool_calls.is_empty();
396            for tc in tool_calls {
397                // Log tool call details (uses call_type and index fields)
398                tracing::debug!(
399                    tool_call_id = %tc.id,
400                    call_type = %tc.call_type,
401                    index = ?tc.index,
402                    function_name = %tc.function.name,
403                    "Processing OpenRouter tool call"
404                );
405                content.push(ContentPart::ToolCall {
406                    id: tc.id.clone(),
407                    name: tc.function.name.clone(),
408                    arguments: tc.function.arguments.clone(),
409                    thought_signature: None,
410                });
411            }
412        }
413
414        // Determine finish reason
415        let finish_reason = if has_tool_calls {
416            FinishReason::ToolCalls
417        } else {
418            match choice.finish_reason.as_deref() {
419                Some("stop") => FinishReason::Stop,
420                Some("length") => FinishReason::Length,
421                Some("tool_calls") => FinishReason::ToolCalls,
422                Some("content_filter") => FinishReason::ContentFilter,
423                _ => FinishReason::Stop,
424            }
425        };
426
427        Ok(CompletionResponse {
428            message: Message {
429                role: Role::Assistant,
430                content,
431            },
432            usage: Usage {
433                prompt_tokens: response
434                    .usage
435                    .as_ref()
436                    .map(|u| u.prompt_tokens)
437                    .unwrap_or(0),
438                completion_tokens: response
439                    .usage
440                    .as_ref()
441                    .map(|u| u.completion_tokens)
442                    .unwrap_or(0),
443                total_tokens: response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0),
444                ..Default::default()
445            },
446            finish_reason,
447        })
448    }
449
450    async fn complete_stream(
451        &self,
452        request: CompletionRequest,
453    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
454        tracing::debug!(
455            provider = "openrouter",
456            model = %request.model,
457            message_count = request.messages.len(),
458            "Starting streaming completion request (falling back to non-streaming)"
459        );
460
461        // For now, fall back to non-streaming
462        let response = self.complete(request).await?;
463        let text = response
464            .message
465            .content
466            .iter()
467            .filter_map(|p| match p {
468                ContentPart::Text { text } => Some(text.clone()),
469                _ => None,
470            })
471            .collect::<Vec<_>>()
472            .join("");
473
474        Ok(Box::pin(futures::stream::once(async move {
475            StreamChunk::Text(text)
476        })))
477    }
478}