agents_runtime/providers/
openai.rs1use agents_core::llm::{ChunkStream, LanguageModel, LlmRequest, LlmResponse, StreamChunk};
2use agents_core::messaging::{AgentMessage, MessageContent, MessageRole};
3use async_trait::async_trait;
4use futures::stream::StreamExt;
5use reqwest::Client;
6use serde::{Deserialize, Serialize};
7use std::sync::{Arc, Mutex};
8
9#[derive(Clone)]
10pub struct OpenAiConfig {
11 pub api_key: String,
12 pub model: String,
13 pub api_url: Option<String>,
14}
15
16impl OpenAiConfig {
17 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
18 Self {
19 api_key: api_key.into(),
20 model: model.into(),
21 api_url: None,
22 }
23 }
24
25 pub fn with_api_url(mut self, api_url: Option<String>) -> Self {
26 self.api_url = api_url;
27 self
28 }
29}
30
31pub struct OpenAiChatModel {
32 client: Client,
33 config: OpenAiConfig,
34}
35
36impl OpenAiChatModel {
37 pub fn new(config: OpenAiConfig) -> anyhow::Result<Self> {
38 Ok(Self {
39 client: Client::builder()
40 .user_agent("rust-deep-agents-sdk/0.1")
41 .build()?,
42 config,
43 })
44 }
45}
46
47#[derive(Serialize)]
48struct ChatRequest<'a> {
49 model: &'a str,
50 messages: &'a [OpenAiMessage],
51 #[serde(skip_serializing_if = "Option::is_none")]
52 stream: Option<bool>,
53}
54
55#[derive(Serialize)]
56struct OpenAiMessage {
57 role: &'static str,
58 content: String,
59}
60
61#[derive(Deserialize)]
62struct ChatResponse {
63 choices: Vec<Choice>,
64}
65
66#[derive(Deserialize)]
67struct Choice {
68 message: ChoiceMessage,
69}
70
71#[derive(Deserialize)]
72struct ChoiceMessage {
73 content: String,
74}
75
76#[derive(Deserialize)]
78struct StreamResponse {
79 choices: Vec<StreamChoice>,
80}
81
82#[derive(Deserialize)]
83struct StreamChoice {
84 delta: StreamDelta,
85 finish_reason: Option<String>,
86}
87
88#[derive(Deserialize)]
89struct StreamDelta {
90 content: Option<String>,
91}
92
93fn to_openai_messages(request: &LlmRequest) -> Vec<OpenAiMessage> {
94 let mut messages = Vec::with_capacity(request.messages.len() + 1);
95 messages.push(OpenAiMessage {
96 role: "system",
97 content: request.system_prompt.clone(),
98 });
99
100 let mut last_was_tool_call = false;
102
103 for msg in &request.messages {
104 let role = match msg.role {
105 MessageRole::User => "user",
106 MessageRole::Agent => "assistant",
107 MessageRole::Tool => {
108 if !last_was_tool_call {
110 tracing::warn!("Skipping tool message without preceding tool_calls");
111 continue;
112 }
113 "tool"
114 }
115 MessageRole::System => "system",
116 };
117
118 let content = match &msg.content {
119 MessageContent::Text(text) => text.clone(),
120 MessageContent::Json(value) => value.to_string(),
121 };
122
123 last_was_tool_call =
125 matches!(msg.role, MessageRole::Agent) && content.contains("tool_calls");
126
127 messages.push(OpenAiMessage { role, content });
128 }
129 messages
130}
131
132#[async_trait]
133impl LanguageModel for OpenAiChatModel {
134 async fn generate(&self, request: LlmRequest) -> anyhow::Result<LlmResponse> {
135 let messages = to_openai_messages(&request);
136 let body = ChatRequest {
137 model: &self.config.model,
138 messages: &messages,
139 stream: None,
140 };
141 let url = self
142 .config
143 .api_url
144 .as_deref()
145 .unwrap_or("https://api.openai.com/v1/chat/completions");
146
147 tracing::debug!(
149 "OpenAI request: model={}, messages={}",
150 self.config.model,
151 messages.len()
152 );
153 for (i, msg) in messages.iter().enumerate() {
154 tracing::debug!(
155 "Message {}: role={}, content_len={}",
156 i,
157 msg.role,
158 msg.content.len()
159 );
160 if msg.content.len() < 500 {
161 tracing::debug!("Message {} content: {}", i, msg.content);
162 }
163 }
164
165 let response = self
166 .client
167 .post(url)
168 .bearer_auth(&self.config.api_key)
169 .json(&body)
170 .send()
171 .await?;
172
173 if !response.status().is_success() {
174 let status = response.status();
175 let error_text = response.text().await.unwrap_or_default();
176 tracing::error!("OpenAI API error: status={}, body={}", status, error_text);
177 return Err(anyhow::anyhow!(
178 "OpenAI API error: {} - {}",
179 status,
180 error_text
181 ));
182 }
183
184 let data: ChatResponse = response.json().await?;
185 let choice = data
186 .choices
187 .into_iter()
188 .next()
189 .ok_or_else(|| anyhow::anyhow!("OpenAI response missing choices"))?;
190
191 Ok(LlmResponse {
192 message: AgentMessage {
193 role: MessageRole::Agent,
194 content: MessageContent::Text(choice.message.content),
195 metadata: None,
196 },
197 })
198 }
199
200 async fn generate_stream(&self, request: LlmRequest) -> anyhow::Result<ChunkStream> {
201 let messages = to_openai_messages(&request);
202 let body = ChatRequest {
203 model: &self.config.model,
204 messages: &messages,
205 stream: Some(true),
206 };
207 let url = self
208 .config
209 .api_url
210 .as_deref()
211 .unwrap_or("https://api.openai.com/v1/chat/completions");
212
213 tracing::debug!(
214 "OpenAI streaming request: model={}, messages={}",
215 self.config.model,
216 messages.len()
217 );
218
219 let response = self
220 .client
221 .post(url)
222 .bearer_auth(&self.config.api_key)
223 .json(&body)
224 .send()
225 .await?;
226
227 if !response.status().is_success() {
228 let status = response.status();
229 let error_text = response.text().await.unwrap_or_default();
230 tracing::error!("OpenAI API error: status={}, body={}", status, error_text);
231 return Err(anyhow::anyhow!(
232 "OpenAI API error: {} - {}",
233 status,
234 error_text
235 ));
236 }
237
238 let stream = response.bytes_stream();
240 let accumulated_content = Arc::new(Mutex::new(String::new()));
241 let buffer = Arc::new(Mutex::new(String::new()));
242
243 let is_done = Arc::new(Mutex::new(false));
244
245 let final_accumulated = accumulated_content.clone();
247 let final_is_done = is_done.clone();
248
249 let chunk_stream = stream.map(move |result| {
250 let accumulated = accumulated_content.clone();
251 let buffer = buffer.clone();
252 let is_done = is_done.clone();
253
254 if *is_done.lock().unwrap() {
256 return Ok(StreamChunk::TextDelta(String::new()));
257 }
258
259 match result {
260 Ok(bytes) => {
261 let text = String::from_utf8_lossy(&bytes);
262
263 buffer.lock().unwrap().push_str(&text);
265
266 let mut buf = buffer.lock().unwrap();
267
268 let mut collected_deltas = String::new();
270 let mut found_done = false;
271 let mut found_finish = false;
272
273 let parts: Vec<&str> = buf.split("\n\n").collect();
275 let complete_messages = if parts.len() > 1 {
276 &parts[..parts.len() - 1] } else {
278 &[] };
280
281 for msg in complete_messages {
283 for line in msg.lines() {
284 if let Some(data) = line.strip_prefix("data: ") {
285 let json_str = data.trim();
286
287 if json_str == "[DONE]" {
289 found_done = true;
290 break;
291 }
292
293 match serde_json::from_str::<StreamResponse>(json_str) {
295 Ok(chunk) => {
296 if let Some(choice) = chunk.choices.first() {
297 if let Some(content) = &choice.delta.content {
299 if !content.is_empty() {
300 accumulated.lock().unwrap().push_str(content);
301 collected_deltas.push_str(content);
302 }
303 }
304
305 if choice.finish_reason.is_some() {
307 found_finish = true;
308 }
309 }
310 }
311 Err(e) => {
312 tracing::debug!("Failed to parse SSE message: {}", e);
313 }
314 }
315 }
316 }
317 if found_done || found_finish {
318 break;
319 }
320 }
321
322 if !complete_messages.is_empty() {
324 *buf = parts.last().unwrap_or(&"").to_string();
325 }
326
327 if found_done || found_finish {
329 let content = accumulated.lock().unwrap().clone();
330 let final_message = AgentMessage {
331 role: MessageRole::Agent,
332 content: MessageContent::Text(content),
333 metadata: None,
334 };
335 *is_done.lock().unwrap() = true;
336 buf.clear();
337 return Ok(StreamChunk::Done {
338 message: final_message,
339 });
340 }
341
342 if !collected_deltas.is_empty() {
344 return Ok(StreamChunk::TextDelta(collected_deltas));
345 }
346
347 Ok(StreamChunk::TextDelta(String::new()))
348 }
349 Err(e) => {
350 if !*is_done.lock().unwrap() {
352 let content = accumulated.lock().unwrap().clone();
353 if !content.is_empty() {
354 let final_message = AgentMessage {
355 role: MessageRole::Agent,
356 content: MessageContent::Text(content),
357 metadata: None,
358 };
359 *is_done.lock().unwrap() = true;
360 return Ok(StreamChunk::Done {
361 message: final_message,
362 });
363 }
364 }
365 Err(anyhow::anyhow!("Stream error: {}", e))
366 }
367 }
368 });
369
370 let stream_with_finale = chunk_stream.chain(futures::stream::once(async move {
372 if !*final_is_done.lock().unwrap() {
374 let content = final_accumulated.lock().unwrap().clone();
375 if !content.is_empty() {
376 let final_message = AgentMessage {
377 role: MessageRole::Agent,
378 content: MessageContent::Text(content),
379 metadata: None,
380 };
381 let content_text = match &final_message.content {
382 MessageContent::Text(t) => t.as_str(),
383 _ => "non-text",
384 };
385 tracing::debug!(
386 "Stream ended naturally, sending final Done chunk with {} chars",
387 content_text.len()
388 );
389 return Ok(StreamChunk::Done {
390 message: final_message,
391 });
392 }
393 }
394 Ok(StreamChunk::TextDelta(String::new()))
396 }));
397
398 Ok(Box::pin(stream_with_finale))
399 }
400}