Skip to main content

matrixcode_core/providers/
openai.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use serde_json::{Value, json};
4
5use crate::tools::ToolDefinition;
6
7use super::{
8    ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role, StopReason,
9    Usage,
10};
11
12pub struct OpenAIProvider {
13    api_key: String,
14    model: String,
15    base_url: String,
16    client: reqwest::Client,
17}
18
19impl OpenAIProvider {
20    pub fn new(api_key: String, model: String, base_url: String) -> Self {
21        Self {
22            api_key,
23            model,
24            base_url,
25            client: reqwest::Client::new(),
26        }
27    }
28
29    fn convert_messages(&self, messages: &[Message], system: Option<&str>) -> Vec<Value> {
30        let mut result = Vec::new();
31
32        if let Some(sys) = system {
33            result.push(json!({"role": "system", "content": sys}));
34        }
35
36        for msg in messages {
37            match (&msg.role, &msg.content) {
38                (Role::System, _) => {}
39                (Role::User, MessageContent::Text(text)) => {
40                    result.push(json!({"role": "user", "content": text}));
41                }
42                (Role::Assistant, MessageContent::Text(text)) => {
43                    result.push(json!({"role": "assistant", "content": text}));
44                }
45                (Role::Assistant, MessageContent::Blocks(blocks)) => {
46                    let mut tool_calls = Vec::new();
47                    let mut text_parts = Vec::new();
48
49                    for block in blocks {
50                        match block {
51                            ContentBlock::Text { text } => text_parts.push(text.clone()),
52                            ContentBlock::ToolUse { id, name, input } => {
53                                tool_calls.push(json!({
54                                    "id": id,
55                                    "type": "function",
56                                    "function": {
57                                        "name": name,
58                                        "arguments": input.to_string(),
59                                    }
60                                }));
61                            }
62                            ContentBlock::Thinking { .. } => {}
63                            _ => {}
64                        }
65                    }
66
67                    let mut msg_obj = json!({"role": "assistant"});
68                    if !text_parts.is_empty() {
69                        msg_obj["content"] = json!(text_parts.join("\n"));
70                    }
71                    if !tool_calls.is_empty() {
72                        msg_obj["tool_calls"] = json!(tool_calls);
73                    }
74                    result.push(msg_obj);
75                }
76                (Role::Tool, MessageContent::Blocks(blocks)) => {
77                    for block in blocks {
78                        if let ContentBlock::ToolResult { tool_use_id, content } = block {
79                            result.push(json!({
80                                "role": "tool",
81                                "tool_call_id": tool_use_id,
82                                "content": content,
83                            }));
84                        }
85                    }
86                }
87                (Role::User, MessageContent::Blocks(blocks)) => {
88                    let text: String = blocks
89                        .iter()
90                        .filter_map(|b| match b {
91                            ContentBlock::Text { text } => Some(text.as_str()),
92                            _ => None,
93                        })
94                        .collect::<Vec<_>>()
95                        .join("\n");
96                    result.push(json!({"role": "user", "content": text}));
97                }
98                _ => {}
99            }
100        }
101
102        result
103    }
104
105    fn convert_tools(&self, tools: &[ToolDefinition]) -> Vec<Value> {
106        tools
107            .iter()
108            .map(|t| {
109                json!({
110                    "type": "function",
111                    "function": {
112                        "name": t.name,
113                        "description": t.description,
114                        "parameters": t.parameters,
115                    }
116                })
117            })
118            .collect()
119    }
120}
121
122#[async_trait]
123impl Provider for OpenAIProvider {
124    fn context_size(&self) -> Option<u32> {
125        context_window_for(&self.model)
126    }
127
128    fn clone_box(&self) -> Box<dyn Provider> {
129        Box::new(Self {
130            api_key: self.api_key.clone(),
131            model: self.model.clone(),
132            base_url: self.base_url.clone(),
133            client: reqwest::Client::new(),
134        })
135    }
136
137    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
138        let messages = self.convert_messages(&request.messages, request.system.as_deref());
139
140        let mut body = json!({
141            "model": self.model,
142            "messages": messages,
143            "max_completion_tokens": request.max_tokens,
144        });
145
146        if !request.tools.is_empty() {
147            body["tools"] = json!(self.convert_tools(&request.tools));
148        }
149
150        let url = format!("{}/chat/completions", self.base_url);
151        let response = self
152            .client
153            .post(&url)
154            .header("Authorization", format!("Bearer {}", self.api_key))
155            .header("Content-Type", "application/json")
156            .json(&body)
157            .send()
158            .await?;
159
160        let status = response.status();
161        let response_body: Value = response.json().await?;
162
163        if !status.is_success() {
164            let err_msg = response_body["error"]["message"]
165                .as_str()
166                .unwrap_or("unknown error");
167            anyhow::bail!("OpenAI API error ({}): {}", status, err_msg);
168        }
169
170        let choice = &response_body["choices"][0];
171        let message = &choice["message"];
172        let finish_reason = choice["finish_reason"].as_str().unwrap_or("stop");
173
174        let stop_reason = match finish_reason {
175            "tool_calls" => StopReason::ToolUse,
176            "length" => StopReason::MaxTokens,
177            _ => StopReason::EndTurn,
178        };
179
180        let mut content = Vec::new();
181
182        let usage_blob = &response_body["usage"];
183        let usage = Usage {
184            input_tokens: usage_blob["prompt_tokens"].as_u64().unwrap_or(0) as u32,
185            output_tokens: usage_blob["completion_tokens"].as_u64().unwrap_or(0) as u32,
186            cache_creation_input_tokens: 0,
187            cache_read_input_tokens: usage_blob["prompt_tokens_details"]["cached_tokens"]
188                .as_u64()
189                .unwrap_or(0) as u32,
190        };
191
192        if let Some(text) = message["content"].as_str()
193            && !text.is_empty() {
194                content.push(ContentBlock::Text { text: text.to_string() });
195            }
196
197        if let Some(tool_calls) = message["tool_calls"].as_array() {
198            for tc in tool_calls {
199                let id = tc["id"].as_str().unwrap_or_default().to_string();
200                let name = tc["function"]["name"].as_str().unwrap_or_default().to_string();
201                let arguments = tc["function"]["arguments"].as_str().unwrap_or("{}");
202                let input: Value = serde_json::from_str(arguments).unwrap_or(json!({}));
203
204                content.push(ContentBlock::ToolUse { id, name, input });
205            }
206
207            if stop_reason == StopReason::EndTurn && !tool_calls.is_empty() {
208                return Ok(ChatResponse {
209                    content,
210                    stop_reason: StopReason::ToolUse,
211                    usage: usage.clone(),
212                });
213            }
214        }
215
216        Ok(ChatResponse {
217            content,
218            stop_reason,
219            usage,
220        })
221    }
222}
223
224/// Best-effort mapping from an OpenAI model name to its context window size.
225/// Honours the `CONTEXT_SIZE` env variable first so users can override.
226fn context_window_for(model: &str) -> Option<u32> {
227    if let Ok(raw) = std::env::var("CONTEXT_SIZE")
228        && let Ok(n) = raw.trim().parse::<u32>()
229            && n > 0 {
230                return Some(n);
231            }
232    let m = model.to_ascii_lowercase();
233    
234    // GPT-4o models: 128K context
235    if m.contains("gpt-4o") || m.contains("gpt-4-turbo") {
236        return Some(128_000);
237    }
238    // GPT-4 (original): 8K or 32K variants
239    if m.contains("gpt-4-32k") {
240        return Some(32_768);
241    }
242    if m.contains("gpt-4") && !m.contains("turbo") && !m.contains("o") {
243        return Some(8_192);
244    }
245    // GPT-3.5 Turbo: 16K (4K variant is deprecated)
246    if m.contains("gpt-3.5-turbo-16k") {
247        return Some(16_384);
248    }
249    if m.contains("gpt-3.5") {
250        return Some(4_096);
251    }
252    // o1 series: 200K context
253    if m.contains("o1") {
254        return Some(200_000);
255    }
256    // DeepSeek models
257    if m.contains("deepseek") {
258        if m.contains("v3") || m.contains("r1") {
259            return Some(128_000);
260        }
261        return Some(64_000);
262    }
263    // Qwen models (via OpenAI-compatible endpoints)
264    if m.contains("qwen") {
265        if m.contains("qwen-max") || m.contains("qwen2.5-72b") {
266            return Some(128_000);
267        }
268        if m.contains("qwen2") {
269            return Some(32_000);
270        }
271        return Some(32_000);
272    }
273    // Llama models (via OpenAI-compatible endpoints)
274    if m.contains("llama-3") || m.contains("llama3") {
275        if m.contains("70b") || m.contains("405b") {
276            return Some(128_000);
277        }
278        return Some(8_192);
279    }
280    // GLM models (Zhipu AI) via OpenAI-compatible endpoints
281    if m.contains("glm") {
282        return Some(128_000);
283    }
284    // Default fallback for unknown models: assume 128K (reasonable for modern models)
285    // This ensures context usage is always displayed
286    Some(128_000)
287}