Skip to main content

agent_code_lib/llm/
openai.rs

1//! OpenAI Chat Completions provider.
2//!
3//! Handles GPT models and any OpenAI-compatible endpoint (Groq,
4//! Together, Ollama, DeepSeek, OpenRouter, vLLM, LMStudio, etc.).
5//! The only difference between providers is the base URL and auth.
6
7use async_trait::async_trait;
8use futures::StreamExt;
9use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue};
10use tokio::sync::mpsc;
11use tracing::debug;
12
13use super::message::{ContentBlock, Message, StopReason, Usage};
14use super::provider::{Provider, ProviderError, ProviderRequest};
15use super::stream::StreamEvent;
16
17pub struct OpenAiProvider {
18    http: reqwest::Client,
19    base_url: String,
20    api_key: String,
21}
22
23impl OpenAiProvider {
24    pub fn new(base_url: &str, api_key: &str) -> Self {
25        let http = reqwest::Client::builder()
26            .timeout(std::time::Duration::from_secs(300))
27            .build()
28            .expect("failed to build HTTP client");
29
30        Self {
31            http,
32            base_url: base_url.trim_end_matches('/').to_string(),
33            api_key: api_key.to_string(),
34        }
35    }
36
37    /// Build the request body in OpenAI format.
38    fn build_body(&self, request: &ProviderRequest) -> serde_json::Value {
39        // Convert our messages to OpenAI format.
40        // Key difference: system message goes in the messages array, not separate.
41        let mut messages = Vec::new();
42
43        // System message as first message.
44        if !request.system_prompt.is_empty() {
45            messages.push(serde_json::json!({
46                "role": "system",
47                "content": request.system_prompt,
48            }));
49        }
50
51        // Convert conversation messages.
52        for msg in &request.messages {
53            match msg {
54                Message::User(u) => {
55                    let content = blocks_to_openai_content(&u.content);
56                    messages.push(serde_json::json!({
57                        "role": "user",
58                        "content": content,
59                    }));
60                }
61                Message::Assistant(a) => {
62                    let mut msg_json = serde_json::json!({
63                        "role": "assistant",
64                    });
65
66                    // Check for tool calls.
67                    let tool_calls: Vec<serde_json::Value> = a
68                        .content
69                        .iter()
70                        .filter_map(|b| match b {
71                            ContentBlock::ToolUse { id, name, input } => Some(serde_json::json!({
72                                "id": id,
73                                "type": "function",
74                                "function": {
75                                    "name": name,
76                                    "arguments": serde_json::to_string(input).unwrap_or_default(),
77                                }
78                            })),
79                            _ => None,
80                        })
81                        .collect();
82
83                    // Text content.
84                    let text: String = a
85                        .content
86                        .iter()
87                        .filter_map(|b| match b {
88                            ContentBlock::Text { text } => Some(text.as_str()),
89                            _ => None,
90                        })
91                        .collect::<Vec<_>>()
92                        .join("");
93
94                    // OpenAI requires content to be a string, never null.
95                    msg_json["content"] = serde_json::Value::String(text);
96                    if !tool_calls.is_empty() {
97                        msg_json["tool_calls"] = serde_json::Value::Array(tool_calls);
98                    }
99
100                    messages.push(msg_json);
101                }
102                Message::System(_) => {} // Already handled above.
103            }
104        }
105
106        // Handle tool results (OpenAI uses role: "tool").
107        // We need a second pass to convert our tool_result content blocks.
108        let mut final_messages = Vec::new();
109        for msg in messages {
110            if msg.get("role").and_then(|r| r.as_str()) == Some("user") {
111                // Check if this is actually a tool result message.
112                if let Some(content) = msg.get("content")
113                    && let Some(arr) = content.as_array()
114                {
115                    let mut tool_results = Vec::new();
116                    let mut other_content = Vec::new();
117
118                    for block in arr {
119                        if block.get("type").and_then(|t| t.as_str()) == Some("tool_result") {
120                            tool_results.push(serde_json::json!({
121                                    "role": "tool",
122                                    "tool_call_id": block.get("tool_use_id").and_then(|v| v.as_str()).unwrap_or(""),
123                                    "content": block.get("content").and_then(|v| v.as_str()).unwrap_or(""),
124                                }));
125                        } else {
126                            other_content.push(block.clone());
127                        }
128                    }
129
130                    if !tool_results.is_empty() {
131                        // Emit tool results as separate messages.
132                        for tr in tool_results {
133                            final_messages.push(tr);
134                        }
135                        if !other_content.is_empty() {
136                            let mut m = msg.clone();
137                            m["content"] = serde_json::Value::Array(other_content);
138                            final_messages.push(m);
139                        }
140                        continue;
141                    }
142                }
143            }
144            final_messages.push(msg);
145        }
146
147        // Build tools in OpenAI format.
148        let tools: Vec<serde_json::Value> = request
149            .tools
150            .iter()
151            .map(|t| {
152                serde_json::json!({
153                    "type": "function",
154                    "function": {
155                        "name": t.name,
156                        "description": t.description,
157                        "parameters": t.input_schema,
158                    }
159                })
160            })
161            .collect();
162
163        // Newer models (o1, o3, gpt-5.x) use max_completion_tokens.
164        let model_lower = request.model.to_lowercase();
165        let uses_new_token_param = model_lower.starts_with("o1")
166            || model_lower.starts_with("o3")
167            || model_lower.contains("gpt-5")
168            || model_lower.contains("gpt-4.1");
169
170        let mut body = serde_json::json!({
171            "model": request.model,
172            "messages": final_messages,
173            "stream": true,
174            "stream_options": { "include_usage": true },
175        });
176
177        if uses_new_token_param {
178            body["max_completion_tokens"] = serde_json::json!(request.max_tokens);
179        } else {
180            body["max_tokens"] = serde_json::json!(request.max_tokens);
181        }
182
183        if !tools.is_empty() {
184            body["tools"] = serde_json::Value::Array(tools);
185
186            // Tool choice.
187            use super::provider::ToolChoice;
188            match &request.tool_choice {
189                ToolChoice::Auto => {
190                    body["tool_choice"] = serde_json::json!("auto");
191                }
192                ToolChoice::Any => {
193                    body["tool_choice"] = serde_json::json!("required");
194                }
195                ToolChoice::None => {
196                    body["tool_choice"] = serde_json::json!("none");
197                }
198                ToolChoice::Specific(name) => {
199                    body["tool_choice"] = serde_json::json!({
200                        "type": "function",
201                        "function": { "name": name }
202                    });
203                }
204            }
205        }
206        if let Some(temp) = request.temperature {
207            body["temperature"] = serde_json::json!(temp);
208        }
209
210        body
211    }
212}
213
214#[async_trait]
215impl Provider for OpenAiProvider {
216    fn name(&self) -> &str {
217        "openai"
218    }
219
220    async fn stream(
221        &self,
222        request: &ProviderRequest,
223    ) -> Result<mpsc::Receiver<StreamEvent>, ProviderError> {
224        let url = format!("{}/chat/completions", self.base_url);
225        let body = self.build_body(request);
226
227        let mut headers = HeaderMap::new();
228        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
229        headers.insert(
230            AUTHORIZATION,
231            HeaderValue::from_str(&format!("Bearer {}", self.api_key))
232                .map_err(|e| ProviderError::Auth(e.to_string()))?,
233        );
234
235        debug!("OpenAI request to {url}");
236
237        let response = self
238            .http
239            .post(&url)
240            .headers(headers)
241            .json(&body)
242            .send()
243            .await
244            .map_err(|e| ProviderError::Network(e.to_string()))?;
245
246        let status = response.status();
247        if !status.is_success() {
248            let body_text = response.text().await.unwrap_or_default();
249            return match status.as_u16() {
250                401 | 403 => Err(ProviderError::Auth(body_text)),
251                429 => Err(ProviderError::RateLimited {
252                    retry_after_ms: 1000,
253                }),
254                529 => Err(ProviderError::Overloaded),
255                413 => Err(ProviderError::RequestTooLarge(body_text)),
256                _ => Err(ProviderError::Network(format!("{status}: {body_text}"))),
257            };
258        }
259
260        // Parse OpenAI SSE stream.
261        let (tx, rx) = mpsc::channel(64);
262        tokio::spawn(async move {
263            let mut byte_stream = response.bytes_stream();
264            let mut buffer = String::new();
265            let mut current_tool_id = String::new();
266            let mut current_tool_name = String::new();
267            let mut current_tool_args = String::new();
268            let mut usage = Usage::default();
269            let mut stop_reason: Option<StopReason> = None;
270
271            while let Some(chunk_result) = byte_stream.next().await {
272                let chunk = match chunk_result {
273                    Ok(c) => c,
274                    Err(e) => {
275                        let _ = tx.send(StreamEvent::Error(e.to_string())).await;
276                        break;
277                    }
278                };
279
280                buffer.push_str(&String::from_utf8_lossy(&chunk));
281
282                while let Some(pos) = buffer.find("\n\n") {
283                    let event_text = buffer[..pos].to_string();
284                    buffer = buffer[pos + 2..].to_string();
285
286                    for line in event_text.lines() {
287                        let data = if let Some(d) = line.strip_prefix("data: ") {
288                            d
289                        } else {
290                            continue;
291                        };
292
293                        if data == "[DONE]" {
294                            // Emit any remaining tool call before Done.
295                            if !current_tool_id.is_empty() {
296                                let input: serde_json::Value =
297                                    serde_json::from_str(&current_tool_args).unwrap_or_default();
298                                let _ = tx
299                                    .send(StreamEvent::ContentBlockComplete(
300                                        ContentBlock::ToolUse {
301                                            id: current_tool_id.clone(),
302                                            name: current_tool_name.clone(),
303                                            input,
304                                        },
305                                    ))
306                                    .await;
307                                current_tool_id.clear();
308                                current_tool_name.clear();
309                                current_tool_args.clear();
310                            }
311
312                            let _ = tx
313                                .send(StreamEvent::Done {
314                                    usage: usage.clone(),
315                                    stop_reason: stop_reason.clone().or(Some(StopReason::EndTurn)),
316                                })
317                                .await;
318                            return;
319                        }
320
321                        let parsed: serde_json::Value = match serde_json::from_str(data) {
322                            Ok(v) => v,
323                            Err(_) => continue,
324                        };
325
326                        // Extract delta from choices[0].delta
327                        let delta = match parsed
328                            .get("choices")
329                            .and_then(|c| c.get(0))
330                            .and_then(|c| c.get("delta"))
331                        {
332                            Some(d) => d,
333                            None => {
334                                // Check for usage in the final chunk.
335                                if let Some(u) = parsed.get("usage") {
336                                    usage.input_tokens = u
337                                        .get("prompt_tokens")
338                                        .and_then(|v| v.as_u64())
339                                        .unwrap_or(0);
340                                    usage.output_tokens = u
341                                        .get("completion_tokens")
342                                        .and_then(|v| v.as_u64())
343                                        .unwrap_or(0);
344                                }
345                                continue;
346                            }
347                        };
348
349                        // Text content.
350                        if let Some(content) = delta.get("content").and_then(|c| c.as_str())
351                            && !content.is_empty()
352                        {
353                            debug!("OpenAI text delta: {}", &content[..content.len().min(80)]);
354                            let _ = tx.send(StreamEvent::TextDelta(content.to_string())).await;
355                        }
356
357                        // Check for finish_reason on the choice level.
358                        if let Some(finish) = parsed
359                            .get("choices")
360                            .and_then(|c| c.get(0))
361                            .and_then(|c| c.get("finish_reason"))
362                            .and_then(|f| f.as_str())
363                        {
364                            debug!("OpenAI finish_reason: {finish}");
365                            match finish {
366                                "stop" => {
367                                    stop_reason = Some(StopReason::EndTurn);
368                                }
369                                "tool_calls" => {
370                                    stop_reason = Some(StopReason::ToolUse);
371                                }
372                                "length" => {
373                                    stop_reason = Some(StopReason::MaxTokens);
374                                }
375                                _ => {}
376                            }
377                        }
378
379                        // Tool calls.
380                        if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array())
381                        {
382                            for tc in tool_calls {
383                                if let Some(func) = tc.get("function") {
384                                    if let Some(name) = func.get("name").and_then(|n| n.as_str()) {
385                                        // New tool call starting.
386                                        if !current_tool_id.is_empty()
387                                            && !current_tool_args.is_empty()
388                                        {
389                                            // Emit the previous tool call.
390                                            let input: serde_json::Value =
391                                                serde_json::from_str(&current_tool_args)
392                                                    .unwrap_or_default();
393                                            let _ = tx
394                                                .send(StreamEvent::ContentBlockComplete(
395                                                    ContentBlock::ToolUse {
396                                                        id: current_tool_id.clone(),
397                                                        name: current_tool_name.clone(),
398                                                        input,
399                                                    },
400                                                ))
401                                                .await;
402                                        }
403                                        current_tool_id = tc
404                                            .get("id")
405                                            .and_then(|i| i.as_str())
406                                            .unwrap_or("")
407                                            .to_string();
408                                        current_tool_name = name.to_string();
409                                        current_tool_args.clear();
410                                    }
411                                    if let Some(args) =
412                                        func.get("arguments").and_then(|a| a.as_str())
413                                    {
414                                        current_tool_args.push_str(args);
415                                    }
416                                }
417                            }
418                        }
419                    }
420                }
421            }
422
423            // Emit any remaining tool call.
424            if !current_tool_id.is_empty() {
425                let input: serde_json::Value =
426                    serde_json::from_str(&current_tool_args).unwrap_or_default();
427                let _ = tx
428                    .send(StreamEvent::ContentBlockComplete(ContentBlock::ToolUse {
429                        id: current_tool_id,
430                        name: current_tool_name,
431                        input,
432                    }))
433                    .await;
434            }
435
436            let _ = tx
437                .send(StreamEvent::Done {
438                    usage,
439                    stop_reason: Some(StopReason::EndTurn),
440                })
441                .await;
442        });
443
444        Ok(rx)
445    }
446}
447
448/// Convert content blocks to OpenAI format.
449fn blocks_to_openai_content(blocks: &[ContentBlock]) -> serde_json::Value {
450    if blocks.len() == 1
451        && let ContentBlock::Text { text } = &blocks[0]
452    {
453        return serde_json::Value::String(text.clone());
454    }
455
456    let parts: Vec<serde_json::Value> = blocks
457        .iter()
458        .map(|b| match b {
459            ContentBlock::Text { text } => serde_json::json!({
460                "type": "text",
461                "text": text,
462            }),
463            ContentBlock::Image { media_type, data } => serde_json::json!({
464                "type": "image_url",
465                "image_url": {
466                    "url": format!("data:{media_type};base64,{data}"),
467                }
468            }),
469            ContentBlock::ToolResult {
470                tool_use_id,
471                content,
472                is_error,
473                ..
474            } => serde_json::json!({
475                "type": "tool_result",
476                "tool_use_id": tool_use_id,
477                "content": content,
478                "is_error": is_error,
479            }),
480            ContentBlock::Thinking { thinking, .. } => serde_json::json!({
481                "type": "text",
482                "text": thinking,
483            }),
484            ContentBlock::ToolUse { name, input, .. } => serde_json::json!({
485                "type": "text",
486                "text": format!("[Tool call: {name}({input})]"),
487            }),
488            ContentBlock::Document { title, .. } => serde_json::json!({
489                "type": "text",
490                "text": format!("[Document: {}]", title.as_deref().unwrap_or("untitled")),
491            }),
492        })
493        .collect();
494
495    serde_json::Value::Array(parts)
496}