Skip to main content

matrixcode_core/providers/
openai.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use serde_json::{Value, json};
4use std::sync::Arc;
5use std::time::Duration;
6
7use crate::constants::{
8    DEFAULT_CONNECT_TIMEOUT_SECS, DEFAULT_READ_TIMEOUT_SECS, DEFAULT_REQUEST_TIMEOUT_SECS,
9};
10use crate::models::context_window_for;
11use crate::tools::ToolDefinition;
12
13use super::{
14    ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role, StopReason,
15    Usage,
16};
17
18pub struct OpenAIProvider {
19    api_key: String,
20    model: String,
21    base_url: String,
22    client: reqwest::Client,
23    /// Extra headers from config
24    extra_headers: Vec<(String, String)>,
25}
26
27impl OpenAIProvider {
28    pub fn new(api_key: String, model: String, base_url: String) -> Self {
29        Self::with_headers(api_key, model, base_url, None)
30    }
31
32    pub fn with_headers(
33        api_key: String,
34        model: String,
35        base_url: String,
36        extra_headers: Option<std::collections::HashMap<String, String>>,
37    ) -> Self {
38        // Use longer timeout for streaming responses (like Anthropic provider)
39        // - Total timeout: 300s (for long thinking/reasoning)
40        // - Connect timeout: 10s
41        // - Read timeout per chunk: 60s (for slow responses between chunks)
42        let client = reqwest::Client::builder()
43            .timeout(Duration::from_secs(DEFAULT_REQUEST_TIMEOUT_SECS))
44            .connect_timeout(Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SECS))
45            .read_timeout(Duration::from_secs(DEFAULT_READ_TIMEOUT_SECS))
46            .build()
47            .unwrap_or_else(|_| reqwest::Client::new());
48        let extra_headers: Vec<(String, String)> = extra_headers
49            .map(|h| h.into_iter().collect())
50            .unwrap_or_default();
51        Self {
52            api_key,
53            model,
54            base_url,
55            client,
56            extra_headers,
57        }
58    }
59
60    fn convert_messages(&self, messages: &[Message], system: Option<&str>) -> Vec<Value> {
61        let mut result = Vec::new();
62
63        if let Some(sys) = system {
64            result.push(json!({"role": "system", "content": sys}));
65        }
66
67        for msg in messages {
68            match (&msg.role, &msg.content) {
69                (Role::System, _) => {}
70                (Role::User, MessageContent::Text(text)) => {
71                    result.push(json!({"role": "user", "content": text}));
72                }
73                (Role::Assistant, MessageContent::Text(text)) => {
74                    result.push(json!({"role": "assistant", "content": text}));
75                }
76                (Role::Assistant, MessageContent::Blocks(blocks)) => {
77                    let mut tool_calls = Vec::new();
78                    let mut text_parts = Vec::new();
79                    // Note: We intentionally DO NOT include thinking/reasoning blocks from history
80                    // to prevent the model from repeating similar thinking patterns.
81                    // Thinking blocks are for user visibility only, not for API context.
82
83                    for block in blocks {
84                        match block {
85                            // Skip thinking blocks - they should not be sent back to the API
86                            ContentBlock::Thinking { .. } => {
87                                continue;
88                            }
89                            ContentBlock::Text { text } => text_parts.push(text.clone()),
90                            ContentBlock::ToolUse { id, name, input } => {
91                                tool_calls.push(json!({
92                                    "id": id,
93                                    "type": "function",
94                                    "function": {
95                                        "name": name,
96                                        "arguments": input.to_string(),
97                                    }
98                                }));
99                            }
100                            _ => {}
101                        }
102                    }
103
104                    let mut msg_obj = json!({"role": "assistant"});
105                    // Note: reasoning_content is NOT included from history to prevent repeated thinking
106                    if !text_parts.is_empty() {
107                        msg_obj["content"] = json!(text_parts.join("\n"));
108                    }
109                    if !tool_calls.is_empty() {
110                        msg_obj["tool_calls"] = json!(tool_calls);
111                    }
112                    result.push(msg_obj);
113                }
114                (Role::Tool, MessageContent::Blocks(blocks)) => {
115                    self.push_tool_results(blocks, &mut result);
116                }
117                (Role::User, MessageContent::Blocks(blocks)) => {
118                    // Check if this is a tool result message (agent wraps tool results as User role)
119                    if blocks
120                        .iter()
121                        .any(|b| matches!(b, ContentBlock::ToolResult { .. }))
122                    {
123                        // Emit as OpenAI tool messages
124                        self.push_tool_results(blocks, &mut result);
125                    } else {
126                        // Regular user message with blocks
127                        let text: String = blocks
128                            .iter()
129                            .filter_map(|b| match b {
130                                ContentBlock::Text { text } => Some(text.as_str()),
131                                _ => None,
132                            })
133                            .collect::<Vec<_>>()
134                            .join("\n");
135                        result.push(json!({"role": "user", "content": text}));
136                    }
137                }
138                _ => {}
139            }
140        }
141
142        result
143    }
144
145    /// Push tool result blocks to message array
146    fn push_tool_results(&self, blocks: &[ContentBlock], result: &mut Vec<Value>) {
147        for block in blocks {
148            if let ContentBlock::ToolResult {
149                tool_use_id,
150                content,
151            } = block
152            {
153                result.push(json!({
154                    "role": "tool",
155                    "tool_call_id": tool_use_id,
156                    "content": content,
157                }));
158            }
159        }
160    }
161
162    fn convert_tools(&self, tools: &[ToolDefinition]) -> Vec<Value> {
163        tools
164            .iter()
165            .map(|t| {
166                json!({
167                    "type": "function",
168                    "function": {
169                        "name": t.name,
170                        "description": t.description,
171                        "parameters": t.parameters,
172                    }
173                })
174            })
175            .collect()
176    }
177}
178
179#[async_trait]
180impl Provider for OpenAIProvider {
181    fn context_size(&self) -> Option<u32> {
182        context_window_for(&self.model)
183    }
184
185    fn model_name(&self) -> &str {
186        &self.model
187    }
188
189    fn clone_box(&self) -> Box<dyn Provider> {
190        Box::new(Self {
191            api_key: self.api_key.clone(),
192            model: self.model.clone(),
193            base_url: self.base_url.clone(),
194            client: reqwest::Client::new(),
195            extra_headers: self.extra_headers.clone(),
196        })
197    }
198
199    fn clone_arc(&self) -> Arc<dyn Provider> {
200        Arc::new(Self {
201            api_key: self.api_key.clone(),
202            model: self.model.clone(),
203            base_url: self.base_url.clone(),
204            client: reqwest::Client::new(),
205            extra_headers: self.extra_headers.clone(),
206        })
207    }
208
209    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
210        let messages = self.convert_messages(&request.messages, request.system.as_deref());
211
212        let mut body = json!({
213            "model": self.model,
214            "messages": messages,
215            "max_completion_tokens": request.max_tokens,
216        });
217
218        if !request.tools.is_empty() {
219            body["tools"] = json!(self.convert_tools(&request.tools));
220        }
221
222        let url = format!("{}/chat/completions", self.base_url);
223
224        // Debug: log request
225        crate::debug::debug_log()
226            .api_request(&url, &serde_json::to_string(&body).unwrap_or_default());
227
228        let mut req = self
229            .client
230            .post(&url)
231            .header("Authorization", format!("Bearer {}", self.api_key))
232            .header("Content-Type", "application/json")
233            .json(&body);
234
235        // Add extra headers from config
236        for (name, value) in &self.extra_headers {
237            req = req.header(name, value);
238        }
239
240        let response = req
241            .send()
242            .await
243            .map_err(|e| anyhow::anyhow!("HTTP request failed: {} (url: {})", e, url))?;
244
245        let status = response.status();
246        let response_body: Value = response
247            .json()
248            .await
249            .map_err(|e| anyhow::anyhow!("Failed to parse response JSON: {}", e))?;
250
251        // Debug: log response
252        crate::debug::debug_log().api_response(
253            status.as_u16(),
254            &serde_json::to_string(&response_body).unwrap_or_default(),
255        );
256
257        if !status.is_success() {
258            let err_msg = response_body["error"]["message"]
259                .as_str()
260                .unwrap_or_else(|| response_body["error"].as_str().unwrap_or("unknown error"));
261            anyhow::bail!("API error ({}): {}", status, err_msg);
262        }
263
264        let choice = &response_body["choices"][0];
265        let message = &choice["message"];
266        let finish_reason = choice["finish_reason"].as_str().unwrap_or("stop");
267
268        let stop_reason = match finish_reason {
269            "tool_calls" => StopReason::ToolUse,
270            "length" => StopReason::MaxTokens,
271            _ => StopReason::EndTurn,
272        };
273
274        let mut content = Vec::new();
275
276        let usage_blob = &response_body["usage"];
277        let usage = Usage {
278            input_tokens: usage_blob["prompt_tokens"].as_u64().unwrap_or(0) as u32,
279            output_tokens: usage_blob["completion_tokens"].as_u64().unwrap_or(0) as u32,
280            cache_creation_input_tokens: 0,
281            cache_read_input_tokens: usage_blob["prompt_tokens_details"]["cached_tokens"]
282                .as_u64()
283                .unwrap_or(0) as u32,
284        };
285
286        // DeepSeek thinking mode: reasoning_content must come before content
287        if let Some(reasoning) = message["reasoning_content"].as_str()
288            && !reasoning.is_empty()
289        {
290            content.push(ContentBlock::Thinking {
291                thinking: reasoning.to_string(),
292                signature: None,
293            });
294        }
295
296        if let Some(text) = message["content"].as_str()
297            && !text.is_empty()
298        {
299            content.push(ContentBlock::Text {
300                text: text.to_string(),
301            });
302        }
303
304        if let Some(tool_calls) = message["tool_calls"].as_array() {
305            for tc in tool_calls {
306                let id = tc["id"].as_str().unwrap_or_default().to_string();
307                let name = tc["function"]["name"]
308                    .as_str()
309                    .unwrap_or_default()
310                    .to_string();
311                let arguments = tc["function"]["arguments"].as_str().unwrap_or("{}");
312                let input: Value = serde_json::from_str(arguments).unwrap_or(json!({}));
313
314                content.push(ContentBlock::ToolUse { id, name, input });
315            }
316
317            if stop_reason == StopReason::EndTurn && !tool_calls.is_empty() {
318                return Ok(ChatResponse {
319                    content,
320                    stop_reason: StopReason::ToolUse,
321                    usage: usage.clone(),
322                });
323            }
324        }
325
326        Ok(ChatResponse {
327            content,
328            stop_reason,
329            usage,
330        })
331    }
332}