agents_runtime/providers/
openai.rs1use agents_core::llm::{ChunkStream, LanguageModel, LlmRequest, LlmResponse, StreamChunk};
2use agents_core::messaging::{AgentMessage, MessageContent, MessageRole};
3use agents_core::tools::ToolSchema;
4use async_trait::async_trait;
5use futures::stream::StreamExt;
6use reqwest::Client;
7use serde::{Deserialize, Serialize};
8use std::sync::{Arc, Mutex};
9
10#[derive(Clone)]
11pub struct OpenAiConfig {
12 pub api_key: String,
13 pub model: String,
14 pub api_url: Option<String>,
15}
16
17impl OpenAiConfig {
18 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
19 Self {
20 api_key: api_key.into(),
21 model: model.into(),
22 api_url: None,
23 }
24 }
25
26 pub fn with_api_url(mut self, api_url: Option<String>) -> Self {
27 self.api_url = api_url;
28 self
29 }
30}
31
32pub struct OpenAiChatModel {
33 client: Client,
34 config: OpenAiConfig,
35}
36
37impl OpenAiChatModel {
38 pub fn new(config: OpenAiConfig) -> anyhow::Result<Self> {
39 Ok(Self {
40 client: Client::builder()
41 .user_agent("rust-deep-agents-sdk/0.1")
42 .build()?,
43 config,
44 })
45 }
46}
47
48#[derive(Serialize)]
49struct ChatRequest<'a> {
50 model: &'a str,
51 messages: &'a [OpenAiMessage],
52 #[serde(skip_serializing_if = "Option::is_none")]
53 stream: Option<bool>,
54 #[serde(skip_serializing_if = "Option::is_none")]
55 tools: Option<Vec<OpenAiTool>>,
56}
57
58#[derive(Serialize)]
59struct OpenAiMessage {
60 role: &'static str,
61 content: String,
62}
63
64#[derive(Clone, Serialize)]
65struct OpenAiTool {
66 #[serde(rename = "type")]
67 tool_type: String,
68 function: OpenAiFunction,
69}
70
71#[derive(Clone, Serialize)]
72struct OpenAiFunction {
73 name: String,
74 description: String,
75 parameters: serde_json::Value,
76}
77
78#[derive(Deserialize)]
79struct ChatResponse {
80 choices: Vec<Choice>,
81}
82
83#[derive(Deserialize)]
84struct Choice {
85 message: ChoiceMessage,
86}
87
88#[derive(Deserialize)]
89struct ChoiceMessage {
90 content: Option<String>,
91 #[serde(default)]
92 tool_calls: Vec<OpenAiToolCall>,
93}
94
95#[derive(Deserialize)]
96struct OpenAiToolCall {
97 #[allow(dead_code)]
98 id: String,
99 #[serde(rename = "type")]
100 #[allow(dead_code)]
101 tool_type: String,
102 function: OpenAiFunctionCall,
103}
104
105#[derive(Deserialize)]
106struct OpenAiFunctionCall {
107 name: String,
108 arguments: String,
109}
110
111#[derive(Deserialize)]
113struct StreamResponse {
114 choices: Vec<StreamChoice>,
115}
116
117#[derive(Deserialize)]
118struct StreamChoice {
119 delta: StreamDelta,
120 finish_reason: Option<String>,
121}
122
123#[derive(Deserialize)]
124struct StreamDelta {
125 content: Option<String>,
126}
127
128fn to_openai_messages(request: &LlmRequest) -> Vec<OpenAiMessage> {
129 let mut messages = Vec::with_capacity(request.messages.len() + 1);
130 messages.push(OpenAiMessage {
131 role: "system",
132 content: request.system_prompt.clone(),
133 });
134
135 for msg in &request.messages {
139 let role = match msg.role {
140 MessageRole::User => "user",
141 MessageRole::Agent => "assistant",
142 MessageRole::Tool => "user", MessageRole::System => "system",
144 };
145
146 let content = match &msg.content {
147 MessageContent::Text(text) => text.clone(),
148 MessageContent::Json(value) => value.to_string(),
149 };
150
151 messages.push(OpenAiMessage { role, content });
152 }
153 messages
154}
155
156fn to_openai_tools(tools: &[ToolSchema]) -> Option<Vec<OpenAiTool>> {
158 if tools.is_empty() {
159 return None;
160 }
161
162 Some(
163 tools
164 .iter()
165 .map(|tool| OpenAiTool {
166 tool_type: "function".to_string(),
167 function: OpenAiFunction {
168 name: tool.name.clone(),
169 description: tool.description.clone(),
170 parameters: serde_json::to_value(&tool.parameters)
171 .unwrap_or_else(|_| serde_json::json!({})),
172 },
173 })
174 .collect(),
175 )
176}
177
178#[async_trait]
179impl LanguageModel for OpenAiChatModel {
180 async fn generate(&self, request: LlmRequest) -> anyhow::Result<LlmResponse> {
181 let messages = to_openai_messages(&request);
182 let tools = to_openai_tools(&request.tools);
183
184 let body = ChatRequest {
185 model: &self.config.model,
186 messages: &messages,
187 stream: None,
188 tools: tools.clone(),
189 };
190 let url = self
191 .config
192 .api_url
193 .as_deref()
194 .unwrap_or("https://api.openai.com/v1/chat/completions");
195
196 tracing::debug!(
198 "OpenAI request: model={}, messages={}, tools={}",
199 self.config.model,
200 messages.len(),
201 tools.as_ref().map(|t| t.len()).unwrap_or(0)
202 );
203 for (i, msg) in messages.iter().enumerate() {
204 tracing::debug!(
205 "Message {}: role={}, content_len={}",
206 i,
207 msg.role,
208 msg.content.len()
209 );
210 if msg.content.len() < 500 {
211 tracing::debug!("Message {} content: {}", i, msg.content);
212 }
213 }
214
215 let response = self
216 .client
217 .post(url)
218 .bearer_auth(&self.config.api_key)
219 .json(&body)
220 .send()
221 .await?;
222
223 if !response.status().is_success() {
224 let status = response.status();
225 let error_text = response.text().await.unwrap_or_default();
226 tracing::error!("OpenAI API error: status={}, body={}", status, error_text);
227 return Err(anyhow::anyhow!(
228 "OpenAI API error: {} - {}",
229 status,
230 error_text
231 ));
232 }
233
234 let data: ChatResponse = response.json().await?;
235 let choice = data
236 .choices
237 .into_iter()
238 .next()
239 .ok_or_else(|| anyhow::anyhow!("OpenAI response missing choices"))?;
240
241 if !choice.message.tool_calls.is_empty() {
243 let tool_calls: Vec<_> = choice
245 .message
246 .tool_calls
247 .iter()
248 .map(|tc| {
249 serde_json::json!({
250 "name": tc.function.name,
251 "args": serde_json::from_str::<serde_json::Value>(&tc.function.arguments)
252 .unwrap_or_else(|_| serde_json::json!({}))
253 })
254 })
255 .collect();
256
257 let tool_names: Vec<&str> = choice
259 .message
260 .tool_calls
261 .iter()
262 .map(|tc| tc.function.name.as_str())
263 .collect();
264
265 tracing::warn!(
266 "🔧 LLM CALLED {} TOOL(S): {:?}",
267 tool_calls.len(),
268 tool_names
269 );
270
271 for (i, tc) in choice.message.tool_calls.iter().enumerate() {
273 tracing::debug!(
274 "Tool call {}: {} with {} bytes of arguments",
275 i + 1,
276 tc.function.name,
277 tc.function.arguments.len()
278 );
279 }
280
281 return Ok(LlmResponse {
282 message: AgentMessage {
283 role: MessageRole::Agent,
284 content: MessageContent::Json(serde_json::json!({
285 "tool_calls": tool_calls
286 })),
287 metadata: None,
288 },
289 });
290 }
291
292 let content = choice.message.content.unwrap_or_else(|| "".to_string());
294
295 Ok(LlmResponse {
296 message: AgentMessage {
297 role: MessageRole::Agent,
298 content: MessageContent::Text(content),
299 metadata: None,
300 },
301 })
302 }
303
304 async fn generate_stream(&self, request: LlmRequest) -> anyhow::Result<ChunkStream> {
305 let messages = to_openai_messages(&request);
306 let tools = to_openai_tools(&request.tools);
307
308 let body = ChatRequest {
309 model: &self.config.model,
310 messages: &messages,
311 stream: Some(true),
312 tools,
313 };
314 let url = self
315 .config
316 .api_url
317 .as_deref()
318 .unwrap_or("https://api.openai.com/v1/chat/completions");
319
320 tracing::debug!(
321 "OpenAI streaming request: model={}, messages={}, tools={}",
322 self.config.model,
323 messages.len(),
324 request.tools.len()
325 );
326
327 let response = self
328 .client
329 .post(url)
330 .bearer_auth(&self.config.api_key)
331 .json(&body)
332 .send()
333 .await?;
334
335 if !response.status().is_success() {
336 let status = response.status();
337 let error_text = response.text().await.unwrap_or_default();
338 tracing::error!("OpenAI API error: status={}, body={}", status, error_text);
339 return Err(anyhow::anyhow!(
340 "OpenAI API error: {} - {}",
341 status,
342 error_text
343 ));
344 }
345
346 let stream = response.bytes_stream();
348 let accumulated_content = Arc::new(Mutex::new(String::new()));
349 let buffer = Arc::new(Mutex::new(String::new()));
350
351 let is_done = Arc::new(Mutex::new(false));
352
353 let final_accumulated = accumulated_content.clone();
355 let final_is_done = is_done.clone();
356
357 let chunk_stream = stream.map(move |result| {
358 let accumulated = accumulated_content.clone();
359 let buffer = buffer.clone();
360 let is_done = is_done.clone();
361
362 if *is_done.lock().unwrap() {
364 return Ok(StreamChunk::TextDelta(String::new()));
365 }
366
367 match result {
368 Ok(bytes) => {
369 let text = String::from_utf8_lossy(&bytes);
370
371 buffer.lock().unwrap().push_str(&text);
373
374 let mut buf = buffer.lock().unwrap();
375
376 let mut collected_deltas = String::new();
378 let mut found_done = false;
379 let mut found_finish = false;
380
381 let parts: Vec<&str> = buf.split("\n\n").collect();
383 let complete_messages = if parts.len() > 1 {
384 &parts[..parts.len() - 1] } else {
386 &[] };
388
389 for msg in complete_messages {
391 for line in msg.lines() {
392 if let Some(data) = line.strip_prefix("data: ") {
393 let json_str = data.trim();
394
395 if json_str == "[DONE]" {
397 found_done = true;
398 break;
399 }
400
401 match serde_json::from_str::<StreamResponse>(json_str) {
403 Ok(chunk) => {
404 if let Some(choice) = chunk.choices.first() {
405 if let Some(content) = &choice.delta.content {
407 if !content.is_empty() {
408 accumulated.lock().unwrap().push_str(content);
409 collected_deltas.push_str(content);
410 }
411 }
412
413 if choice.finish_reason.is_some() {
415 found_finish = true;
416 }
417 }
418 }
419 Err(e) => {
420 tracing::debug!("Failed to parse SSE message: {}", e);
421 }
422 }
423 }
424 }
425 if found_done || found_finish {
426 break;
427 }
428 }
429
430 if !complete_messages.is_empty() {
432 *buf = parts.last().unwrap_or(&"").to_string();
433 }
434
435 if found_done || found_finish {
437 let content = accumulated.lock().unwrap().clone();
438 let final_message = AgentMessage {
439 role: MessageRole::Agent,
440 content: MessageContent::Text(content),
441 metadata: None,
442 };
443 *is_done.lock().unwrap() = true;
444 buf.clear();
445 return Ok(StreamChunk::Done {
446 message: final_message,
447 });
448 }
449
450 if !collected_deltas.is_empty() {
452 return Ok(StreamChunk::TextDelta(collected_deltas));
453 }
454
455 Ok(StreamChunk::TextDelta(String::new()))
456 }
457 Err(e) => {
458 if !*is_done.lock().unwrap() {
460 let content = accumulated.lock().unwrap().clone();
461 if !content.is_empty() {
462 let final_message = AgentMessage {
463 role: MessageRole::Agent,
464 content: MessageContent::Text(content),
465 metadata: None,
466 };
467 *is_done.lock().unwrap() = true;
468 return Ok(StreamChunk::Done {
469 message: final_message,
470 });
471 }
472 }
473 Err(anyhow::anyhow!("Stream error: {}", e))
474 }
475 }
476 });
477
478 let stream_with_finale = chunk_stream.chain(futures::stream::once(async move {
480 if !*final_is_done.lock().unwrap() {
482 let content = final_accumulated.lock().unwrap().clone();
483 if !content.is_empty() {
484 let final_message = AgentMessage {
485 role: MessageRole::Agent,
486 content: MessageContent::Text(content),
487 metadata: None,
488 };
489 let content_text = match &final_message.content {
490 MessageContent::Text(t) => t.as_str(),
491 _ => "non-text",
492 };
493 tracing::debug!(
494 "Stream ended naturally, sending final Done chunk with {} chars",
495 content_text.len()
496 );
497 return Ok(StreamChunk::Done {
498 message: final_message,
499 });
500 }
501 }
502 Ok(StreamChunk::TextDelta(String::new()))
504 }));
505
506 Ok(Box::pin(stream_with_finale))
507 }
508}