Skip to main content

matrixcode_core/providers/
openai.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use serde_json::{Value, json};
4
5use crate::tools::ToolDefinition;
6
7use super::{
8    ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role, StopReason,
9    Usage,
10};
11
12pub struct OpenAIProvider {
13    api_key: String,
14    model: String,
15    base_url: String,
16    client: reqwest::Client,
17}
18
19impl OpenAIProvider {
20    pub fn new(api_key: String, model: String, base_url: String) -> Self {
21        Self {
22            api_key,
23            model,
24            base_url,
25            client: reqwest::Client::new(),
26        }
27    }
28
29    fn convert_messages(&self, messages: &[Message], system: Option<&str>) -> Vec<Value> {
30        let mut result = Vec::new();
31
32        if let Some(sys) = system {
33            result.push(json!({"role": "system", "content": sys}));
34        }
35
36        for msg in messages {
37            match (&msg.role, &msg.content) {
38                (Role::System, _) => {}
39                (Role::User, MessageContent::Text(text)) => {
40                    result.push(json!({"role": "user", "content": text}));
41                }
42                (Role::Assistant, MessageContent::Text(text)) => {
43                    result.push(json!({"role": "assistant", "content": text}));
44                }
45                (Role::Assistant, MessageContent::Blocks(blocks)) => {
46                    let mut tool_calls = Vec::new();
47                    let mut text_parts = Vec::new();
48
49                    for block in blocks {
50                        match block {
51                            ContentBlock::Text { text } => text_parts.push(text.clone()),
52                            ContentBlock::ToolUse { id, name, input } => {
53                                tool_calls.push(json!({
54                                    "id": id,
55                                    "type": "function",
56                                    "function": {
57                                        "name": name,
58                                        "arguments": input.to_string(),
59                                    }
60                                }));
61                            }
62                            ContentBlock::Thinking { .. } => {}
63                            _ => {}
64                        }
65                    }
66
67                    let mut msg_obj = json!({"role": "assistant"});
68                    if !text_parts.is_empty() {
69                        msg_obj["content"] = json!(text_parts.join("\n"));
70                    }
71                    if !tool_calls.is_empty() {
72                        msg_obj["tool_calls"] = json!(tool_calls);
73                    }
74                    result.push(msg_obj);
75                }
76                (Role::Tool, MessageContent::Blocks(blocks)) => {
77                    for block in blocks {
78                        if let ContentBlock::ToolResult { tool_use_id, content } = block {
79                            result.push(json!({
80                                "role": "tool",
81                                "tool_call_id": tool_use_id,
82                                "content": content,
83                            }));
84                        }
85                    }
86                }
87                (Role::User, MessageContent::Blocks(blocks)) => {
88                    // Check if this is a tool result message (agent wraps tool results as User role)
89                    let tool_results: Vec<&ContentBlock> = blocks
90                        .iter()
91                        .filter(|b| matches!(b, ContentBlock::ToolResult { .. }))
92                        .collect();
93                    
94                    if !tool_results.is_empty() {
95                        // Emit as OpenAI tool messages
96                        for block in blocks {
97                            if let ContentBlock::ToolResult { tool_use_id, content } = block {
98                                result.push(json!({
99                                    "role": "tool",
100                                    "tool_call_id": tool_use_id,
101                                    "content": content,
102                                }));
103                            }
104                        }
105                    } else {
106                        // Regular user message with blocks
107                        let text: String = blocks
108                            .iter()
109                            .filter_map(|b| match b {
110                                ContentBlock::Text { text } => Some(text.as_str()),
111                                _ => None,
112                            })
113                            .collect::<Vec<_>>()
114                            .join("\n");
115                        result.push(json!({"role": "user", "content": text}));
116                    }
117                }
118                _ => {}
119            }
120        }
121
122        result
123    }
124
125    fn convert_tools(&self, tools: &[ToolDefinition]) -> Vec<Value> {
126        tools
127            .iter()
128            .map(|t| {
129                json!({
130                    "type": "function",
131                    "function": {
132                        "name": t.name,
133                        "description": t.description,
134                        "parameters": t.parameters,
135                    }
136                })
137            })
138            .collect()
139    }
140}
141
142#[async_trait]
143impl Provider for OpenAIProvider {
144    fn context_size(&self) -> Option<u32> {
145        context_window_for(&self.model)
146    }
147
148    fn clone_box(&self) -> Box<dyn Provider> {
149        Box::new(Self {
150            api_key: self.api_key.clone(),
151            model: self.model.clone(),
152            base_url: self.base_url.clone(),
153            client: reqwest::Client::new(),
154        })
155    }
156
157    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
158        let messages = self.convert_messages(&request.messages, request.system.as_deref());
159
160        let mut body = json!({
161            "model": self.model,
162            "messages": messages,
163            "max_completion_tokens": request.max_tokens,
164        });
165
166        if !request.tools.is_empty() {
167            body["tools"] = json!(self.convert_tools(&request.tools));
168        }
169
170        let url = format!("{}/chat/completions", self.base_url);
171        let response = self
172            .client
173            .post(&url)
174            .header("Authorization", format!("Bearer {}", self.api_key))
175            .header("Content-Type", "application/json")
176            .json(&body)
177            .send()
178            .await?;
179
180        let status = response.status();
181        let response_body: Value = response.json().await?;
182
183        if !status.is_success() {
184            let err_msg = response_body["error"]["message"]
185                .as_str()
186                .unwrap_or("unknown error");
187            anyhow::bail!("OpenAI API error ({}): {}", status, err_msg);
188        }
189
190        let choice = &response_body["choices"][0];
191        let message = &choice["message"];
192        let finish_reason = choice["finish_reason"].as_str().unwrap_or("stop");
193
194        let stop_reason = match finish_reason {
195            "tool_calls" => StopReason::ToolUse,
196            "length" => StopReason::MaxTokens,
197            _ => StopReason::EndTurn,
198        };
199
200        let mut content = Vec::new();
201
202        let usage_blob = &response_body["usage"];
203        let usage = Usage {
204            input_tokens: usage_blob["prompt_tokens"].as_u64().unwrap_or(0) as u32,
205            output_tokens: usage_blob["completion_tokens"].as_u64().unwrap_or(0) as u32,
206            cache_creation_input_tokens: 0,
207            cache_read_input_tokens: usage_blob["prompt_tokens_details"]["cached_tokens"]
208                .as_u64()
209                .unwrap_or(0) as u32,
210        };
211
212        if let Some(text) = message["content"].as_str()
213            && !text.is_empty() {
214                content.push(ContentBlock::Text { text: text.to_string() });
215            }
216
217        if let Some(tool_calls) = message["tool_calls"].as_array() {
218            for tc in tool_calls {
219                let id = tc["id"].as_str().unwrap_or_default().to_string();
220                let name = tc["function"]["name"].as_str().unwrap_or_default().to_string();
221                let arguments = tc["function"]["arguments"].as_str().unwrap_or("{}");
222                let input: Value = serde_json::from_str(arguments).unwrap_or(json!({}));
223
224                content.push(ContentBlock::ToolUse { id, name, input });
225            }
226
227            if stop_reason == StopReason::EndTurn && !tool_calls.is_empty() {
228                return Ok(ChatResponse {
229                    content,
230                    stop_reason: StopReason::ToolUse,
231                    usage: usage.clone(),
232                });
233            }
234        }
235
236        Ok(ChatResponse {
237            content,
238            stop_reason,
239            usage,
240        })
241    }
242}
243
244/// Best-effort mapping from an OpenAI model name to its context window size.
245/// Honours the `CONTEXT_SIZE` env variable first so users can override.
246fn context_window_for(model: &str) -> Option<u32> {
247    if let Ok(raw) = std::env::var("CONTEXT_SIZE")
248        && let Ok(n) = raw.trim().parse::<u32>()
249            && n > 0 {
250                return Some(n);
251            }
252    let m = model.to_ascii_lowercase();
253    
254    // GPT-4o models: 128K context
255    if m.contains("gpt-4o") || m.contains("gpt-4-turbo") {
256        return Some(128_000);
257    }
258    // GPT-4 (original): 8K or 32K variants
259    if m.contains("gpt-4-32k") {
260        return Some(32_768);
261    }
262    if m.contains("gpt-4") && !m.contains("turbo") && !m.contains("o") {
263        return Some(8_192);
264    }
265    // GPT-3.5 Turbo: 16K (4K variant is deprecated)
266    if m.contains("gpt-3.5-turbo-16k") {
267        return Some(16_384);
268    }
269    if m.contains("gpt-3.5") {
270        return Some(4_096);
271    }
272    // o1 series: 200K context
273    if m.contains("o1") {
274        return Some(200_000);
275    }
276    // DeepSeek models
277    if m.contains("deepseek") {
278        if m.contains("v3") || m.contains("r1") {
279            return Some(128_000);
280        }
281        return Some(64_000);
282    }
283    // Qwen models (via OpenAI-compatible endpoints)
284    if m.contains("qwen") {
285        if m.contains("qwen-max") || m.contains("qwen2.5-72b") {
286            return Some(128_000);
287        }
288        if m.contains("qwen2") {
289            return Some(32_000);
290        }
291        return Some(32_000);
292    }
293    // Llama models (via OpenAI-compatible endpoints)
294    if m.contains("llama-3") || m.contains("llama3") {
295        if m.contains("70b") || m.contains("405b") {
296            return Some(128_000);
297        }
298        return Some(8_192);
299    }
300    // GLM models (Zhipu AI) via OpenAI-compatible endpoints
301    if m.contains("glm") {
302        return Some(128_000);
303    }
304    // Default fallback for unknown models: assume 128K (reasonable for modern models)
305    // This ensures context usage is always displayed
306    Some(128_000)
307}