Skip to main content

matrixcode_core/providers/
openai.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use serde_json::{Value, json};
4use std::sync::Arc;
5use std::time::Duration;
6
7use crate::constants::{
8    DEFAULT_CONNECT_TIMEOUT_SECS, DEFAULT_READ_TIMEOUT_SECS, DEFAULT_REQUEST_TIMEOUT_SECS,
9};
10use crate::models::context_window_for;
11use crate::tools::ToolDefinition;
12
13use super::{
14    ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role, StopReason,
15    Usage,
16};
17
18pub struct OpenAIProvider {
19    api_key: String,
20    model: String,
21    base_url: String,
22    client: reqwest::Client,
23    /// Extra headers from config
24    extra_headers: Vec<(String, String)>,
25}
26
27impl OpenAIProvider {
28    pub fn new(api_key: String, model: String, base_url: String) -> Self {
29        Self::with_headers(api_key, model, base_url, None)
30    }
31
32    pub fn with_headers(
33        api_key: String,
34        model: String,
35        base_url: String,
36        extra_headers: Option<std::collections::HashMap<String, String>>,
37    ) -> Self {
38        // Use longer timeout for streaming responses (like Anthropic provider)
39        // - Total timeout: 300s (for long thinking/reasoning)
40        // - Connect timeout: 10s
41        // - Read timeout per chunk: 60s (for slow responses between chunks)
42        let client = reqwest::Client::builder()
43            .timeout(Duration::from_secs(DEFAULT_REQUEST_TIMEOUT_SECS))
44            .connect_timeout(Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SECS))
45            .read_timeout(Duration::from_secs(DEFAULT_READ_TIMEOUT_SECS))
46            .build()
47            .unwrap_or_else(|_| reqwest::Client::new());
48        let extra_headers: Vec<(String, String)> = extra_headers
49            .map(|h| h.into_iter().collect())
50            .unwrap_or_default();
51        Self {
52            api_key,
53            model,
54            base_url,
55            client,
56            extra_headers,
57        }
58    }
59
60    fn convert_messages(&self, messages: &[Message], system: Option<&str>) -> Vec<Value> {
61        let mut result = Vec::new();
62
63        if let Some(sys) = system {
64            result.push(json!({"role": "system", "content": sys}));
65        }
66
67        for msg in messages {
68            match (&msg.role, &msg.content) {
69                (Role::System, _) => {}
70                (Role::User, MessageContent::Text(text)) => {
71                    result.push(json!({"role": "user", "content": text}));
72                }
73                (Role::Assistant, MessageContent::Text(text)) => {
74                    result.push(json!({"role": "assistant", "content": text}));
75                }
76                (Role::Assistant, MessageContent::Blocks(blocks)) => {
77                    let mut tool_calls = Vec::new();
78                    let mut text_parts = Vec::new();
79                    let mut reasoning_parts = Vec::new();
80
81                    for block in blocks {
82                        match block {
83                            ContentBlock::Thinking { thinking, .. } => {
84                                // DeepSeek requires reasoning_content before content
85                                reasoning_parts.push(thinking.clone());
86                            }
87                            ContentBlock::Text { text } => text_parts.push(text.clone()),
88                            ContentBlock::ToolUse { id, name, input } => {
89                                tool_calls.push(json!({
90                                    "id": id,
91                                    "type": "function",
92                                    "function": {
93                                        "name": name,
94                                        "arguments": input.to_string(),
95                                    }
96                                }));
97                            }
98                            _ => {}
99                        }
100                    }
101
102                    let mut msg_obj = json!({"role": "assistant"});
103                    // DeepSeek: reasoning_content must be before content
104                    if !reasoning_parts.is_empty() {
105                        msg_obj["reasoning_content"] = json!(reasoning_parts.join("\n"));
106                    }
107                    if !text_parts.is_empty() {
108                        msg_obj["content"] = json!(text_parts.join("\n"));
109                    }
110                    if !tool_calls.is_empty() {
111                        msg_obj["tool_calls"] = json!(tool_calls);
112                    }
113                    result.push(msg_obj);
114                }
115                (Role::Tool, MessageContent::Blocks(blocks)) => {
116                    self.push_tool_results(blocks, &mut result);
117                }
118                (Role::User, MessageContent::Blocks(blocks)) => {
119                    // Check if this is a tool result message (agent wraps tool results as User role)
120                    if blocks
121                        .iter()
122                        .any(|b| matches!(b, ContentBlock::ToolResult { .. }))
123                    {
124                        // Emit as OpenAI tool messages
125                        self.push_tool_results(blocks, &mut result);
126                    } else {
127                        // Regular user message with blocks
128                        let text: String = blocks
129                            .iter()
130                            .filter_map(|b| match b {
131                                ContentBlock::Text { text } => Some(text.as_str()),
132                                _ => None,
133                            })
134                            .collect::<Vec<_>>()
135                            .join("\n");
136                        result.push(json!({"role": "user", "content": text}));
137                    }
138                }
139                _ => {}
140            }
141        }
142
143        result
144    }
145
146    /// Push tool result blocks to message array
147    fn push_tool_results(&self, blocks: &[ContentBlock], result: &mut Vec<Value>) {
148        for block in blocks {
149            if let ContentBlock::ToolResult {
150                tool_use_id,
151                content,
152            } = block
153            {
154                result.push(json!({
155                    "role": "tool",
156                    "tool_call_id": tool_use_id,
157                    "content": content,
158                }));
159            }
160        }
161    }
162
163    fn convert_tools(&self, tools: &[ToolDefinition]) -> Vec<Value> {
164        tools
165            .iter()
166            .map(|t| {
167                json!({
168                    "type": "function",
169                    "function": {
170                        "name": t.name,
171                        "description": t.description,
172                        "parameters": t.parameters,
173                    }
174                })
175            })
176            .collect()
177    }
178}
179
180#[async_trait]
181impl Provider for OpenAIProvider {
182    fn context_size(&self) -> Option<u32> {
183        context_window_for(&self.model)
184    }
185
186    fn model_name(&self) -> &str {
187        &self.model
188    }
189
190    fn clone_box(&self) -> Box<dyn Provider> {
191        Box::new(Self {
192            api_key: self.api_key.clone(),
193            model: self.model.clone(),
194            base_url: self.base_url.clone(),
195            client: reqwest::Client::new(),
196            extra_headers: self.extra_headers.clone(),
197        })
198    }
199
200    fn clone_arc(&self) -> Arc<dyn Provider> {
201        Arc::new(Self {
202            api_key: self.api_key.clone(),
203            model: self.model.clone(),
204            base_url: self.base_url.clone(),
205            client: reqwest::Client::new(),
206            extra_headers: self.extra_headers.clone(),
207        })
208    }
209
210    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
211        let messages = self.convert_messages(&request.messages, request.system.as_deref());
212
213        let mut body = json!({
214            "model": self.model,
215            "messages": messages,
216            "max_completion_tokens": request.max_tokens,
217        });
218
219        if !request.tools.is_empty() {
220            body["tools"] = json!(self.convert_tools(&request.tools));
221        }
222
223        let url = format!("{}/chat/completions", self.base_url);
224
225        // Debug: log request
226        crate::debug::debug_log()
227            .api_request(&url, &serde_json::to_string(&body).unwrap_or_default());
228
229        let mut req = self
230            .client
231            .post(&url)
232            .header("Authorization", format!("Bearer {}", self.api_key))
233            .header("Content-Type", "application/json")
234            .json(&body);
235
236        // Add extra headers from config
237        for (name, value) in &self.extra_headers {
238            req = req.header(name, value);
239        }
240
241        let response = req
242            .send()
243            .await
244            .map_err(|e| anyhow::anyhow!("HTTP request failed: {} (url: {})", e, url))?;
245
246        let status = response.status();
247        let response_body: Value = response
248            .json()
249            .await
250            .map_err(|e| anyhow::anyhow!("Failed to parse response JSON: {}", e))?;
251
252        // Debug: log response
253        crate::debug::debug_log().api_response(
254            status.as_u16(),
255            &serde_json::to_string(&response_body).unwrap_or_default(),
256        );
257
258        if !status.is_success() {
259            let err_msg = response_body["error"]["message"]
260                .as_str()
261                .unwrap_or_else(|| response_body["error"].as_str().unwrap_or("unknown error"));
262            anyhow::bail!("API error ({}): {}", status, err_msg);
263        }
264
265        let choice = &response_body["choices"][0];
266        let message = &choice["message"];
267        let finish_reason = choice["finish_reason"].as_str().unwrap_or("stop");
268
269        let stop_reason = match finish_reason {
270            "tool_calls" => StopReason::ToolUse,
271            "length" => StopReason::MaxTokens,
272            _ => StopReason::EndTurn,
273        };
274
275        let mut content = Vec::new();
276
277        let usage_blob = &response_body["usage"];
278        let usage = Usage {
279            input_tokens: usage_blob["prompt_tokens"].as_u64().unwrap_or(0) as u32,
280            output_tokens: usage_blob["completion_tokens"].as_u64().unwrap_or(0) as u32,
281            cache_creation_input_tokens: 0,
282            cache_read_input_tokens: usage_blob["prompt_tokens_details"]["cached_tokens"]
283                .as_u64()
284                .unwrap_or(0) as u32,
285        };
286
287        // DeepSeek thinking mode: reasoning_content must come before content
288        if let Some(reasoning) = message["reasoning_content"].as_str()
289            && !reasoning.is_empty()
290        {
291            content.push(ContentBlock::Thinking {
292                thinking: reasoning.to_string(),
293                signature: None,
294            });
295        }
296
297        if let Some(text) = message["content"].as_str()
298            && !text.is_empty()
299        {
300            content.push(ContentBlock::Text {
301                text: text.to_string(),
302            });
303        }
304
305        if let Some(tool_calls) = message["tool_calls"].as_array() {
306            for tc in tool_calls {
307                let id = tc["id"].as_str().unwrap_or_default().to_string();
308                let name = tc["function"]["name"]
309                    .as_str()
310                    .unwrap_or_default()
311                    .to_string();
312                let arguments = tc["function"]["arguments"].as_str().unwrap_or("{}");
313                let input: Value = serde_json::from_str(arguments).unwrap_or(json!({}));
314
315                content.push(ContentBlock::ToolUse { id, name, input });
316            }
317
318            if stop_reason == StopReason::EndTurn && !tool_calls.is_empty() {
319                return Ok(ChatResponse {
320                    content,
321                    stop_reason: StopReason::ToolUse,
322                    usage: usage.clone(),
323                });
324            }
325        }
326
327        Ok(ChatResponse {
328            content,
329            stop_reason,
330            usage,
331        })
332    }
333}