Skip to main content

matrixcode_core/providers/
openai.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use serde_json::{Value, json};
4
5use crate::models::context_window_for;
6use crate::tools::ToolDefinition;
7
8use super::{
9    ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role, StopReason,
10    Usage,
11};
12
13pub struct OpenAIProvider {
14    api_key: String,
15    model: String,
16    base_url: String,
17    client: reqwest::Client,
18}
19
20impl OpenAIProvider {
21    pub fn new(api_key: String, model: String, base_url: String) -> Self {
22        let client = reqwest::Client::builder()
23            .timeout(std::time::Duration::from_secs(120))
24            .connect_timeout(std::time::Duration::from_secs(10))
25            .build()
26            .unwrap_or_else(|_| reqwest::Client::new());
27        Self {
28            api_key,
29            model,
30            base_url,
31            client,
32        }
33    }
34
35    fn convert_messages(&self, messages: &[Message], system: Option<&str>) -> Vec<Value> {
36        let mut result = Vec::new();
37
38        if let Some(sys) = system {
39            result.push(json!({"role": "system", "content": sys}));
40        }
41
42        for msg in messages {
43            match (&msg.role, &msg.content) {
44                (Role::System, _) => {}
45                (Role::User, MessageContent::Text(text)) => {
46                    result.push(json!({"role": "user", "content": text}));
47                }
48                (Role::Assistant, MessageContent::Text(text)) => {
49                    result.push(json!({"role": "assistant", "content": text}));
50                }
51                (Role::Assistant, MessageContent::Blocks(blocks)) => {
52                    let mut tool_calls = Vec::new();
53                    let mut text_parts = Vec::new();
54
55                    for block in blocks {
56                        match block {
57                            ContentBlock::Text { text } => text_parts.push(text.clone()),
58                            ContentBlock::ToolUse { id, name, input } => {
59                                tool_calls.push(json!({
60                                    "id": id,
61                                    "type": "function",
62                                    "function": {
63                                        "name": name,
64                                        "arguments": input.to_string(),
65                                    }
66                                }));
67                            }
68                            ContentBlock::Thinking { .. } => {}
69                            _ => {}
70                        }
71                    }
72
73                    let mut msg_obj = json!({"role": "assistant"});
74                    if !text_parts.is_empty() {
75                        msg_obj["content"] = json!(text_parts.join("\n"));
76                    }
77                    if !tool_calls.is_empty() {
78                        msg_obj["tool_calls"] = json!(tool_calls);
79                    }
80                    result.push(msg_obj);
81                }
82                (Role::Tool, MessageContent::Blocks(blocks)) => {
83                    self.push_tool_results(blocks, &mut result);
84                }
85                (Role::User, MessageContent::Blocks(blocks)) => {
86                    // Check if this is a tool result message (agent wraps tool results as User role)
87                    if blocks
88                        .iter()
89                        .any(|b| matches!(b, ContentBlock::ToolResult { .. }))
90                    {
91                        // Emit as OpenAI tool messages
92                        self.push_tool_results(blocks, &mut result);
93                    } else {
94                        // Regular user message with blocks
95                        let text: String = blocks
96                            .iter()
97                            .filter_map(|b| match b {
98                                ContentBlock::Text { text } => Some(text.as_str()),
99                                _ => None,
100                            })
101                            .collect::<Vec<_>>()
102                            .join("\n");
103                        result.push(json!({"role": "user", "content": text}));
104                    }
105                }
106                _ => {}
107            }
108        }
109
110        result
111    }
112
113    /// Push tool result blocks to message array
114    fn push_tool_results(&self, blocks: &[ContentBlock], result: &mut Vec<Value>) {
115        for block in blocks {
116            if let ContentBlock::ToolResult {
117                tool_use_id,
118                content,
119            } = block
120            {
121                result.push(json!({
122                    "role": "tool",
123                    "tool_call_id": tool_use_id,
124                    "content": content,
125                }));
126            }
127        }
128    }
129
130    fn convert_tools(&self, tools: &[ToolDefinition]) -> Vec<Value> {
131        tools
132            .iter()
133            .map(|t| {
134                json!({
135                    "type": "function",
136                    "function": {
137                        "name": t.name,
138                        "description": t.description,
139                        "parameters": t.parameters,
140                    }
141                })
142            })
143            .collect()
144    }
145}
146
147#[async_trait]
148impl Provider for OpenAIProvider {
149    fn context_size(&self) -> Option<u32> {
150        context_window_for(&self.model)
151    }
152
153    fn clone_box(&self) -> Box<dyn Provider> {
154        Box::new(Self {
155            api_key: self.api_key.clone(),
156            model: self.model.clone(),
157            base_url: self.base_url.clone(),
158            client: reqwest::Client::new(),
159        })
160    }
161
162    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
163        let messages = self.convert_messages(&request.messages, request.system.as_deref());
164
165        let mut body = json!({
166            "model": self.model,
167            "messages": messages,
168            "max_completion_tokens": request.max_tokens,
169        });
170
171        if !request.tools.is_empty() {
172            body["tools"] = json!(self.convert_tools(&request.tools));
173        }
174
175        let url = format!("{}/chat/completions", self.base_url);
176        let response = self
177            .client
178            .post(&url)
179            .header("Authorization", format!("Bearer {}", self.api_key))
180            .header("Content-Type", "application/json")
181            .json(&body)
182            .send()
183            .await?;
184
185        let status = response.status();
186        let response_body: Value = response.json().await?;
187
188        if !status.is_success() {
189            let err_msg = response_body["error"]["message"]
190                .as_str()
191                .unwrap_or("unknown error");
192            anyhow::bail!("OpenAI API error ({}): {}", status, err_msg);
193        }
194
195        let choice = &response_body["choices"][0];
196        let message = &choice["message"];
197        let finish_reason = choice["finish_reason"].as_str().unwrap_or("stop");
198
199        let stop_reason = match finish_reason {
200            "tool_calls" => StopReason::ToolUse,
201            "length" => StopReason::MaxTokens,
202            _ => StopReason::EndTurn,
203        };
204
205        let mut content = Vec::new();
206
207        let usage_blob = &response_body["usage"];
208        let usage = Usage {
209            input_tokens: usage_blob["prompt_tokens"].as_u64().unwrap_or(0) as u32,
210            output_tokens: usage_blob["completion_tokens"].as_u64().unwrap_or(0) as u32,
211            cache_creation_input_tokens: 0,
212            cache_read_input_tokens: usage_blob["prompt_tokens_details"]["cached_tokens"]
213                .as_u64()
214                .unwrap_or(0) as u32,
215        };
216
217        if let Some(text) = message["content"].as_str()
218            && !text.is_empty()
219        {
220            content.push(ContentBlock::Text {
221                text: text.to_string(),
222            });
223        }
224
225        if let Some(tool_calls) = message["tool_calls"].as_array() {
226            for tc in tool_calls {
227                let id = tc["id"].as_str().unwrap_or_default().to_string();
228                let name = tc["function"]["name"]
229                    .as_str()
230                    .unwrap_or_default()
231                    .to_string();
232                let arguments = tc["function"]["arguments"].as_str().unwrap_or("{}");
233                let input: Value = serde_json::from_str(arguments).unwrap_or(json!({}));
234
235                content.push(ContentBlock::ToolUse { id, name, input });
236            }
237
238            if stop_reason == StopReason::EndTurn && !tool_calls.is_empty() {
239                return Ok(ChatResponse {
240                    content,
241                    stop_reason: StopReason::ToolUse,
242                    usage: usage.clone(),
243                });
244            }
245        }
246
247        Ok(ChatResponse {
248            content,
249            stop_reason,
250            usage,
251        })
252    }
253}