agents_runtime/providers/
openai.rs

1use agents_core::llm::{ChunkStream, LanguageModel, LlmRequest, LlmResponse, StreamChunk};
2use agents_core::messaging::{AgentMessage, MessageContent, MessageRole};
3use agents_core::tools::ToolSchema;
4use async_trait::async_trait;
5use futures::stream::StreamExt;
6use reqwest::Client;
7use serde::{Deserialize, Serialize};
8use std::sync::{Arc, Mutex};
9
10#[derive(Clone)]
11pub struct OpenAiConfig {
12    pub api_key: String,
13    pub model: String,
14    pub api_url: Option<String>,
15}
16
17impl OpenAiConfig {
18    pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
19        Self {
20            api_key: api_key.into(),
21            model: model.into(),
22            api_url: None,
23        }
24    }
25
26    pub fn with_api_url(mut self, api_url: Option<String>) -> Self {
27        self.api_url = api_url;
28        self
29    }
30}
31
32pub struct OpenAiChatModel {
33    client: Client,
34    config: OpenAiConfig,
35}
36
37impl OpenAiChatModel {
38    pub fn new(config: OpenAiConfig) -> anyhow::Result<Self> {
39        Ok(Self {
40            client: Client::builder()
41                .user_agent("rust-deep-agents-sdk/0.1")
42                .build()?,
43            config,
44        })
45    }
46}
47
48#[derive(Serialize)]
49struct ChatRequest<'a> {
50    model: &'a str,
51    messages: &'a [OpenAiMessage],
52    #[serde(skip_serializing_if = "Option::is_none")]
53    stream: Option<bool>,
54    #[serde(skip_serializing_if = "Option::is_none")]
55    tools: Option<Vec<OpenAiTool>>,
56}
57
58#[derive(Serialize)]
59struct OpenAiMessage {
60    role: &'static str,
61    content: String,
62}
63
64#[derive(Clone, Serialize)]
65struct OpenAiTool {
66    #[serde(rename = "type")]
67    tool_type: String,
68    function: OpenAiFunction,
69}
70
71#[derive(Clone, Serialize)]
72struct OpenAiFunction {
73    name: String,
74    description: String,
75    parameters: serde_json::Value,
76}
77
78#[derive(Deserialize)]
79struct ChatResponse {
80    choices: Vec<Choice>,
81}
82
83#[derive(Deserialize)]
84struct Choice {
85    message: ChoiceMessage,
86}
87
88#[derive(Deserialize)]
89struct ChoiceMessage {
90    content: Option<String>,
91    #[serde(default)]
92    tool_calls: Vec<OpenAiToolCall>,
93}
94
95#[derive(Deserialize)]
96struct OpenAiToolCall {
97    #[allow(dead_code)]
98    id: String,
99    #[serde(rename = "type")]
100    #[allow(dead_code)]
101    tool_type: String,
102    function: OpenAiFunctionCall,
103}
104
105#[derive(Deserialize)]
106struct OpenAiFunctionCall {
107    name: String,
108    arguments: String,
109}
110
111// Streaming response structures
112#[derive(Deserialize)]
113struct StreamResponse {
114    choices: Vec<StreamChoice>,
115}
116
117#[derive(Deserialize)]
118struct StreamChoice {
119    delta: StreamDelta,
120    finish_reason: Option<String>,
121}
122
123#[derive(Deserialize)]
124struct StreamDelta {
125    content: Option<String>,
126}
127
128fn to_openai_messages(request: &LlmRequest) -> Vec<OpenAiMessage> {
129    let mut messages = Vec::with_capacity(request.messages.len() + 1);
130    messages.push(OpenAiMessage {
131        role: "system",
132        content: request.system_prompt.clone(),
133    });
134
135    // Convert all messages to OpenAI format
136    // Note: Tool messages are converted to user messages for compatibility
137    // since we don't have the full tool_calls metadata structure
138    for msg in &request.messages {
139        let role = match msg.role {
140            MessageRole::User => "user",
141            MessageRole::Agent => "assistant",
142            MessageRole::Tool => "user", // Convert tool results to user messages
143            MessageRole::System => "system",
144        };
145
146        let content = match &msg.content {
147            MessageContent::Text(text) => text.clone(),
148            MessageContent::Json(value) => value.to_string(),
149        };
150
151        messages.push(OpenAiMessage { role, content });
152    }
153    messages
154}
155
156/// Convert tool schemas to OpenAI function calling format
157fn to_openai_tools(tools: &[ToolSchema]) -> Option<Vec<OpenAiTool>> {
158    if tools.is_empty() {
159        return None;
160    }
161
162    Some(
163        tools
164            .iter()
165            .map(|tool| OpenAiTool {
166                tool_type: "function".to_string(),
167                function: OpenAiFunction {
168                    name: tool.name.clone(),
169                    description: tool.description.clone(),
170                    parameters: serde_json::to_value(&tool.parameters)
171                        .unwrap_or_else(|_| serde_json::json!({})),
172                },
173            })
174            .collect(),
175    )
176}
177
178#[async_trait]
179impl LanguageModel for OpenAiChatModel {
180    async fn generate(&self, request: LlmRequest) -> anyhow::Result<LlmResponse> {
181        let messages = to_openai_messages(&request);
182        let tools = to_openai_tools(&request.tools);
183
184        let body = ChatRequest {
185            model: &self.config.model,
186            messages: &messages,
187            stream: None,
188            tools: tools.clone(),
189        };
190        let url = self
191            .config
192            .api_url
193            .as_deref()
194            .unwrap_or("https://api.openai.com/v1/chat/completions");
195
196        // Debug logging
197        tracing::debug!(
198            "OpenAI request: model={}, messages={}, tools={}",
199            self.config.model,
200            messages.len(),
201            tools.as_ref().map(|t| t.len()).unwrap_or(0)
202        );
203        for (i, msg) in messages.iter().enumerate() {
204            tracing::debug!(
205                "Message {}: role={}, content_len={}",
206                i,
207                msg.role,
208                msg.content.len()
209            );
210            if msg.content.len() < 500 {
211                tracing::debug!("Message {} content: {}", i, msg.content);
212            }
213        }
214
215        let response = self
216            .client
217            .post(url)
218            .bearer_auth(&self.config.api_key)
219            .json(&body)
220            .send()
221            .await?;
222
223        if !response.status().is_success() {
224            let status = response.status();
225            let error_text = response.text().await.unwrap_or_default();
226            tracing::error!("OpenAI API error: status={}, body={}", status, error_text);
227            return Err(anyhow::anyhow!(
228                "OpenAI API error: {} - {}",
229                status,
230                error_text
231            ));
232        }
233
234        let data: ChatResponse = response.json().await?;
235        let choice = data
236            .choices
237            .into_iter()
238            .next()
239            .ok_or_else(|| anyhow::anyhow!("OpenAI response missing choices"))?;
240
241        // Handle tool calls if present
242        if !choice.message.tool_calls.is_empty() {
243            // Convert OpenAI tool_calls format to our JSON format
244            let tool_calls: Vec<_> = choice
245                .message
246                .tool_calls
247                .iter()
248                .map(|tc| {
249                    serde_json::json!({
250                        "name": tc.function.name,
251                        "args": serde_json::from_str::<serde_json::Value>(&tc.function.arguments)
252                            .unwrap_or_else(|_| serde_json::json!({}))
253                    })
254                })
255                .collect();
256
257            // Enhanced logging for tool call detection
258            let tool_names: Vec<&str> = choice
259                .message
260                .tool_calls
261                .iter()
262                .map(|tc| tc.function.name.as_str())
263                .collect();
264
265            tracing::warn!(
266                "🔧 LLM CALLED {} TOOL(S): {:?}",
267                tool_calls.len(),
268                tool_names
269            );
270
271            // Log argument sizes for debugging
272            for (i, tc) in choice.message.tool_calls.iter().enumerate() {
273                tracing::debug!(
274                    "Tool call {}: {} with {} bytes of arguments",
275                    i + 1,
276                    tc.function.name,
277                    tc.function.arguments.len()
278                );
279            }
280
281            return Ok(LlmResponse {
282                message: AgentMessage {
283                    role: MessageRole::Agent,
284                    content: MessageContent::Json(serde_json::json!({
285                        "tool_calls": tool_calls
286                    })),
287                    metadata: None,
288                },
289            });
290        }
291
292        // Regular text response
293        let content = choice.message.content.unwrap_or_else(|| "".to_string());
294
295        Ok(LlmResponse {
296            message: AgentMessage {
297                role: MessageRole::Agent,
298                content: MessageContent::Text(content),
299                metadata: None,
300            },
301        })
302    }
303
304    async fn generate_stream(&self, request: LlmRequest) -> anyhow::Result<ChunkStream> {
305        let messages = to_openai_messages(&request);
306        let tools = to_openai_tools(&request.tools);
307
308        let body = ChatRequest {
309            model: &self.config.model,
310            messages: &messages,
311            stream: Some(true),
312            tools,
313        };
314        let url = self
315            .config
316            .api_url
317            .as_deref()
318            .unwrap_or("https://api.openai.com/v1/chat/completions");
319
320        tracing::debug!(
321            "OpenAI streaming request: model={}, messages={}, tools={}",
322            self.config.model,
323            messages.len(),
324            request.tools.len()
325        );
326
327        let response = self
328            .client
329            .post(url)
330            .bearer_auth(&self.config.api_key)
331            .json(&body)
332            .send()
333            .await?;
334
335        if !response.status().is_success() {
336            let status = response.status();
337            let error_text = response.text().await.unwrap_or_default();
338            tracing::error!("OpenAI API error: status={}, body={}", status, error_text);
339            return Err(anyhow::anyhow!(
340                "OpenAI API error: {} - {}",
341                status,
342                error_text
343            ));
344        }
345
346        // Create stream from SSE response
347        let stream = response.bytes_stream();
348        let accumulated_content = Arc::new(Mutex::new(String::new()));
349        let buffer = Arc::new(Mutex::new(String::new()));
350
351        let is_done = Arc::new(Mutex::new(false));
352
353        // Clone Arcs for use in finale
354        let final_accumulated = accumulated_content.clone();
355        let final_is_done = is_done.clone();
356
357        let chunk_stream = stream.map(move |result| {
358            let accumulated = accumulated_content.clone();
359            let buffer = buffer.clone();
360            let is_done = is_done.clone();
361
362            // Check if we're already done
363            if *is_done.lock().unwrap() {
364                return Ok(StreamChunk::TextDelta(String::new()));
365            }
366
367            match result {
368                Ok(bytes) => {
369                    let text = String::from_utf8_lossy(&bytes);
370
371                    // Append to buffer
372                    buffer.lock().unwrap().push_str(&text);
373
374                    let mut buf = buffer.lock().unwrap();
375
376                    // Process complete SSE messages (separated by \n\n)
377                    let mut collected_deltas = String::new();
378                    let mut found_done = false;
379                    let mut found_finish = false;
380
381                    // Split on double newline to get complete SSE messages
382                    let parts: Vec<&str> = buf.split("\n\n").collect();
383                    let complete_messages = if parts.len() > 1 {
384                        &parts[..parts.len() - 1] // All but last (potentially incomplete)
385                    } else {
386                        &[] // No complete messages yet
387                    };
388
389                    // Process each complete SSE message
390                    for msg in complete_messages {
391                        for line in msg.lines() {
392                            if let Some(data) = line.strip_prefix("data: ") {
393                                let json_str = data.trim();
394
395                                // Check for [DONE] marker
396                                if json_str == "[DONE]" {
397                                    found_done = true;
398                                    break;
399                                }
400
401                                // Parse JSON chunk
402                                match serde_json::from_str::<StreamResponse>(json_str) {
403                                    Ok(chunk) => {
404                                        if let Some(choice) = chunk.choices.first() {
405                                            // Collect delta content
406                                            if let Some(content) = &choice.delta.content {
407                                                if !content.is_empty() {
408                                                    accumulated.lock().unwrap().push_str(content);
409                                                    collected_deltas.push_str(content);
410                                                }
411                                            }
412
413                                            // Check if stream is finished
414                                            if choice.finish_reason.is_some() {
415                                                found_finish = true;
416                                            }
417                                        }
418                                    }
419                                    Err(e) => {
420                                        tracing::debug!("Failed to parse SSE message: {}", e);
421                                    }
422                                }
423                            }
424                        }
425                        if found_done || found_finish {
426                            break;
427                        }
428                    }
429
430                    // Clear processed messages from buffer, keep only incomplete part
431                    if !complete_messages.is_empty() {
432                        *buf = parts.last().unwrap_or(&"").to_string();
433                    }
434
435                    // Handle completion
436                    if found_done || found_finish {
437                        let content = accumulated.lock().unwrap().clone();
438                        let final_message = AgentMessage {
439                            role: MessageRole::Agent,
440                            content: MessageContent::Text(content),
441                            metadata: None,
442                        };
443                        *is_done.lock().unwrap() = true;
444                        buf.clear();
445                        return Ok(StreamChunk::Done {
446                            message: final_message,
447                        });
448                    }
449
450                    // Return collected deltas (may be empty)
451                    if !collected_deltas.is_empty() {
452                        return Ok(StreamChunk::TextDelta(collected_deltas));
453                    }
454
455                    Ok(StreamChunk::TextDelta(String::new()))
456                }
457                Err(e) => {
458                    // Stream ended - check if we have accumulated content
459                    if !*is_done.lock().unwrap() {
460                        let content = accumulated.lock().unwrap().clone();
461                        if !content.is_empty() {
462                            let final_message = AgentMessage {
463                                role: MessageRole::Agent,
464                                content: MessageContent::Text(content),
465                                metadata: None,
466                            };
467                            *is_done.lock().unwrap() = true;
468                            return Ok(StreamChunk::Done {
469                                message: final_message,
470                            });
471                        }
472                    }
473                    Err(anyhow::anyhow!("Stream error: {}", e))
474                }
475            }
476        });
477
478        // Chain a final chunk to ensure Done is sent when stream completes
479        let stream_with_finale = chunk_stream.chain(futures::stream::once(async move {
480            // Check if we already sent Done
481            if !*final_is_done.lock().unwrap() {
482                let content = final_accumulated.lock().unwrap().clone();
483                if !content.is_empty() {
484                    let final_message = AgentMessage {
485                        role: MessageRole::Agent,
486                        content: MessageContent::Text(content),
487                        metadata: None,
488                    };
489                    let content_text = match &final_message.content {
490                        MessageContent::Text(t) => t.as_str(),
491                        _ => "non-text",
492                    };
493                    tracing::debug!(
494                        "Stream ended naturally, sending final Done chunk with {} chars",
495                        content_text.len()
496                    );
497                    return Ok(StreamChunk::Done {
498                        message: final_message,
499                    });
500                }
501            }
502            // Return empty delta if already done or no content
503            Ok(StreamChunk::TextDelta(String::new()))
504        }));
505
506        Ok(Box::pin(stream_with_finale))
507    }
508}