Skip to main content

agent_code_lib/llm/
openai.rs

1//! OpenAI Chat Completions provider.
2//!
3//! Handles GPT models and any OpenAI-compatible endpoint (Groq,
4//! Together, Ollama, DeepSeek, OpenRouter, vLLM, LMStudio, etc.).
5//! The only difference between providers is the base URL and auth.
6
7use async_trait::async_trait;
8use futures::StreamExt;
9use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue};
10use tokio::sync::mpsc;
11use tracing::debug;
12
13use super::message::{ContentBlock, Message, StopReason, Usage};
14use super::provider::{Provider, ProviderError, ProviderRequest};
15use super::stream::StreamEvent;
16
17/// OpenAI Chat Completions provider (GPT, Groq, Together, DeepSeek, etc.).
18pub struct OpenAiProvider {
19    http: reqwest::Client,
20    base_url: String,
21    api_key: String,
22}
23
24impl OpenAiProvider {
25    pub fn new(base_url: &str, api_key: &str) -> Self {
26        let http = reqwest::Client::builder()
27            .timeout(std::time::Duration::from_secs(300))
28            .build()
29            .expect("failed to build HTTP client");
30
31        Self {
32            http,
33            base_url: base_url.trim_end_matches('/').to_string(),
34            api_key: api_key.to_string(),
35        }
36    }
37
38    /// Build the request body in OpenAI format.
39    fn build_body(&self, request: &ProviderRequest) -> serde_json::Value {
40        // Convert our messages to OpenAI format.
41        // Key difference: system message goes in the messages array, not separate.
42        let mut messages = Vec::new();
43
44        // System message as first message.
45        if !request.system_prompt.is_empty() {
46            messages.push(serde_json::json!({
47                "role": "system",
48                "content": request.system_prompt,
49            }));
50        }
51
52        // Convert conversation messages.
53        for msg in &request.messages {
54            match msg {
55                Message::User(u) => {
56                    let content = blocks_to_openai_content(&u.content);
57                    messages.push(serde_json::json!({
58                        "role": "user",
59                        "content": content,
60                    }));
61                }
62                Message::Assistant(a) => {
63                    let mut msg_json = serde_json::json!({
64                        "role": "assistant",
65                    });
66
67                    // Check for tool calls.
68                    let tool_calls: Vec<serde_json::Value> = a
69                        .content
70                        .iter()
71                        .filter_map(|b| match b {
72                            ContentBlock::ToolUse { id, name, input } => Some(serde_json::json!({
73                                "id": id,
74                                "type": "function",
75                                "function": {
76                                    "name": name,
77                                    "arguments": serde_json::to_string(input).unwrap_or_default(),
78                                }
79                            })),
80                            _ => None,
81                        })
82                        .collect();
83
84                    // Text content.
85                    let text: String = a
86                        .content
87                        .iter()
88                        .filter_map(|b| match b {
89                            ContentBlock::Text { text } => Some(text.as_str()),
90                            _ => None,
91                        })
92                        .collect::<Vec<_>>()
93                        .join("");
94
95                    // OpenAI requires content to be a string, never null.
96                    msg_json["content"] = serde_json::Value::String(text);
97                    if !tool_calls.is_empty() {
98                        msg_json["tool_calls"] = serde_json::Value::Array(tool_calls);
99                    }
100
101                    messages.push(msg_json);
102                }
103                Message::System(_) => {} // Already handled above.
104            }
105        }
106
107        // Handle tool results (OpenAI uses role: "tool").
108        // We need a second pass to convert our tool_result content blocks.
109        let mut final_messages = Vec::new();
110        for msg in messages {
111            if msg.get("role").and_then(|r| r.as_str()) == Some("user") {
112                // Check if this is actually a tool result message.
113                if let Some(content) = msg.get("content")
114                    && let Some(arr) = content.as_array()
115                {
116                    let mut tool_results = Vec::new();
117                    let mut other_content = Vec::new();
118
119                    for block in arr {
120                        if block.get("type").and_then(|t| t.as_str()) == Some("tool_result") {
121                            tool_results.push(serde_json::json!({
122                                    "role": "tool",
123                                    "tool_call_id": block.get("tool_use_id").and_then(|v| v.as_str()).unwrap_or(""),
124                                    "content": block.get("content").and_then(|v| v.as_str()).unwrap_or(""),
125                                }));
126                        } else {
127                            other_content.push(block.clone());
128                        }
129                    }
130
131                    if !tool_results.is_empty() {
132                        // Emit tool results as separate messages.
133                        for tr in tool_results {
134                            final_messages.push(tr);
135                        }
136                        if !other_content.is_empty() {
137                            let mut m = msg.clone();
138                            m["content"] = serde_json::Value::Array(other_content);
139                            final_messages.push(m);
140                        }
141                        continue;
142                    }
143                }
144            }
145            final_messages.push(msg);
146        }
147
148        // Build tools in OpenAI format.
149        let tools: Vec<serde_json::Value> = request
150            .tools
151            .iter()
152            .map(|t| {
153                serde_json::json!({
154                    "type": "function",
155                    "function": {
156                        "name": t.name,
157                        "description": t.description,
158                        "parameters": t.input_schema,
159                    }
160                })
161            })
162            .collect();
163
164        // Newer models (o1, o3, gpt-5.x) use max_completion_tokens.
165        let model_lower = request.model.to_lowercase();
166        let uses_new_token_param = model_lower.starts_with("o1")
167            || model_lower.starts_with("o3")
168            || model_lower.contains("gpt-5")
169            || model_lower.contains("gpt-4.1");
170
171        let mut body = serde_json::json!({
172            "model": request.model,
173            "messages": final_messages,
174            "stream": true,
175            "stream_options": { "include_usage": true },
176        });
177
178        if uses_new_token_param {
179            body["max_completion_tokens"] = serde_json::json!(request.max_tokens);
180        } else {
181            body["max_tokens"] = serde_json::json!(request.max_tokens);
182        }
183
184        if !tools.is_empty() {
185            body["tools"] = serde_json::Value::Array(tools);
186
187            // Tool choice.
188            use super::provider::ToolChoice;
189            match &request.tool_choice {
190                ToolChoice::Auto => {
191                    body["tool_choice"] = serde_json::json!("auto");
192                }
193                ToolChoice::Any => {
194                    body["tool_choice"] = serde_json::json!("required");
195                }
196                ToolChoice::None => {
197                    body["tool_choice"] = serde_json::json!("none");
198                }
199                ToolChoice::Specific(name) => {
200                    body["tool_choice"] = serde_json::json!({
201                        "type": "function",
202                        "function": { "name": name }
203                    });
204                }
205            }
206        }
207        if let Some(temp) = request.temperature {
208            body["temperature"] = serde_json::json!(temp);
209        }
210
211        body
212    }
213}
214
215#[async_trait]
216impl Provider for OpenAiProvider {
217    fn name(&self) -> &str {
218        "openai"
219    }
220
221    async fn stream(
222        &self,
223        request: &ProviderRequest,
224    ) -> Result<mpsc::Receiver<StreamEvent>, ProviderError> {
225        let url = format!("{}/chat/completions", self.base_url);
226        let body = self.build_body(request);
227
228        let mut headers = HeaderMap::new();
229        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
230        headers.insert(
231            AUTHORIZATION,
232            HeaderValue::from_str(&format!("Bearer {}", self.api_key))
233                .map_err(|e| ProviderError::Auth(e.to_string()))?,
234        );
235
236        debug!("OpenAI request to {url}");
237
238        let response = self
239            .http
240            .post(&url)
241            .headers(headers)
242            .json(&body)
243            .send()
244            .await
245            .map_err(|e| ProviderError::Network(e.to_string()))?;
246
247        let status = response.status();
248        if !status.is_success() {
249            let body_text = response.text().await.unwrap_or_default();
250            return match status.as_u16() {
251                401 | 403 => Err(ProviderError::Auth(body_text)),
252                429 => Err(ProviderError::RateLimited {
253                    retry_after_ms: 1000,
254                }),
255                529 => Err(ProviderError::Overloaded),
256                413 => Err(ProviderError::RequestTooLarge(body_text)),
257                _ => Err(ProviderError::Network(format!("{status}: {body_text}"))),
258            };
259        }
260
261        // Parse OpenAI SSE stream.
262        let (tx, rx) = mpsc::channel(64);
263        tokio::spawn(async move {
264            let mut byte_stream = response.bytes_stream();
265            let mut buffer = String::new();
266            let mut current_tool_id = String::new();
267            let mut current_tool_name = String::new();
268            let mut current_tool_args = String::new();
269            let mut usage = Usage::default();
270            let mut stop_reason: Option<StopReason> = None;
271
272            while let Some(chunk_result) = byte_stream.next().await {
273                let chunk = match chunk_result {
274                    Ok(c) => c,
275                    Err(e) => {
276                        let _ = tx.send(StreamEvent::Error(e.to_string())).await;
277                        break;
278                    }
279                };
280
281                buffer.push_str(&String::from_utf8_lossy(&chunk));
282
283                while let Some(pos) = buffer.find("\n\n") {
284                    let event_text = buffer[..pos].to_string();
285                    buffer = buffer[pos + 2..].to_string();
286
287                    for line in event_text.lines() {
288                        let data = if let Some(d) = line.strip_prefix("data: ") {
289                            d
290                        } else {
291                            continue;
292                        };
293
294                        if data == "[DONE]" {
295                            // Emit any remaining tool call before Done.
296                            if !current_tool_id.is_empty() {
297                                let input: serde_json::Value =
298                                    serde_json::from_str(&current_tool_args).unwrap_or_default();
299                                let _ = tx
300                                    .send(StreamEvent::ContentBlockComplete(
301                                        ContentBlock::ToolUse {
302                                            id: current_tool_id.clone(),
303                                            name: current_tool_name.clone(),
304                                            input,
305                                        },
306                                    ))
307                                    .await;
308                                current_tool_id.clear();
309                                current_tool_name.clear();
310                                current_tool_args.clear();
311                            }
312
313                            let _ = tx
314                                .send(StreamEvent::Done {
315                                    usage: usage.clone(),
316                                    stop_reason: stop_reason.clone().or(Some(StopReason::EndTurn)),
317                                })
318                                .await;
319                            return;
320                        }
321
322                        let parsed: serde_json::Value = match serde_json::from_str(data) {
323                            Ok(v) => v,
324                            Err(_) => continue,
325                        };
326
327                        // Extract delta from choices[0].delta
328                        let delta = match parsed
329                            .get("choices")
330                            .and_then(|c| c.get(0))
331                            .and_then(|c| c.get("delta"))
332                        {
333                            Some(d) => d,
334                            None => {
335                                // Check for usage in the final chunk.
336                                if let Some(u) = parsed.get("usage") {
337                                    usage.input_tokens = u
338                                        .get("prompt_tokens")
339                                        .and_then(|v| v.as_u64())
340                                        .unwrap_or(0);
341                                    usage.output_tokens = u
342                                        .get("completion_tokens")
343                                        .and_then(|v| v.as_u64())
344                                        .unwrap_or(0);
345                                }
346                                continue;
347                            }
348                        };
349
350                        // Text content.
351                        if let Some(content) = delta.get("content").and_then(|c| c.as_str())
352                            && !content.is_empty()
353                        {
354                            debug!("OpenAI text delta: {}", &content[..content.len().min(80)]);
355                            let _ = tx.send(StreamEvent::TextDelta(content.to_string())).await;
356                        }
357
358                        // Check for finish_reason on the choice level.
359                        if let Some(finish) = parsed
360                            .get("choices")
361                            .and_then(|c| c.get(0))
362                            .and_then(|c| c.get("finish_reason"))
363                            .and_then(|f| f.as_str())
364                        {
365                            debug!("OpenAI finish_reason: {finish}");
366                            match finish {
367                                "stop" => {
368                                    stop_reason = Some(StopReason::EndTurn);
369                                }
370                                "tool_calls" => {
371                                    stop_reason = Some(StopReason::ToolUse);
372                                }
373                                "length" => {
374                                    stop_reason = Some(StopReason::MaxTokens);
375                                }
376                                _ => {}
377                            }
378                        }
379
380                        // Tool calls.
381                        if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array())
382                        {
383                            for tc in tool_calls {
384                                if let Some(func) = tc.get("function") {
385                                    if let Some(name) = func.get("name").and_then(|n| n.as_str()) {
386                                        // New tool call starting.
387                                        if !current_tool_id.is_empty()
388                                            && !current_tool_args.is_empty()
389                                        {
390                                            // Emit the previous tool call.
391                                            let input: serde_json::Value =
392                                                serde_json::from_str(&current_tool_args)
393                                                    .unwrap_or_default();
394                                            let _ = tx
395                                                .send(StreamEvent::ContentBlockComplete(
396                                                    ContentBlock::ToolUse {
397                                                        id: current_tool_id.clone(),
398                                                        name: current_tool_name.clone(),
399                                                        input,
400                                                    },
401                                                ))
402                                                .await;
403                                        }
404                                        current_tool_id = tc
405                                            .get("id")
406                                            .and_then(|i| i.as_str())
407                                            .unwrap_or("")
408                                            .to_string();
409                                        current_tool_name = name.to_string();
410                                        current_tool_args.clear();
411                                    }
412                                    if let Some(args) =
413                                        func.get("arguments").and_then(|a| a.as_str())
414                                    {
415                                        current_tool_args.push_str(args);
416                                    }
417                                }
418                            }
419                        }
420                    }
421                }
422            }
423
424            // Emit any remaining tool call.
425            if !current_tool_id.is_empty() {
426                let input: serde_json::Value =
427                    serde_json::from_str(&current_tool_args).unwrap_or_default();
428                let _ = tx
429                    .send(StreamEvent::ContentBlockComplete(ContentBlock::ToolUse {
430                        id: current_tool_id,
431                        name: current_tool_name,
432                        input,
433                    }))
434                    .await;
435            }
436
437            let _ = tx
438                .send(StreamEvent::Done {
439                    usage,
440                    stop_reason: Some(StopReason::EndTurn),
441                })
442                .await;
443        });
444
445        Ok(rx)
446    }
447}
448
449/// Convert content blocks to OpenAI format.
450fn blocks_to_openai_content(blocks: &[ContentBlock]) -> serde_json::Value {
451    if blocks.len() == 1
452        && let ContentBlock::Text { text } = &blocks[0]
453    {
454        return serde_json::Value::String(text.clone());
455    }
456
457    let parts: Vec<serde_json::Value> = blocks
458        .iter()
459        .map(|b| match b {
460            ContentBlock::Text { text } => serde_json::json!({
461                "type": "text",
462                "text": text,
463            }),
464            ContentBlock::Image { media_type, data } => serde_json::json!({
465                "type": "image_url",
466                "image_url": {
467                    "url": format!("data:{media_type};base64,{data}"),
468                }
469            }),
470            ContentBlock::ToolResult {
471                tool_use_id,
472                content,
473                is_error,
474                ..
475            } => serde_json::json!({
476                "type": "tool_result",
477                "tool_use_id": tool_use_id,
478                "content": content,
479                "is_error": is_error,
480            }),
481            ContentBlock::Thinking { thinking, .. } => serde_json::json!({
482                "type": "text",
483                "text": thinking,
484            }),
485            ContentBlock::ToolUse { name, input, .. } => serde_json::json!({
486                "type": "text",
487                "text": format!("[Tool call: {name}({input})]"),
488            }),
489            ContentBlock::Document { title, .. } => serde_json::json!({
490                "type": "text",
491                "text": format!("[Document: {}]", title.as_deref().unwrap_or("untitled")),
492            }),
493        })
494        .collect();
495
496    serde_json::Value::Array(parts)
497}