agents_runtime/providers/
openai.rs

1use agents_core::llm::{ChunkStream, LanguageModel, LlmRequest, LlmResponse, StreamChunk};
2use agents_core::messaging::{AgentMessage, MessageContent, MessageRole};
3use async_trait::async_trait;
4use futures::stream::StreamExt;
5use reqwest::Client;
6use serde::{Deserialize, Serialize};
7use std::sync::{Arc, Mutex};
8
9#[derive(Clone)]
10pub struct OpenAiConfig {
11    pub api_key: String,
12    pub model: String,
13    pub api_url: Option<String>,
14}
15
16impl OpenAiConfig {
17    pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
18        Self {
19            api_key: api_key.into(),
20            model: model.into(),
21            api_url: None,
22        }
23    }
24
25    pub fn with_api_url(mut self, api_url: Option<String>) -> Self {
26        self.api_url = api_url;
27        self
28    }
29}
30
31pub struct OpenAiChatModel {
32    client: Client,
33    config: OpenAiConfig,
34}
35
36impl OpenAiChatModel {
37    pub fn new(config: OpenAiConfig) -> anyhow::Result<Self> {
38        Ok(Self {
39            client: Client::builder()
40                .user_agent("rust-deep-agents-sdk/0.1")
41                .build()?,
42            config,
43        })
44    }
45}
46
47#[derive(Serialize)]
48struct ChatRequest<'a> {
49    model: &'a str,
50    messages: &'a [OpenAiMessage],
51    #[serde(skip_serializing_if = "Option::is_none")]
52    stream: Option<bool>,
53}
54
55#[derive(Serialize)]
56struct OpenAiMessage {
57    role: &'static str,
58    content: String,
59}
60
61#[derive(Deserialize)]
62struct ChatResponse {
63    choices: Vec<Choice>,
64}
65
66#[derive(Deserialize)]
67struct Choice {
68    message: ChoiceMessage,
69}
70
71#[derive(Deserialize)]
72struct ChoiceMessage {
73    content: String,
74}
75
76// Streaming response structures
77#[derive(Deserialize)]
78struct StreamResponse {
79    choices: Vec<StreamChoice>,
80}
81
82#[derive(Deserialize)]
83struct StreamChoice {
84    delta: StreamDelta,
85    finish_reason: Option<String>,
86}
87
88#[derive(Deserialize)]
89struct StreamDelta {
90    content: Option<String>,
91}
92
93fn to_openai_messages(request: &LlmRequest) -> Vec<OpenAiMessage> {
94    let mut messages = Vec::with_capacity(request.messages.len() + 1);
95    messages.push(OpenAiMessage {
96        role: "system",
97        content: request.system_prompt.clone(),
98    });
99
100    // Filter and validate message sequence for OpenAI compatibility
101    let mut last_was_tool_call = false;
102
103    for msg in &request.messages {
104        let role = match msg.role {
105            MessageRole::User => "user",
106            MessageRole::Agent => "assistant",
107            MessageRole::Tool => {
108                // Only include tool messages if they follow a tool call
109                if !last_was_tool_call {
110                    tracing::warn!("Skipping tool message without preceding tool_calls");
111                    continue;
112                }
113                "tool"
114            }
115            MessageRole::System => "system",
116        };
117
118        let content = match &msg.content {
119            MessageContent::Text(text) => text.clone(),
120            MessageContent::Json(value) => value.to_string(),
121        };
122
123        // Check if this assistant message contains tool calls
124        last_was_tool_call =
125            matches!(msg.role, MessageRole::Agent) && content.contains("tool_calls");
126
127        messages.push(OpenAiMessage { role, content });
128    }
129    messages
130}
131
132#[async_trait]
133impl LanguageModel for OpenAiChatModel {
134    async fn generate(&self, request: LlmRequest) -> anyhow::Result<LlmResponse> {
135        let messages = to_openai_messages(&request);
136        let body = ChatRequest {
137            model: &self.config.model,
138            messages: &messages,
139            stream: None,
140        };
141        let url = self
142            .config
143            .api_url
144            .as_deref()
145            .unwrap_or("https://api.openai.com/v1/chat/completions");
146
147        // Debug logging
148        tracing::debug!(
149            "OpenAI request: model={}, messages={}",
150            self.config.model,
151            messages.len()
152        );
153        for (i, msg) in messages.iter().enumerate() {
154            tracing::debug!(
155                "Message {}: role={}, content_len={}",
156                i,
157                msg.role,
158                msg.content.len()
159            );
160            if msg.content.len() < 500 {
161                tracing::debug!("Message {} content: {}", i, msg.content);
162            }
163        }
164
165        let response = self
166            .client
167            .post(url)
168            .bearer_auth(&self.config.api_key)
169            .json(&body)
170            .send()
171            .await?;
172
173        if !response.status().is_success() {
174            let status = response.status();
175            let error_text = response.text().await.unwrap_or_default();
176            tracing::error!("OpenAI API error: status={}, body={}", status, error_text);
177            return Err(anyhow::anyhow!(
178                "OpenAI API error: {} - {}",
179                status,
180                error_text
181            ));
182        }
183
184        let data: ChatResponse = response.json().await?;
185        let choice = data
186            .choices
187            .into_iter()
188            .next()
189            .ok_or_else(|| anyhow::anyhow!("OpenAI response missing choices"))?;
190
191        Ok(LlmResponse {
192            message: AgentMessage {
193                role: MessageRole::Agent,
194                content: MessageContent::Text(choice.message.content),
195                metadata: None,
196            },
197        })
198    }
199
200    async fn generate_stream(&self, request: LlmRequest) -> anyhow::Result<ChunkStream> {
201        let messages = to_openai_messages(&request);
202        let body = ChatRequest {
203            model: &self.config.model,
204            messages: &messages,
205            stream: Some(true),
206        };
207        let url = self
208            .config
209            .api_url
210            .as_deref()
211            .unwrap_or("https://api.openai.com/v1/chat/completions");
212
213        tracing::debug!(
214            "OpenAI streaming request: model={}, messages={}",
215            self.config.model,
216            messages.len()
217        );
218
219        let response = self
220            .client
221            .post(url)
222            .bearer_auth(&self.config.api_key)
223            .json(&body)
224            .send()
225            .await?;
226
227        if !response.status().is_success() {
228            let status = response.status();
229            let error_text = response.text().await.unwrap_or_default();
230            tracing::error!("OpenAI API error: status={}, body={}", status, error_text);
231            return Err(anyhow::anyhow!(
232                "OpenAI API error: {} - {}",
233                status,
234                error_text
235            ));
236        }
237
238        // Create stream from SSE response
239        let stream = response.bytes_stream();
240        let accumulated_content = Arc::new(Mutex::new(String::new()));
241        let buffer = Arc::new(Mutex::new(String::new()));
242
243        let is_done = Arc::new(Mutex::new(false));
244
245        // Clone Arcs for use in finale
246        let final_accumulated = accumulated_content.clone();
247        let final_is_done = is_done.clone();
248
249        let chunk_stream = stream.map(move |result| {
250            let accumulated = accumulated_content.clone();
251            let buffer = buffer.clone();
252            let is_done = is_done.clone();
253
254            // Check if we're already done
255            if *is_done.lock().unwrap() {
256                return Ok(StreamChunk::TextDelta(String::new()));
257            }
258
259            match result {
260                Ok(bytes) => {
261                    let text = String::from_utf8_lossy(&bytes);
262
263                    // Append to buffer
264                    buffer.lock().unwrap().push_str(&text);
265
266                    let mut buf = buffer.lock().unwrap();
267
268                    // Process complete SSE messages (separated by \n\n)
269                    let mut collected_deltas = String::new();
270                    let mut found_done = false;
271                    let mut found_finish = false;
272
273                    // Split on double newline to get complete SSE messages
274                    let parts: Vec<&str> = buf.split("\n\n").collect();
275                    let complete_messages = if parts.len() > 1 {
276                        &parts[..parts.len() - 1] // All but last (potentially incomplete)
277                    } else {
278                        &[] // No complete messages yet
279                    };
280
281                    // Process each complete SSE message
282                    for msg in complete_messages {
283                        for line in msg.lines() {
284                            if let Some(data) = line.strip_prefix("data: ") {
285                                let json_str = data.trim();
286
287                                // Check for [DONE] marker
288                                if json_str == "[DONE]" {
289                                    found_done = true;
290                                    break;
291                                }
292
293                                // Parse JSON chunk
294                                match serde_json::from_str::<StreamResponse>(json_str) {
295                                    Ok(chunk) => {
296                                        if let Some(choice) = chunk.choices.first() {
297                                            // Collect delta content
298                                            if let Some(content) = &choice.delta.content {
299                                                if !content.is_empty() {
300                                                    accumulated.lock().unwrap().push_str(content);
301                                                    collected_deltas.push_str(content);
302                                                }
303                                            }
304
305                                            // Check if stream is finished
306                                            if choice.finish_reason.is_some() {
307                                                found_finish = true;
308                                            }
309                                        }
310                                    }
311                                    Err(e) => {
312                                        tracing::debug!("Failed to parse SSE message: {}", e);
313                                    }
314                                }
315                            }
316                        }
317                        if found_done || found_finish {
318                            break;
319                        }
320                    }
321
322                    // Clear processed messages from buffer, keep only incomplete part
323                    if !complete_messages.is_empty() {
324                        *buf = parts.last().unwrap_or(&"").to_string();
325                    }
326
327                    // Handle completion
328                    if found_done || found_finish {
329                        let content = accumulated.lock().unwrap().clone();
330                        let final_message = AgentMessage {
331                            role: MessageRole::Agent,
332                            content: MessageContent::Text(content),
333                            metadata: None,
334                        };
335                        *is_done.lock().unwrap() = true;
336                        buf.clear();
337                        return Ok(StreamChunk::Done {
338                            message: final_message,
339                        });
340                    }
341
342                    // Return collected deltas (may be empty)
343                    if !collected_deltas.is_empty() {
344                        return Ok(StreamChunk::TextDelta(collected_deltas));
345                    }
346
347                    Ok(StreamChunk::TextDelta(String::new()))
348                }
349                Err(e) => {
350                    // Stream ended - check if we have accumulated content
351                    if !*is_done.lock().unwrap() {
352                        let content = accumulated.lock().unwrap().clone();
353                        if !content.is_empty() {
354                            let final_message = AgentMessage {
355                                role: MessageRole::Agent,
356                                content: MessageContent::Text(content),
357                                metadata: None,
358                            };
359                            *is_done.lock().unwrap() = true;
360                            return Ok(StreamChunk::Done {
361                                message: final_message,
362                            });
363                        }
364                    }
365                    Err(anyhow::anyhow!("Stream error: {}", e))
366                }
367            }
368        });
369
370        // Chain a final chunk to ensure Done is sent when stream completes
371        let stream_with_finale = chunk_stream.chain(futures::stream::once(async move {
372            // Check if we already sent Done
373            if !*final_is_done.lock().unwrap() {
374                let content = final_accumulated.lock().unwrap().clone();
375                if !content.is_empty() {
376                    let final_message = AgentMessage {
377                        role: MessageRole::Agent,
378                        content: MessageContent::Text(content),
379                        metadata: None,
380                    };
381                    let content_text = match &final_message.content {
382                        MessageContent::Text(t) => t.as_str(),
383                        _ => "non-text",
384                    };
385                    tracing::debug!(
386                        "Stream ended naturally, sending final Done chunk with {} chars",
387                        content_text.len()
388                    );
389                    return Ok(StreamChunk::Done {
390                        message: final_message,
391                    });
392                }
393            }
394            // Return empty delta if already done or no content
395            Ok(StreamChunk::TextDelta(String::new()))
396        }));
397
398        Ok(Box::pin(stream_with_finale))
399    }
400}