Skip to main content

matrixcode_core/providers/
openai.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use serde_json::{Value, json};
4
5use crate::models::context_window_for;
6use crate::tools::ToolDefinition;
7
8use super::{
9    ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role, StopReason,
10    Usage,
11};
12
13pub struct OpenAIProvider {
14    api_key: String,
15    model: String,
16    base_url: String,
17    client: reqwest::Client,
18    /// Extra headers from config
19    extra_headers: Vec<(String, String)>,
20}
21
22impl OpenAIProvider {
23    pub fn new(api_key: String, model: String, base_url: String) -> Self {
24        Self::with_headers(api_key, model, base_url, None)
25    }
26
27    pub fn with_headers(
28        api_key: String,
29        model: String,
30        base_url: String,
31        extra_headers: Option<std::collections::HashMap<String, String>>,
32    ) -> Self {
33        // Use longer timeout for streaming responses (like Anthropic provider)
34        // - Total timeout: 300s (for long thinking/reasoning)
35        // - Connect timeout: 10s
36        // - Read timeout per chunk: 60s (for slow responses between chunks)
37        let client = reqwest::Client::builder()
38            .timeout(std::time::Duration::from_secs(300))
39            .connect_timeout(std::time::Duration::from_secs(10))
40            .read_timeout(std::time::Duration::from_secs(60))
41            .build()
42            .unwrap_or_else(|_| reqwest::Client::new());
43        let extra_headers: Vec<(String, String)> = extra_headers
44            .map(|h| h.into_iter().collect())
45            .unwrap_or_default();
46        Self {
47            api_key,
48            model,
49            base_url,
50            client,
51            extra_headers,
52        }
53    }
54
55    fn convert_messages(&self, messages: &[Message], system: Option<&str>) -> Vec<Value> {
56        let mut result = Vec::new();
57
58        if let Some(sys) = system {
59            result.push(json!({"role": "system", "content": sys}));
60        }
61
62        for msg in messages {
63            match (&msg.role, &msg.content) {
64                (Role::System, _) => {}
65                (Role::User, MessageContent::Text(text)) => {
66                    result.push(json!({"role": "user", "content": text}));
67                }
68                (Role::Assistant, MessageContent::Text(text)) => {
69                    result.push(json!({"role": "assistant", "content": text}));
70                }
71                (Role::Assistant, MessageContent::Blocks(blocks)) => {
72                    let mut tool_calls = Vec::new();
73                    let mut text_parts = Vec::new();
74                    let mut reasoning_parts = Vec::new();
75
76                    for block in blocks {
77                        match block {
78                            ContentBlock::Thinking { thinking, .. } => {
79                                // DeepSeek requires reasoning_content before content
80                                reasoning_parts.push(thinking.clone());
81                            }
82                            ContentBlock::Text { text } => text_parts.push(text.clone()),
83                            ContentBlock::ToolUse { id, name, input } => {
84                                tool_calls.push(json!({
85                                    "id": id,
86                                    "type": "function",
87                                    "function": {
88                                        "name": name,
89                                        "arguments": input.to_string(),
90                                    }
91                                }));
92                            }
93                            _ => {}
94                        }
95                    }
96
97                    let mut msg_obj = json!({"role": "assistant"});
98                    // DeepSeek: reasoning_content must be before content
99                    if !reasoning_parts.is_empty() {
100                        msg_obj["reasoning_content"] = json!(reasoning_parts.join("\n"));
101                    }
102                    if !text_parts.is_empty() {
103                        msg_obj["content"] = json!(text_parts.join("\n"));
104                    }
105                    if !tool_calls.is_empty() {
106                        msg_obj["tool_calls"] = json!(tool_calls);
107                    }
108                    result.push(msg_obj);
109                }
110                (Role::Tool, MessageContent::Blocks(blocks)) => {
111                    self.push_tool_results(blocks, &mut result);
112                }
113                (Role::User, MessageContent::Blocks(blocks)) => {
114                    // Check if this is a tool result message (agent wraps tool results as User role)
115                    if blocks
116                        .iter()
117                        .any(|b| matches!(b, ContentBlock::ToolResult { .. }))
118                    {
119                        // Emit as OpenAI tool messages
120                        self.push_tool_results(blocks, &mut result);
121                    } else {
122                        // Regular user message with blocks
123                        let text: String = blocks
124                            .iter()
125                            .filter_map(|b| match b {
126                                ContentBlock::Text { text } => Some(text.as_str()),
127                                _ => None,
128                            })
129                            .collect::<Vec<_>>()
130                            .join("\n");
131                        result.push(json!({"role": "user", "content": text}));
132                    }
133                }
134                _ => {}
135            }
136        }
137
138        result
139    }
140
141    /// Push tool result blocks to message array
142    fn push_tool_results(&self, blocks: &[ContentBlock], result: &mut Vec<Value>) {
143        for block in blocks {
144            if let ContentBlock::ToolResult {
145                tool_use_id,
146                content,
147            } = block
148            {
149                result.push(json!({
150                    "role": "tool",
151                    "tool_call_id": tool_use_id,
152                    "content": content,
153                }));
154            }
155        }
156    }
157
158    fn convert_tools(&self, tools: &[ToolDefinition]) -> Vec<Value> {
159        tools
160            .iter()
161            .map(|t| {
162                json!({
163                    "type": "function",
164                    "function": {
165                        "name": t.name,
166                        "description": t.description,
167                        "parameters": t.parameters,
168                    }
169                })
170            })
171            .collect()
172    }
173}
174
175#[async_trait]
176impl Provider for OpenAIProvider {
177    fn context_size(&self) -> Option<u32> {
178        context_window_for(&self.model)
179    }
180
181    fn model_name(&self) -> &str {
182        &self.model
183    }
184
185    fn clone_box(&self) -> Box<dyn Provider> {
186        Box::new(Self {
187            api_key: self.api_key.clone(),
188            model: self.model.clone(),
189            base_url: self.base_url.clone(),
190            client: reqwest::Client::new(),
191            extra_headers: self.extra_headers.clone(),
192        })
193    }
194
195    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
196        let messages = self.convert_messages(&request.messages, request.system.as_deref());
197
198        let mut body = json!({
199            "model": self.model,
200            "messages": messages,
201            "max_completion_tokens": request.max_tokens,
202        });
203
204        if !request.tools.is_empty() {
205            body["tools"] = json!(self.convert_tools(&request.tools));
206        }
207
208        let url = format!("{}/chat/completions", self.base_url);
209
210        // Debug: log request
211        crate::debug::debug_log().api_request(&url, &serde_json::to_string(&body).unwrap_or_default());
212
213        let mut req = self
214            .client
215            .post(&url)
216            .header("Authorization", format!("Bearer {}", self.api_key))
217            .header("Content-Type", "application/json")
218            .json(&body);
219
220        // Add extra headers from config
221        for (name, value) in &self.extra_headers {
222            req = req.header(name, value);
223        }
224
225        let response = req.send().await?;
226
227        let status = response.status();
228        let response_body: Value = response.json().await?;
229
230        // Debug: log response
231        crate::debug::debug_log().api_response(status.as_u16(), &serde_json::to_string(&response_body).unwrap_or_default());
232
233        if !status.is_success() {
234            let err_msg = response_body["error"]["message"]
235                .as_str()
236                .unwrap_or("unknown error");
237            anyhow::bail!("OpenAI API error ({}): {}", status, err_msg);
238        }
239
240        let choice = &response_body["choices"][0];
241        let message = &choice["message"];
242        let finish_reason = choice["finish_reason"].as_str().unwrap_or("stop");
243
244        let stop_reason = match finish_reason {
245            "tool_calls" => StopReason::ToolUse,
246            "length" => StopReason::MaxTokens,
247            _ => StopReason::EndTurn,
248        };
249
250        let mut content = Vec::new();
251
252        let usage_blob = &response_body["usage"];
253        let usage = Usage {
254            input_tokens: usage_blob["prompt_tokens"].as_u64().unwrap_or(0) as u32,
255            output_tokens: usage_blob["completion_tokens"].as_u64().unwrap_or(0) as u32,
256            cache_creation_input_tokens: 0,
257            cache_read_input_tokens: usage_blob["prompt_tokens_details"]["cached_tokens"]
258                .as_u64()
259                .unwrap_or(0) as u32,
260        };
261
262        // DeepSeek thinking mode: reasoning_content must come before content
263        if let Some(reasoning) = message["reasoning_content"].as_str()
264            && !reasoning.is_empty()
265        {
266            content.push(ContentBlock::Thinking {
267                thinking: reasoning.to_string(),
268                signature: None,
269            });
270        }
271
272        if let Some(text) = message["content"].as_str()
273            && !text.is_empty()
274        {
275            content.push(ContentBlock::Text {
276                text: text.to_string(),
277            });
278        }
279
280        if let Some(tool_calls) = message["tool_calls"].as_array() {
281            for tc in tool_calls {
282                let id = tc["id"].as_str().unwrap_or_default().to_string();
283                let name = tc["function"]["name"]
284                    .as_str()
285                    .unwrap_or_default()
286                    .to_string();
287                let arguments = tc["function"]["arguments"].as_str().unwrap_or("{}");
288                let input: Value = serde_json::from_str(arguments).unwrap_or(json!({}));
289
290                content.push(ContentBlock::ToolUse { id, name, input });
291            }
292
293            if stop_reason == StopReason::EndTurn && !tool_calls.is_empty() {
294                return Ok(ChatResponse {
295                    content,
296                    stop_reason: StopReason::ToolUse,
297                    usage: usage.clone(),
298                });
299            }
300        }
301
302        Ok(ChatResponse {
303            content,
304            stop_reason,
305            usage,
306        })
307    }
308}