Skip to main content

agent_code_lib/llm/
openai.rs

1//! OpenAI Chat Completions provider.
2//!
3//! Handles GPT models and any OpenAI-compatible endpoint (Groq,
4//! Together, Ollama, DeepSeek, OpenRouter, vLLM, LMStudio, etc.).
5//! The only difference between providers is the base URL and auth.
6
7use async_trait::async_trait;
8use futures::StreamExt;
9use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue};
10use tokio::sync::mpsc;
11use tracing::debug;
12
13use super::message::{ContentBlock, Message, StopReason, Usage};
14use super::provider::{Provider, ProviderError, ProviderRequest};
15use super::stream::StreamEvent;
16
17/// OpenAI Chat Completions provider (GPT, Groq, Together, DeepSeek, etc.).
18pub struct OpenAiProvider {
19    http: reqwest::Client,
20    base_url: String,
21    api_key: String,
22}
23
24impl OpenAiProvider {
25    pub fn new(base_url: &str, api_key: &str) -> Self {
26        let http = reqwest::Client::builder()
27            .timeout(std::time::Duration::from_secs(300))
28            .build()
29            .expect("failed to build HTTP client");
30
31        Self {
32            http,
33            base_url: base_url.trim_end_matches('/').to_string(),
34            api_key: api_key.to_string(),
35        }
36    }
37
38    /// Build the request body in OpenAI format.
39    fn build_body(&self, request: &ProviderRequest) -> serde_json::Value {
40        // Convert our messages to OpenAI format.
41        // Key difference: system message goes in the messages array, not separate.
42        let mut messages = Vec::new();
43
44        // System message as first message.
45        if !request.system_prompt.is_empty() {
46            messages.push(serde_json::json!({
47                "role": "system",
48                "content": request.system_prompt,
49            }));
50        }
51
52        // Convert conversation messages.
53        for msg in &request.messages {
54            match msg {
55                Message::User(u) => {
56                    let content = blocks_to_openai_content(&u.content);
57                    messages.push(serde_json::json!({
58                        "role": "user",
59                        "content": content,
60                    }));
61                }
62                Message::Assistant(a) => {
63                    let mut msg_json = serde_json::json!({
64                        "role": "assistant",
65                    });
66
67                    // Check for tool calls.
68                    let tool_calls: Vec<serde_json::Value> = a
69                        .content
70                        .iter()
71                        .filter_map(|b| match b {
72                            ContentBlock::ToolUse { id, name, input } => Some(serde_json::json!({
73                                "id": id,
74                                "type": "function",
75                                "function": {
76                                    "name": name,
77                                    "arguments": serde_json::to_string(input).unwrap_or_default(),
78                                }
79                            })),
80                            _ => None,
81                        })
82                        .collect();
83
84                    // Text content.
85                    let text: String = a
86                        .content
87                        .iter()
88                        .filter_map(|b| match b {
89                            ContentBlock::Text { text } => Some(text.as_str()),
90                            _ => None,
91                        })
92                        .collect::<Vec<_>>()
93                        .join("");
94
95                    // OpenAI requires content to be a string, never null.
96                    msg_json["content"] = serde_json::Value::String(text);
97                    if !tool_calls.is_empty() {
98                        msg_json["tool_calls"] = serde_json::Value::Array(tool_calls);
99                    }
100
101                    messages.push(msg_json);
102                }
103                Message::System(_) => {} // Already handled above.
104            }
105        }
106
107        // Handle tool results (OpenAI uses role: "tool").
108        // We need a second pass to convert our tool_result content blocks.
109        let mut final_messages = Vec::new();
110        for msg in messages {
111            if msg.get("role").and_then(|r| r.as_str()) == Some("user") {
112                // Check if this is actually a tool result message.
113                if let Some(content) = msg.get("content")
114                    && let Some(arr) = content.as_array()
115                {
116                    let mut tool_results = Vec::new();
117                    let mut other_content = Vec::new();
118
119                    for block in arr {
120                        if block.get("type").and_then(|t| t.as_str()) == Some("tool_result") {
121                            tool_results.push(serde_json::json!({
122                                    "role": "tool",
123                                    "tool_call_id": block.get("tool_use_id").and_then(|v| v.as_str()).unwrap_or(""),
124                                    "content": block.get("content").and_then(|v| v.as_str()).unwrap_or(""),
125                                }));
126                        } else {
127                            other_content.push(block.clone());
128                        }
129                    }
130
131                    if !tool_results.is_empty() {
132                        // Emit tool results as separate messages.
133                        for tr in tool_results {
134                            final_messages.push(tr);
135                        }
136                        if !other_content.is_empty() {
137                            let mut m = msg.clone();
138                            m["content"] = serde_json::Value::Array(other_content);
139                            final_messages.push(m);
140                        }
141                        continue;
142                    }
143                }
144            }
145            final_messages.push(msg);
146        }
147
148        // Build tools in OpenAI format.
149        let tools: Vec<serde_json::Value> = request
150            .tools
151            .iter()
152            .map(|t| {
153                serde_json::json!({
154                    "type": "function",
155                    "function": {
156                        "name": t.name,
157                        "description": t.description,
158                        "parameters": t.input_schema,
159                    }
160                })
161            })
162            .collect();
163
164        // Newer models (o1, o3, gpt-5.x) use max_completion_tokens.
165        let model_lower = request.model.to_lowercase();
166        let uses_new_token_param = model_lower.starts_with("o1")
167            || model_lower.starts_with("o3")
168            || model_lower.contains("gpt-5")
169            || model_lower.contains("gpt-4.1");
170
171        let mut body = serde_json::json!({
172            "model": request.model,
173            "messages": final_messages,
174            "stream": true,
175            "stream_options": { "include_usage": true },
176        });
177
178        if uses_new_token_param {
179            body["max_completion_tokens"] = serde_json::json!(request.max_tokens);
180        } else {
181            body["max_tokens"] = serde_json::json!(request.max_tokens);
182        }
183
184        if !tools.is_empty() {
185            body["tools"] = serde_json::Value::Array(tools);
186
187            // Tool choice.
188            use super::provider::ToolChoice;
189            match &request.tool_choice {
190                ToolChoice::Auto => {
191                    body["tool_choice"] = serde_json::json!("auto");
192                }
193                ToolChoice::Any => {
194                    body["tool_choice"] = serde_json::json!("required");
195                }
196                ToolChoice::None => {
197                    body["tool_choice"] = serde_json::json!("none");
198                }
199                ToolChoice::Specific(name) => {
200                    body["tool_choice"] = serde_json::json!({
201                        "type": "function",
202                        "function": { "name": name }
203                    });
204                }
205            }
206        }
207        if let Some(temp) = request.temperature {
208            body["temperature"] = serde_json::json!(temp);
209        }
210
211        body
212    }
213}
214
215#[async_trait]
216impl Provider for OpenAiProvider {
217    fn name(&self) -> &str {
218        "openai"
219    }
220
221    async fn stream(
222        &self,
223        request: &ProviderRequest,
224    ) -> Result<mpsc::Receiver<StreamEvent>, ProviderError> {
225        let url = format!("{}/chat/completions", self.base_url);
226        let body = self.build_body(request);
227
228        let mut headers = HeaderMap::new();
229        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
230        headers.insert(
231            AUTHORIZATION,
232            HeaderValue::from_str(&format!("Bearer {}", self.api_key))
233                .map_err(|e| ProviderError::Auth(e.to_string()))?,
234        );
235
236        debug!("OpenAI request to {url}");
237
238        let response = self
239            .http
240            .post(&url)
241            .headers(headers)
242            .json(&body)
243            .send()
244            .await
245            .map_err(|e| ProviderError::Network(e.to_string()))?;
246
247        let status = response.status();
248        if !status.is_success() {
249            let body_text = response.text().await.unwrap_or_default();
250            return match status.as_u16() {
251                401 | 403 => Err(ProviderError::Auth(body_text)),
252                429 => Err(ProviderError::RateLimited {
253                    retry_after_ms: 1000,
254                }),
255                529 => Err(ProviderError::Overloaded),
256                413 => Err(ProviderError::RequestTooLarge(body_text)),
257                _ => Err(ProviderError::Network(format!("{status}: {body_text}"))),
258            };
259        }
260
261        // Parse OpenAI SSE stream.
262        let (tx, rx) = mpsc::channel(64);
263        let cancel = request.cancel.clone();
264        tokio::spawn(async move {
265            let mut byte_stream = response.bytes_stream();
266            let mut buffer = String::new();
267            let mut current_tool_id = String::new();
268            let mut current_tool_name = String::new();
269            let mut current_tool_args = String::new();
270            let mut usage = Usage::default();
271            let mut stop_reason: Option<StopReason> = None;
272
273            loop {
274                // Race the next SSE chunk against cancellation. On cancel,
275                // drop the byte_stream (and therefore the reqwest::Response),
276                // which aborts the underlying HTTP connection immediately.
277                let chunk_result = tokio::select! {
278                    biased;
279                    _ = cancel.cancelled() => return,
280                    chunk = byte_stream.next() => match chunk {
281                        Some(c) => c,
282                        None => break,
283                    },
284                };
285                let chunk = match chunk_result {
286                    Ok(c) => c,
287                    Err(e) => {
288                        let _ = tx.send(StreamEvent::Error(e.to_string())).await;
289                        break;
290                    }
291                };
292
293                buffer.push_str(&String::from_utf8_lossy(&chunk));
294
295                while let Some(pos) = buffer.find("\n\n") {
296                    let event_text = buffer[..pos].to_string();
297                    buffer = buffer[pos + 2..].to_string();
298
299                    for line in event_text.lines() {
300                        let data = if let Some(d) = line.strip_prefix("data: ") {
301                            d
302                        } else {
303                            continue;
304                        };
305
306                        if data == "[DONE]" {
307                            // Emit any remaining tool call before Done.
308                            if !current_tool_id.is_empty() {
309                                let input: serde_json::Value =
310                                    serde_json::from_str(&current_tool_args).unwrap_or_default();
311                                let _ = tx
312                                    .send(StreamEvent::ContentBlockComplete(
313                                        ContentBlock::ToolUse {
314                                            id: current_tool_id.clone(),
315                                            name: current_tool_name.clone(),
316                                            input,
317                                        },
318                                    ))
319                                    .await;
320                                current_tool_id.clear();
321                                current_tool_name.clear();
322                                current_tool_args.clear();
323                            }
324
325                            let _ = tx
326                                .send(StreamEvent::Done {
327                                    usage: usage.clone(),
328                                    stop_reason: stop_reason.clone().or(Some(StopReason::EndTurn)),
329                                })
330                                .await;
331                            return;
332                        }
333
334                        let parsed: serde_json::Value = match serde_json::from_str(data) {
335                            Ok(v) => v,
336                            Err(_) => continue,
337                        };
338
339                        // Extract delta from choices[0].delta
340                        let delta = match parsed
341                            .get("choices")
342                            .and_then(|c| c.get(0))
343                            .and_then(|c| c.get("delta"))
344                        {
345                            Some(d) => d,
346                            None => {
347                                // Check for usage in the final chunk.
348                                if let Some(u) = parsed.get("usage") {
349                                    usage.input_tokens = u
350                                        .get("prompt_tokens")
351                                        .and_then(|v| v.as_u64())
352                                        .unwrap_or(0);
353                                    usage.output_tokens = u
354                                        .get("completion_tokens")
355                                        .and_then(|v| v.as_u64())
356                                        .unwrap_or(0);
357                                }
358                                continue;
359                            }
360                        };
361
362                        // Text content.
363                        if let Some(content) = delta.get("content").and_then(|c| c.as_str())
364                            && !content.is_empty()
365                        {
366                            debug!("OpenAI text delta: {}", &content[..content.len().min(80)]);
367                            let _ = tx.send(StreamEvent::TextDelta(content.to_string())).await;
368                        }
369
370                        // Check for finish_reason on the choice level.
371                        if let Some(finish) = parsed
372                            .get("choices")
373                            .and_then(|c| c.get(0))
374                            .and_then(|c| c.get("finish_reason"))
375                            .and_then(|f| f.as_str())
376                        {
377                            debug!("OpenAI finish_reason: {finish}");
378                            match finish {
379                                "stop" => {
380                                    stop_reason = Some(StopReason::EndTurn);
381                                }
382                                "tool_calls" => {
383                                    stop_reason = Some(StopReason::ToolUse);
384                                }
385                                "length" => {
386                                    stop_reason = Some(StopReason::MaxTokens);
387                                }
388                                _ => {}
389                            }
390                        }
391
392                        // Tool calls.
393                        if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array())
394                        {
395                            for tc in tool_calls {
396                                if let Some(func) = tc.get("function") {
397                                    if let Some(name) = func.get("name").and_then(|n| n.as_str()) {
398                                        // New tool call starting.
399                                        if !current_tool_id.is_empty()
400                                            && !current_tool_args.is_empty()
401                                        {
402                                            // Emit the previous tool call.
403                                            let input: serde_json::Value =
404                                                serde_json::from_str(&current_tool_args)
405                                                    .unwrap_or_default();
406                                            let _ = tx
407                                                .send(StreamEvent::ContentBlockComplete(
408                                                    ContentBlock::ToolUse {
409                                                        id: current_tool_id.clone(),
410                                                        name: current_tool_name.clone(),
411                                                        input,
412                                                    },
413                                                ))
414                                                .await;
415                                        }
416                                        current_tool_id = tc
417                                            .get("id")
418                                            .and_then(|i| i.as_str())
419                                            .unwrap_or("")
420                                            .to_string();
421                                        current_tool_name = name.to_string();
422                                        current_tool_args.clear();
423                                    }
424                                    if let Some(args) =
425                                        func.get("arguments").and_then(|a| a.as_str())
426                                    {
427                                        current_tool_args.push_str(args);
428                                    }
429                                }
430                            }
431                        }
432                    }
433                }
434            }
435
436            // Emit any remaining tool call.
437            if !current_tool_id.is_empty() {
438                let input: serde_json::Value =
439                    serde_json::from_str(&current_tool_args).unwrap_or_default();
440                let _ = tx
441                    .send(StreamEvent::ContentBlockComplete(ContentBlock::ToolUse {
442                        id: current_tool_id,
443                        name: current_tool_name,
444                        input,
445                    }))
446                    .await;
447            }
448
449            let _ = tx
450                .send(StreamEvent::Done {
451                    usage,
452                    stop_reason: Some(StopReason::EndTurn),
453                })
454                .await;
455        });
456
457        Ok(rx)
458    }
459}
460
461/// Convert content blocks to OpenAI format.
462fn blocks_to_openai_content(blocks: &[ContentBlock]) -> serde_json::Value {
463    if blocks.len() == 1
464        && let ContentBlock::Text { text } = &blocks[0]
465    {
466        return serde_json::Value::String(text.clone());
467    }
468
469    let parts: Vec<serde_json::Value> = blocks
470        .iter()
471        .map(|b| match b {
472            ContentBlock::Text { text } => serde_json::json!({
473                "type": "text",
474                "text": text,
475            }),
476            ContentBlock::Image { media_type, data } => serde_json::json!({
477                "type": "image_url",
478                "image_url": {
479                    "url": format!("data:{media_type};base64,{data}"),
480                }
481            }),
482            ContentBlock::ToolResult {
483                tool_use_id,
484                content,
485                is_error,
486                ..
487            } => serde_json::json!({
488                "type": "tool_result",
489                "tool_use_id": tool_use_id,
490                "content": content,
491                "is_error": is_error,
492            }),
493            ContentBlock::Thinking { thinking, .. } => serde_json::json!({
494                "type": "text",
495                "text": thinking,
496            }),
497            ContentBlock::ToolUse { name, input, .. } => serde_json::json!({
498                "type": "text",
499                "text": format!("[Tool call: {name}({input})]"),
500            }),
501            ContentBlock::Document { title, .. } => serde_json::json!({
502                "type": "text",
503                "text": format!("[Document: {}]", title.as_deref().unwrap_or("untitled")),
504            }),
505        })
506        .collect();
507
508    serde_json::Value::Array(parts)
509}