agents_runtime/providers/
openai.rs

1use agents_core::llm::{ChunkStream, LanguageModel, LlmRequest, LlmResponse, StreamChunk};
2use agents_core::messaging::{AgentMessage, MessageContent, MessageRole};
3use agents_core::tools::ToolSchema;
4use async_trait::async_trait;
5use futures::stream::StreamExt;
6use reqwest::Client;
7use serde::{Deserialize, Serialize};
8use std::sync::{Arc, Mutex};
9
10#[derive(Clone)]
11pub struct OpenAiConfig {
12    pub api_key: String,
13    pub model: String,
14    pub api_url: Option<String>,
15    pub custom_headers: Vec<(String, String)>,
16}
17
18impl OpenAiConfig {
19    pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
20        Self {
21            api_key: api_key.into(),
22            model: model.into(),
23            api_url: None,
24            custom_headers: Vec::new(),
25        }
26    }
27
28    pub fn with_api_url(mut self, api_url: Option<String>) -> Self {
29        self.api_url = api_url;
30        self
31    }
32
33    pub fn with_custom_headers(mut self, headers: Vec<(String, String)>) -> Self {
34        self.custom_headers = headers;
35        self
36    }
37}
38
39pub struct OpenAiChatModel {
40    client: Client,
41    config: OpenAiConfig,
42}
43
44impl OpenAiChatModel {
45    pub fn new(config: OpenAiConfig) -> anyhow::Result<Self> {
46        Ok(Self {
47            client: Client::builder()
48                .user_agent("rust-deep-agents-sdk/0.1")
49                .build()?,
50            config,
51        })
52    }
53}
54
55#[derive(Serialize)]
56struct ChatRequest<'a> {
57    model: &'a str,
58    messages: &'a [OpenAiMessage],
59    #[serde(skip_serializing_if = "Option::is_none")]
60    stream: Option<bool>,
61    #[serde(skip_serializing_if = "Option::is_none")]
62    tools: Option<Vec<OpenAiTool>>,
63}
64
65#[derive(Serialize)]
66struct OpenAiMessage {
67    role: &'static str,
68    content: String,
69}
70
71#[derive(Clone, Serialize)]
72struct OpenAiTool {
73    #[serde(rename = "type")]
74    tool_type: String,
75    function: OpenAiFunction,
76}
77
78#[derive(Clone, Serialize)]
79struct OpenAiFunction {
80    name: String,
81    description: String,
82    parameters: serde_json::Value,
83}
84
85#[derive(Deserialize)]
86struct ChatResponse {
87    choices: Vec<Choice>,
88}
89
90#[derive(Deserialize)]
91struct Choice {
92    message: ChoiceMessage,
93}
94
95#[derive(Deserialize)]
96struct ChoiceMessage {
97    content: Option<String>,
98    #[serde(default)]
99    tool_calls: Vec<OpenAiToolCall>,
100}
101
102#[derive(Deserialize)]
103struct OpenAiToolCall {
104    #[allow(dead_code)]
105    id: String,
106    #[serde(rename = "type")]
107    #[allow(dead_code)]
108    tool_type: String,
109    function: OpenAiFunctionCall,
110}
111
112#[derive(Deserialize)]
113struct OpenAiFunctionCall {
114    name: String,
115    arguments: String,
116}
117
118// Streaming response structures
119#[derive(Deserialize)]
120struct StreamResponse {
121    choices: Vec<StreamChoice>,
122}
123
124#[derive(Deserialize)]
125struct StreamChoice {
126    delta: StreamDelta,
127    finish_reason: Option<String>,
128}
129
130#[derive(Deserialize)]
131struct StreamDelta {
132    content: Option<String>,
133}
134
135fn to_openai_messages(request: &LlmRequest) -> Vec<OpenAiMessage> {
136    let mut messages = Vec::with_capacity(request.messages.len() + 1);
137    messages.push(OpenAiMessage {
138        role: "system",
139        content: request.system_prompt.clone(),
140    });
141
142    // Convert all messages to OpenAI format
143    // Note: Tool messages are converted to user messages for compatibility
144    // since we don't have the full tool_calls metadata structure
145    for msg in &request.messages {
146        let role = match msg.role {
147            MessageRole::User => "user",
148            MessageRole::Agent => "assistant",
149            MessageRole::Tool => "user", // Convert tool results to user messages
150            MessageRole::System => "system",
151        };
152
153        let content = match &msg.content {
154            MessageContent::Text(text) => text.clone(),
155            MessageContent::Json(value) => value.to_string(),
156        };
157
158        messages.push(OpenAiMessage { role, content });
159    }
160    messages
161}
162
163/// Convert tool schemas to OpenAI function calling format
164fn to_openai_tools(tools: &[ToolSchema]) -> Option<Vec<OpenAiTool>> {
165    if tools.is_empty() {
166        return None;
167    }
168
169    Some(
170        tools
171            .iter()
172            .map(|tool| OpenAiTool {
173                tool_type: "function".to_string(),
174                function: OpenAiFunction {
175                    name: tool.name.clone(),
176                    description: tool.description.clone(),
177                    parameters: serde_json::to_value(&tool.parameters)
178                        .unwrap_or_else(|_| serde_json::json!({})),
179                },
180            })
181            .collect(),
182    )
183}
184
185#[async_trait]
186impl LanguageModel for OpenAiChatModel {
187    async fn generate(&self, request: LlmRequest) -> anyhow::Result<LlmResponse> {
188        let messages = to_openai_messages(&request);
189        let tools = to_openai_tools(&request.tools);
190
191        let body = ChatRequest {
192            model: &self.config.model,
193            messages: &messages,
194            stream: None,
195            tools: tools.clone(),
196        };
197        let url = self
198            .config
199            .api_url
200            .as_deref()
201            .unwrap_or("https://api.openai.com/v1/chat/completions");
202
203        // Debug logging
204        tracing::debug!(
205            "OpenAI request: model={}, messages={}, tools={}",
206            self.config.model,
207            messages.len(),
208            tools.as_ref().map(|t| t.len()).unwrap_or(0)
209        );
210        for (i, msg) in messages.iter().enumerate() {
211            tracing::debug!(
212                "Message {}: role={}, content_len={}",
213                i,
214                msg.role,
215                msg.content.len()
216            );
217            if msg.content.len() < 500 {
218                tracing::debug!("Message {} content: {}", i, msg.content);
219            }
220        }
221
222        let mut request = self.client.post(url).bearer_auth(&self.config.api_key);
223
224        for (key, value) in &self.config.custom_headers {
225            request = request.header(key, value);
226        }
227
228        let response = request.json(&body).send().await?;
229
230        if !response.status().is_success() {
231            let status = response.status();
232            let error_text = response.text().await.unwrap_or_default();
233            tracing::error!("OpenAI API error: status={}, body={}", status, error_text);
234            return Err(anyhow::anyhow!(
235                "OpenAI API error: {} - {}",
236                status,
237                error_text
238            ));
239        }
240
241        let data: ChatResponse = response.json().await?;
242        let choice = data
243            .choices
244            .into_iter()
245            .next()
246            .ok_or_else(|| anyhow::anyhow!("OpenAI response missing choices"))?;
247
248        // Handle tool calls if present
249        if !choice.message.tool_calls.is_empty() {
250            // Convert OpenAI tool_calls format to our JSON format
251            let tool_calls: Vec<_> = choice
252                .message
253                .tool_calls
254                .iter()
255                .map(|tc| {
256                    serde_json::json!({
257                        "name": tc.function.name,
258                        "args": serde_json::from_str::<serde_json::Value>(&tc.function.arguments)
259                            .unwrap_or_else(|_| serde_json::json!({}))
260                    })
261                })
262                .collect();
263
264            // Enhanced logging for tool call detection
265            let tool_names: Vec<&str> = choice
266                .message
267                .tool_calls
268                .iter()
269                .map(|tc| tc.function.name.as_str())
270                .collect();
271
272            tracing::warn!(
273                "🔧 LLM CALLED {} TOOL(S): {:?}",
274                tool_calls.len(),
275                tool_names
276            );
277
278            // Log argument sizes for debugging
279            for (i, tc) in choice.message.tool_calls.iter().enumerate() {
280                tracing::debug!(
281                    "Tool call {}: {} with {} bytes of arguments",
282                    i + 1,
283                    tc.function.name,
284                    tc.function.arguments.len()
285                );
286            }
287
288            return Ok(LlmResponse {
289                message: AgentMessage {
290                    role: MessageRole::Agent,
291                    content: MessageContent::Json(serde_json::json!({
292                        "tool_calls": tool_calls
293                    })),
294                    metadata: None,
295                },
296            });
297        }
298
299        // Regular text response
300        let content = choice.message.content.unwrap_or_else(|| "".to_string());
301
302        Ok(LlmResponse {
303            message: AgentMessage {
304                role: MessageRole::Agent,
305                content: MessageContent::Text(content),
306                metadata: None,
307            },
308        })
309    }
310
311    async fn generate_stream(&self, request: LlmRequest) -> anyhow::Result<ChunkStream> {
312        let messages = to_openai_messages(&request);
313        let tools = to_openai_tools(&request.tools);
314
315        let body = ChatRequest {
316            model: &self.config.model,
317            messages: &messages,
318            stream: Some(true),
319            tools,
320        };
321        let url = self
322            .config
323            .api_url
324            .as_deref()
325            .unwrap_or("https://api.openai.com/v1/chat/completions");
326
327        tracing::debug!(
328            "OpenAI streaming request: model={}, messages={}, tools={}",
329            self.config.model,
330            messages.len(),
331            request.tools.len()
332        );
333
334        let mut http_request = self.client.post(url).bearer_auth(&self.config.api_key);
335
336        for (key, value) in &self.config.custom_headers {
337            http_request = http_request.header(key, value);
338        }
339
340        let response = http_request.json(&body).send().await?;
341
342        if !response.status().is_success() {
343            let status = response.status();
344            let error_text = response.text().await.unwrap_or_default();
345            tracing::error!("OpenAI API error: status={}, body={}", status, error_text);
346            return Err(anyhow::anyhow!(
347                "OpenAI API error: {} - {}",
348                status,
349                error_text
350            ));
351        }
352
353        // Create stream from SSE response
354        let stream = response.bytes_stream();
355        let accumulated_content = Arc::new(Mutex::new(String::new()));
356        let buffer = Arc::new(Mutex::new(String::new()));
357
358        let is_done = Arc::new(Mutex::new(false));
359
360        // Clone Arcs for use in finale
361        let final_accumulated = accumulated_content.clone();
362        let final_is_done = is_done.clone();
363
364        let chunk_stream = stream.map(move |result| {
365            let accumulated = accumulated_content.clone();
366            let buffer = buffer.clone();
367            let is_done = is_done.clone();
368
369            // Check if we're already done
370            if *is_done.lock().unwrap() {
371                return Ok(StreamChunk::TextDelta(String::new()));
372            }
373
374            match result {
375                Ok(bytes) => {
376                    let text = String::from_utf8_lossy(&bytes);
377
378                    // Append to buffer
379                    buffer.lock().unwrap().push_str(&text);
380
381                    let mut buf = buffer.lock().unwrap();
382
383                    // Process complete SSE messages (separated by \n\n)
384                    let mut collected_deltas = String::new();
385                    let mut found_done = false;
386                    let mut found_finish = false;
387
388                    // Split on double newline to get complete SSE messages
389                    let parts: Vec<&str> = buf.split("\n\n").collect();
390                    let complete_messages = if parts.len() > 1 {
391                        &parts[..parts.len() - 1] // All but last (potentially incomplete)
392                    } else {
393                        &[] // No complete messages yet
394                    };
395
396                    // Process each complete SSE message
397                    for msg in complete_messages {
398                        for line in msg.lines() {
399                            if let Some(data) = line.strip_prefix("data: ") {
400                                let json_str = data.trim();
401
402                                // Check for [DONE] marker
403                                if json_str == "[DONE]" {
404                                    found_done = true;
405                                    break;
406                                }
407
408                                // Parse JSON chunk
409                                match serde_json::from_str::<StreamResponse>(json_str) {
410                                    Ok(chunk) => {
411                                        if let Some(choice) = chunk.choices.first() {
412                                            // Collect delta content
413                                            if let Some(content) = &choice.delta.content {
414                                                if !content.is_empty() {
415                                                    accumulated.lock().unwrap().push_str(content);
416                                                    collected_deltas.push_str(content);
417                                                }
418                                            }
419
420                                            // Check if stream is finished
421                                            if choice.finish_reason.is_some() {
422                                                found_finish = true;
423                                            }
424                                        }
425                                    }
426                                    Err(e) => {
427                                        tracing::debug!("Failed to parse SSE message: {}", e);
428                                    }
429                                }
430                            }
431                        }
432                        if found_done || found_finish {
433                            break;
434                        }
435                    }
436
437                    // Clear processed messages from buffer, keep only incomplete part
438                    if !complete_messages.is_empty() {
439                        *buf = parts.last().unwrap_or(&"").to_string();
440                    }
441
442                    // Handle completion
443                    if found_done || found_finish {
444                        let content = accumulated.lock().unwrap().clone();
445                        let final_message = AgentMessage {
446                            role: MessageRole::Agent,
447                            content: MessageContent::Text(content),
448                            metadata: None,
449                        };
450                        *is_done.lock().unwrap() = true;
451                        buf.clear();
452                        return Ok(StreamChunk::Done {
453                            message: final_message,
454                        });
455                    }
456
457                    // Return collected deltas (may be empty)
458                    if !collected_deltas.is_empty() {
459                        return Ok(StreamChunk::TextDelta(collected_deltas));
460                    }
461
462                    Ok(StreamChunk::TextDelta(String::new()))
463                }
464                Err(e) => {
465                    // Stream ended - check if we have accumulated content
466                    if !*is_done.lock().unwrap() {
467                        let content = accumulated.lock().unwrap().clone();
468                        if !content.is_empty() {
469                            let final_message = AgentMessage {
470                                role: MessageRole::Agent,
471                                content: MessageContent::Text(content),
472                                metadata: None,
473                            };
474                            *is_done.lock().unwrap() = true;
475                            return Ok(StreamChunk::Done {
476                                message: final_message,
477                            });
478                        }
479                    }
480                    Err(anyhow::anyhow!("Stream error: {}", e))
481                }
482            }
483        });
484
485        // Chain a final chunk to ensure Done is sent when stream completes
486        let stream_with_finale = chunk_stream.chain(futures::stream::once(async move {
487            // Check if we already sent Done
488            if !*final_is_done.lock().unwrap() {
489                let content = final_accumulated.lock().unwrap().clone();
490                if !content.is_empty() {
491                    let final_message = AgentMessage {
492                        role: MessageRole::Agent,
493                        content: MessageContent::Text(content),
494                        metadata: None,
495                    };
496                    let content_text = match &final_message.content {
497                        MessageContent::Text(t) => t.as_str(),
498                        _ => "non-text",
499                    };
500                    tracing::debug!(
501                        "Stream ended naturally, sending final Done chunk with {} chars",
502                        content_text.len()
503                    );
504                    return Ok(StreamChunk::Done {
505                        message: final_message,
506                    });
507                }
508            }
509            // Return empty delta if already done or no content
510            Ok(StreamChunk::TextDelta(String::new()))
511        }));
512
513        Ok(Box::pin(stream_with_finale))
514    }
515}
516
517#[cfg(test)]
518mod tests {
519    use super::*;
520
521    #[test]
522    fn openai_config_new_initializes_empty_custom_headers() {
523        let config = OpenAiConfig::new("test-key", "gpt-4");
524        assert_eq!(config.api_key, "test-key");
525        assert_eq!(config.model, "gpt-4");
526        assert!(config.custom_headers.is_empty());
527        assert!(config.api_url.is_none());
528    }
529
530    #[test]
531    fn openai_config_with_custom_headers_sets_headers() {
532        let headers = vec![
533            ("X-Custom-Header".to_string(), "value1".to_string()),
534            ("X-Another-Header".to_string(), "value2".to_string()),
535        ];
536        let config = OpenAiConfig::new("test-key", "gpt-4").with_custom_headers(headers.clone());
537
538        assert_eq!(config.custom_headers.len(), 2);
539        assert_eq!(config.custom_headers[0].0, "X-Custom-Header");
540        assert_eq!(config.custom_headers[0].1, "value1");
541        assert_eq!(config.custom_headers[1].0, "X-Another-Header");
542        assert_eq!(config.custom_headers[1].1, "value2");
543    }
544}