agents_runtime/providers/
openai.rs

1use agents_core::llm::{ChunkStream, LanguageModel, LlmRequest, LlmResponse, StreamChunk};
2use agents_core::messaging::{AgentMessage, MessageContent, MessageRole};
3use agents_core::tools::ToolSchema;
4use async_trait::async_trait;
5use futures::stream::StreamExt;
6use reqwest::Client;
7use serde::{Deserialize, Serialize};
8use std::sync::{Arc, Mutex};
9
10#[derive(Clone)]
11pub struct OpenAiConfig {
12    pub api_key: String,
13    pub model: String,
14    pub api_url: Option<String>,
15}
16
17impl OpenAiConfig {
18    pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
19        Self {
20            api_key: api_key.into(),
21            model: model.into(),
22            api_url: None,
23        }
24    }
25
26    pub fn with_api_url(mut self, api_url: Option<String>) -> Self {
27        self.api_url = api_url;
28        self
29    }
30}
31
32pub struct OpenAiChatModel {
33    client: Client,
34    config: OpenAiConfig,
35}
36
37impl OpenAiChatModel {
38    pub fn new(config: OpenAiConfig) -> anyhow::Result<Self> {
39        Ok(Self {
40            client: Client::builder()
41                .user_agent("rust-deep-agents-sdk/0.1")
42                .build()?,
43            config,
44        })
45    }
46}
47
48#[derive(Serialize)]
49struct ChatRequest<'a> {
50    model: &'a str,
51    messages: &'a [OpenAiMessage],
52    #[serde(skip_serializing_if = "Option::is_none")]
53    stream: Option<bool>,
54    #[serde(skip_serializing_if = "Option::is_none")]
55    tools: Option<Vec<OpenAiTool>>,
56}
57
58#[derive(Serialize)]
59struct OpenAiMessage {
60    role: &'static str,
61    content: String,
62}
63
64#[derive(Clone, Serialize)]
65struct OpenAiTool {
66    #[serde(rename = "type")]
67    tool_type: String,
68    function: OpenAiFunction,
69}
70
71#[derive(Clone, Serialize)]
72struct OpenAiFunction {
73    name: String,
74    description: String,
75    parameters: serde_json::Value,
76}
77
78#[derive(Deserialize)]
79struct ChatResponse {
80    choices: Vec<Choice>,
81}
82
83#[derive(Deserialize)]
84struct Choice {
85    message: ChoiceMessage,
86}
87
88#[derive(Deserialize)]
89struct ChoiceMessage {
90    content: Option<String>,
91    #[serde(default)]
92    tool_calls: Vec<OpenAiToolCall>,
93}
94
95#[derive(Deserialize)]
96struct OpenAiToolCall {
97    #[allow(dead_code)]
98    id: String,
99    #[serde(rename = "type")]
100    #[allow(dead_code)]
101    tool_type: String,
102    function: OpenAiFunctionCall,
103}
104
105#[derive(Deserialize)]
106struct OpenAiFunctionCall {
107    name: String,
108    arguments: String,
109}
110
111// Streaming response structures
112#[derive(Deserialize)]
113struct StreamResponse {
114    choices: Vec<StreamChoice>,
115}
116
117#[derive(Deserialize)]
118struct StreamChoice {
119    delta: StreamDelta,
120    finish_reason: Option<String>,
121}
122
123#[derive(Deserialize)]
124struct StreamDelta {
125    content: Option<String>,
126}
127
128fn to_openai_messages(request: &LlmRequest) -> Vec<OpenAiMessage> {
129    let mut messages = Vec::with_capacity(request.messages.len() + 1);
130    messages.push(OpenAiMessage {
131        role: "system",
132        content: request.system_prompt.clone(),
133    });
134
135    // Filter and validate message sequence for OpenAI compatibility
136    let mut last_was_tool_call = false;
137
138    for msg in &request.messages {
139        let role = match msg.role {
140            MessageRole::User => "user",
141            MessageRole::Agent => "assistant",
142            MessageRole::Tool => {
143                // Only include tool messages if they follow a tool call
144                if !last_was_tool_call {
145                    tracing::warn!("Skipping tool message without preceding tool_calls");
146                    continue;
147                }
148                "tool"
149            }
150            MessageRole::System => "system",
151        };
152
153        let content = match &msg.content {
154            MessageContent::Text(text) => text.clone(),
155            MessageContent::Json(value) => value.to_string(),
156        };
157
158        // Check if this assistant message contains tool calls
159        last_was_tool_call =
160            matches!(msg.role, MessageRole::Agent) && content.contains("tool_calls");
161
162        messages.push(OpenAiMessage { role, content });
163    }
164    messages
165}
166
167/// Convert tool schemas to OpenAI function calling format
168fn to_openai_tools(tools: &[ToolSchema]) -> Option<Vec<OpenAiTool>> {
169    if tools.is_empty() {
170        return None;
171    }
172
173    Some(
174        tools
175            .iter()
176            .map(|tool| OpenAiTool {
177                tool_type: "function".to_string(),
178                function: OpenAiFunction {
179                    name: tool.name.clone(),
180                    description: tool.description.clone(),
181                    parameters: serde_json::to_value(&tool.parameters)
182                        .unwrap_or_else(|_| serde_json::json!({})),
183                },
184            })
185            .collect(),
186    )
187}
188
189#[async_trait]
190impl LanguageModel for OpenAiChatModel {
191    async fn generate(&self, request: LlmRequest) -> anyhow::Result<LlmResponse> {
192        let messages = to_openai_messages(&request);
193        let tools = to_openai_tools(&request.tools);
194
195        let body = ChatRequest {
196            model: &self.config.model,
197            messages: &messages,
198            stream: None,
199            tools: tools.clone(),
200        };
201        let url = self
202            .config
203            .api_url
204            .as_deref()
205            .unwrap_or("https://api.openai.com/v1/chat/completions");
206
207        // Debug logging
208        tracing::debug!(
209            "OpenAI request: model={}, messages={}, tools={}",
210            self.config.model,
211            messages.len(),
212            tools.as_ref().map(|t| t.len()).unwrap_or(0)
213        );
214        for (i, msg) in messages.iter().enumerate() {
215            tracing::debug!(
216                "Message {}: role={}, content_len={}",
217                i,
218                msg.role,
219                msg.content.len()
220            );
221            if msg.content.len() < 500 {
222                tracing::debug!("Message {} content: {}", i, msg.content);
223            }
224        }
225
226        let response = self
227            .client
228            .post(url)
229            .bearer_auth(&self.config.api_key)
230            .json(&body)
231            .send()
232            .await?;
233
234        if !response.status().is_success() {
235            let status = response.status();
236            let error_text = response.text().await.unwrap_or_default();
237            tracing::error!("OpenAI API error: status={}, body={}", status, error_text);
238            return Err(anyhow::anyhow!(
239                "OpenAI API error: {} - {}",
240                status,
241                error_text
242            ));
243        }
244
245        let data: ChatResponse = response.json().await?;
246        let choice = data
247            .choices
248            .into_iter()
249            .next()
250            .ok_or_else(|| anyhow::anyhow!("OpenAI response missing choices"))?;
251
252        // Handle tool calls if present
253        if !choice.message.tool_calls.is_empty() {
254            // Convert OpenAI tool_calls format to our JSON format
255            let tool_calls: Vec<_> = choice
256                .message
257                .tool_calls
258                .iter()
259                .map(|tc| {
260                    serde_json::json!({
261                        "name": tc.function.name,
262                        "args": serde_json::from_str::<serde_json::Value>(&tc.function.arguments)
263                            .unwrap_or_else(|_| serde_json::json!({}))
264                    })
265                })
266                .collect();
267
268            tracing::debug!("OpenAI response contains {} tool calls", tool_calls.len());
269
270            return Ok(LlmResponse {
271                message: AgentMessage {
272                    role: MessageRole::Agent,
273                    content: MessageContent::Json(serde_json::json!({
274                        "tool_calls": tool_calls
275                    })),
276                    metadata: None,
277                },
278            });
279        }
280
281        // Regular text response
282        let content = choice.message.content.unwrap_or_else(|| "".to_string());
283
284        Ok(LlmResponse {
285            message: AgentMessage {
286                role: MessageRole::Agent,
287                content: MessageContent::Text(content),
288                metadata: None,
289            },
290        })
291    }
292
293    async fn generate_stream(&self, request: LlmRequest) -> anyhow::Result<ChunkStream> {
294        let messages = to_openai_messages(&request);
295        let tools = to_openai_tools(&request.tools);
296
297        let body = ChatRequest {
298            model: &self.config.model,
299            messages: &messages,
300            stream: Some(true),
301            tools,
302        };
303        let url = self
304            .config
305            .api_url
306            .as_deref()
307            .unwrap_or("https://api.openai.com/v1/chat/completions");
308
309        tracing::debug!(
310            "OpenAI streaming request: model={}, messages={}, tools={}",
311            self.config.model,
312            messages.len(),
313            request.tools.len()
314        );
315
316        let response = self
317            .client
318            .post(url)
319            .bearer_auth(&self.config.api_key)
320            .json(&body)
321            .send()
322            .await?;
323
324        if !response.status().is_success() {
325            let status = response.status();
326            let error_text = response.text().await.unwrap_or_default();
327            tracing::error!("OpenAI API error: status={}, body={}", status, error_text);
328            return Err(anyhow::anyhow!(
329                "OpenAI API error: {} - {}",
330                status,
331                error_text
332            ));
333        }
334
335        // Create stream from SSE response
336        let stream = response.bytes_stream();
337        let accumulated_content = Arc::new(Mutex::new(String::new()));
338        let buffer = Arc::new(Mutex::new(String::new()));
339
340        let is_done = Arc::new(Mutex::new(false));
341
342        // Clone Arcs for use in finale
343        let final_accumulated = accumulated_content.clone();
344        let final_is_done = is_done.clone();
345
346        let chunk_stream = stream.map(move |result| {
347            let accumulated = accumulated_content.clone();
348            let buffer = buffer.clone();
349            let is_done = is_done.clone();
350
351            // Check if we're already done
352            if *is_done.lock().unwrap() {
353                return Ok(StreamChunk::TextDelta(String::new()));
354            }
355
356            match result {
357                Ok(bytes) => {
358                    let text = String::from_utf8_lossy(&bytes);
359
360                    // Append to buffer
361                    buffer.lock().unwrap().push_str(&text);
362
363                    let mut buf = buffer.lock().unwrap();
364
365                    // Process complete SSE messages (separated by \n\n)
366                    let mut collected_deltas = String::new();
367                    let mut found_done = false;
368                    let mut found_finish = false;
369
370                    // Split on double newline to get complete SSE messages
371                    let parts: Vec<&str> = buf.split("\n\n").collect();
372                    let complete_messages = if parts.len() > 1 {
373                        &parts[..parts.len() - 1] // All but last (potentially incomplete)
374                    } else {
375                        &[] // No complete messages yet
376                    };
377
378                    // Process each complete SSE message
379                    for msg in complete_messages {
380                        for line in msg.lines() {
381                            if let Some(data) = line.strip_prefix("data: ") {
382                                let json_str = data.trim();
383
384                                // Check for [DONE] marker
385                                if json_str == "[DONE]" {
386                                    found_done = true;
387                                    break;
388                                }
389
390                                // Parse JSON chunk
391                                match serde_json::from_str::<StreamResponse>(json_str) {
392                                    Ok(chunk) => {
393                                        if let Some(choice) = chunk.choices.first() {
394                                            // Collect delta content
395                                            if let Some(content) = &choice.delta.content {
396                                                if !content.is_empty() {
397                                                    accumulated.lock().unwrap().push_str(content);
398                                                    collected_deltas.push_str(content);
399                                                }
400                                            }
401
402                                            // Check if stream is finished
403                                            if choice.finish_reason.is_some() {
404                                                found_finish = true;
405                                            }
406                                        }
407                                    }
408                                    Err(e) => {
409                                        tracing::debug!("Failed to parse SSE message: {}", e);
410                                    }
411                                }
412                            }
413                        }
414                        if found_done || found_finish {
415                            break;
416                        }
417                    }
418
419                    // Clear processed messages from buffer, keep only incomplete part
420                    if !complete_messages.is_empty() {
421                        *buf = parts.last().unwrap_or(&"").to_string();
422                    }
423
424                    // Handle completion
425                    if found_done || found_finish {
426                        let content = accumulated.lock().unwrap().clone();
427                        let final_message = AgentMessage {
428                            role: MessageRole::Agent,
429                            content: MessageContent::Text(content),
430                            metadata: None,
431                        };
432                        *is_done.lock().unwrap() = true;
433                        buf.clear();
434                        return Ok(StreamChunk::Done {
435                            message: final_message,
436                        });
437                    }
438
439                    // Return collected deltas (may be empty)
440                    if !collected_deltas.is_empty() {
441                        return Ok(StreamChunk::TextDelta(collected_deltas));
442                    }
443
444                    Ok(StreamChunk::TextDelta(String::new()))
445                }
446                Err(e) => {
447                    // Stream ended - check if we have accumulated content
448                    if !*is_done.lock().unwrap() {
449                        let content = accumulated.lock().unwrap().clone();
450                        if !content.is_empty() {
451                            let final_message = AgentMessage {
452                                role: MessageRole::Agent,
453                                content: MessageContent::Text(content),
454                                metadata: None,
455                            };
456                            *is_done.lock().unwrap() = true;
457                            return Ok(StreamChunk::Done {
458                                message: final_message,
459                            });
460                        }
461                    }
462                    Err(anyhow::anyhow!("Stream error: {}", e))
463                }
464            }
465        });
466
467        // Chain a final chunk to ensure Done is sent when stream completes
468        let stream_with_finale = chunk_stream.chain(futures::stream::once(async move {
469            // Check if we already sent Done
470            if !*final_is_done.lock().unwrap() {
471                let content = final_accumulated.lock().unwrap().clone();
472                if !content.is_empty() {
473                    let final_message = AgentMessage {
474                        role: MessageRole::Agent,
475                        content: MessageContent::Text(content),
476                        metadata: None,
477                    };
478                    let content_text = match &final_message.content {
479                        MessageContent::Text(t) => t.as_str(),
480                        _ => "non-text",
481                    };
482                    tracing::debug!(
483                        "Stream ended naturally, sending final Done chunk with {} chars",
484                        content_text.len()
485                    );
486                    return Ok(StreamChunk::Done {
487                        message: final_message,
488                    });
489                }
490            }
491            // Return empty delta if already done or no content
492            Ok(StreamChunk::TextDelta(String::new()))
493        }));
494
495        Ok(Box::pin(stream_with_finale))
496    }
497}