agents_runtime/providers/
openai.rs1use agents_core::llm::{ChunkStream, LanguageModel, LlmRequest, LlmResponse, StreamChunk};
2use agents_core::messaging::{AgentMessage, MessageContent, MessageRole};
3use agents_core::tools::ToolSchema;
4use async_trait::async_trait;
5use futures::stream::StreamExt;
6use reqwest::Client;
7use serde::{Deserialize, Serialize};
8use std::sync::{Arc, Mutex};
9
10#[derive(Clone)]
11pub struct OpenAiConfig {
12 pub api_key: String,
13 pub model: String,
14 pub api_url: Option<String>,
15 pub custom_headers: Vec<(String, String)>,
16}
17
18impl OpenAiConfig {
19 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
20 Self {
21 api_key: api_key.into(),
22 model: model.into(),
23 api_url: None,
24 custom_headers: Vec::new(),
25 }
26 }
27
28 pub fn with_api_url(mut self, api_url: Option<String>) -> Self {
29 self.api_url = api_url;
30 self
31 }
32
33 pub fn with_custom_headers(mut self, headers: Vec<(String, String)>) -> Self {
34 self.custom_headers = headers;
35 self
36 }
37}
38
39pub struct OpenAiChatModel {
40 client: Client,
41 config: OpenAiConfig,
42}
43
44impl OpenAiChatModel {
45 pub fn new(config: OpenAiConfig) -> anyhow::Result<Self> {
46 Ok(Self {
47 client: Client::builder()
48 .user_agent("rust-deep-agents-sdk/0.1")
49 .build()?,
50 config,
51 })
52 }
53}
54
55#[derive(Serialize)]
56struct ChatRequest<'a> {
57 model: &'a str,
58 messages: &'a [OpenAiMessage],
59 #[serde(skip_serializing_if = "Option::is_none")]
60 stream: Option<bool>,
61 #[serde(skip_serializing_if = "Option::is_none")]
62 tools: Option<Vec<OpenAiTool>>,
63}
64
65#[derive(Serialize)]
66struct OpenAiMessage {
67 role: &'static str,
68 content: String,
69}
70
71#[derive(Clone, Serialize)]
72struct OpenAiTool {
73 #[serde(rename = "type")]
74 tool_type: String,
75 function: OpenAiFunction,
76}
77
78#[derive(Clone, Serialize)]
79struct OpenAiFunction {
80 name: String,
81 description: String,
82 parameters: serde_json::Value,
83}
84
85#[derive(Deserialize)]
86struct ChatResponse {
87 choices: Vec<Choice>,
88}
89
90#[derive(Deserialize)]
91struct Choice {
92 message: ChoiceMessage,
93}
94
95#[derive(Deserialize)]
96struct ChoiceMessage {
97 content: Option<String>,
98 #[serde(default)]
99 tool_calls: Vec<OpenAiToolCall>,
100}
101
102#[derive(Deserialize)]
103struct OpenAiToolCall {
104 #[allow(dead_code)]
105 id: String,
106 #[serde(rename = "type")]
107 #[allow(dead_code)]
108 tool_type: String,
109 function: OpenAiFunctionCall,
110}
111
112#[derive(Deserialize)]
113struct OpenAiFunctionCall {
114 name: String,
115 arguments: String,
116}
117
118#[derive(Deserialize)]
120struct StreamResponse {
121 choices: Vec<StreamChoice>,
122}
123
124#[derive(Deserialize)]
125struct StreamChoice {
126 delta: StreamDelta,
127 finish_reason: Option<String>,
128}
129
130#[derive(Deserialize)]
131struct StreamDelta {
132 content: Option<String>,
133}
134
135fn to_openai_messages(request: &LlmRequest) -> Vec<OpenAiMessage> {
136 let mut messages = Vec::with_capacity(request.messages.len() + 1);
137 messages.push(OpenAiMessage {
138 role: "system",
139 content: request.system_prompt.clone(),
140 });
141
142 for msg in &request.messages {
146 let role = match msg.role {
147 MessageRole::User => "user",
148 MessageRole::Agent => "assistant",
149 MessageRole::Tool => "user", MessageRole::System => "system",
151 };
152
153 let content = match &msg.content {
154 MessageContent::Text(text) => text.clone(),
155 MessageContent::Json(value) => value.to_string(),
156 };
157
158 messages.push(OpenAiMessage { role, content });
159 }
160 messages
161}
162
163fn to_openai_tools(tools: &[ToolSchema]) -> Option<Vec<OpenAiTool>> {
165 if tools.is_empty() {
166 return None;
167 }
168
169 Some(
170 tools
171 .iter()
172 .map(|tool| OpenAiTool {
173 tool_type: "function".to_string(),
174 function: OpenAiFunction {
175 name: tool.name.clone(),
176 description: tool.description.clone(),
177 parameters: serde_json::to_value(&tool.parameters)
178 .unwrap_or_else(|_| serde_json::json!({})),
179 },
180 })
181 .collect(),
182 )
183}
184
185#[async_trait]
186impl LanguageModel for OpenAiChatModel {
187 async fn generate(&self, request: LlmRequest) -> anyhow::Result<LlmResponse> {
188 let messages = to_openai_messages(&request);
189 let tools = to_openai_tools(&request.tools);
190
191 let body = ChatRequest {
192 model: &self.config.model,
193 messages: &messages,
194 stream: None,
195 tools: tools.clone(),
196 };
197 let url = self
198 .config
199 .api_url
200 .as_deref()
201 .unwrap_or("https://api.openai.com/v1/chat/completions");
202
203 tracing::debug!(
205 "OpenAI request: model={}, messages={}, tools={}",
206 self.config.model,
207 messages.len(),
208 tools.as_ref().map(|t| t.len()).unwrap_or(0)
209 );
210 for (i, msg) in messages.iter().enumerate() {
211 tracing::debug!(
212 "Message {}: role={}, content_len={}",
213 i,
214 msg.role,
215 msg.content.len()
216 );
217 if msg.content.len() < 500 {
218 tracing::debug!("Message {} content: {}", i, msg.content);
219 }
220 }
221
222 let mut request = self.client.post(url).bearer_auth(&self.config.api_key);
223
224 for (key, value) in &self.config.custom_headers {
225 request = request.header(key, value);
226 }
227
228 let response = request.json(&body).send().await?;
229
230 if !response.status().is_success() {
231 let status = response.status();
232 let error_text = response.text().await.unwrap_or_default();
233 tracing::error!("OpenAI API error: status={}, body={}", status, error_text);
234 return Err(anyhow::anyhow!(
235 "OpenAI API error: {} - {}",
236 status,
237 error_text
238 ));
239 }
240
241 let data: ChatResponse = response.json().await?;
242 let choice = data
243 .choices
244 .into_iter()
245 .next()
246 .ok_or_else(|| anyhow::anyhow!("OpenAI response missing choices"))?;
247
248 if !choice.message.tool_calls.is_empty() {
250 let tool_calls: Vec<_> = choice
252 .message
253 .tool_calls
254 .iter()
255 .map(|tc| {
256 serde_json::json!({
257 "name": tc.function.name,
258 "args": serde_json::from_str::<serde_json::Value>(&tc.function.arguments)
259 .unwrap_or_else(|_| serde_json::json!({}))
260 })
261 })
262 .collect();
263
264 let tool_names: Vec<&str> = choice
266 .message
267 .tool_calls
268 .iter()
269 .map(|tc| tc.function.name.as_str())
270 .collect();
271
272 tracing::warn!(
273 "🔧 LLM CALLED {} TOOL(S): {:?}",
274 tool_calls.len(),
275 tool_names
276 );
277
278 for (i, tc) in choice.message.tool_calls.iter().enumerate() {
280 tracing::debug!(
281 "Tool call {}: {} with {} bytes of arguments",
282 i + 1,
283 tc.function.name,
284 tc.function.arguments.len()
285 );
286 }
287
288 return Ok(LlmResponse {
289 message: AgentMessage {
290 role: MessageRole::Agent,
291 content: MessageContent::Json(serde_json::json!({
292 "tool_calls": tool_calls
293 })),
294 metadata: None,
295 },
296 });
297 }
298
299 let content = choice.message.content.unwrap_or_else(|| "".to_string());
301
302 Ok(LlmResponse {
303 message: AgentMessage {
304 role: MessageRole::Agent,
305 content: MessageContent::Text(content),
306 metadata: None,
307 },
308 })
309 }
310
311 async fn generate_stream(&self, request: LlmRequest) -> anyhow::Result<ChunkStream> {
312 let messages = to_openai_messages(&request);
313 let tools = to_openai_tools(&request.tools);
314
315 let body = ChatRequest {
316 model: &self.config.model,
317 messages: &messages,
318 stream: Some(true),
319 tools,
320 };
321 let url = self
322 .config
323 .api_url
324 .as_deref()
325 .unwrap_or("https://api.openai.com/v1/chat/completions");
326
327 tracing::debug!(
328 "OpenAI streaming request: model={}, messages={}, tools={}",
329 self.config.model,
330 messages.len(),
331 request.tools.len()
332 );
333
334 let mut http_request = self.client.post(url).bearer_auth(&self.config.api_key);
335
336 for (key, value) in &self.config.custom_headers {
337 http_request = http_request.header(key, value);
338 }
339
340 let response = http_request.json(&body).send().await?;
341
342 if !response.status().is_success() {
343 let status = response.status();
344 let error_text = response.text().await.unwrap_or_default();
345 tracing::error!("OpenAI API error: status={}, body={}", status, error_text);
346 return Err(anyhow::anyhow!(
347 "OpenAI API error: {} - {}",
348 status,
349 error_text
350 ));
351 }
352
353 let stream = response.bytes_stream();
355 let accumulated_content = Arc::new(Mutex::new(String::new()));
356 let buffer = Arc::new(Mutex::new(String::new()));
357
358 let is_done = Arc::new(Mutex::new(false));
359
360 let final_accumulated = accumulated_content.clone();
362 let final_is_done = is_done.clone();
363
364 let chunk_stream = stream.map(move |result| {
365 let accumulated = accumulated_content.clone();
366 let buffer = buffer.clone();
367 let is_done = is_done.clone();
368
369 if *is_done.lock().unwrap() {
371 return Ok(StreamChunk::TextDelta(String::new()));
372 }
373
374 match result {
375 Ok(bytes) => {
376 let text = String::from_utf8_lossy(&bytes);
377
378 buffer.lock().unwrap().push_str(&text);
380
381 let mut buf = buffer.lock().unwrap();
382
383 let mut collected_deltas = String::new();
385 let mut found_done = false;
386 let mut found_finish = false;
387
388 let parts: Vec<&str> = buf.split("\n\n").collect();
390 let complete_messages = if parts.len() > 1 {
391 &parts[..parts.len() - 1] } else {
393 &[] };
395
396 for msg in complete_messages {
398 for line in msg.lines() {
399 if let Some(data) = line.strip_prefix("data: ") {
400 let json_str = data.trim();
401
402 if json_str == "[DONE]" {
404 found_done = true;
405 break;
406 }
407
408 match serde_json::from_str::<StreamResponse>(json_str) {
410 Ok(chunk) => {
411 if let Some(choice) = chunk.choices.first() {
412 if let Some(content) = &choice.delta.content {
414 if !content.is_empty() {
415 accumulated.lock().unwrap().push_str(content);
416 collected_deltas.push_str(content);
417 }
418 }
419
420 if choice.finish_reason.is_some() {
422 found_finish = true;
423 }
424 }
425 }
426 Err(e) => {
427 tracing::debug!("Failed to parse SSE message: {}", e);
428 }
429 }
430 }
431 }
432 if found_done || found_finish {
433 break;
434 }
435 }
436
437 if !complete_messages.is_empty() {
439 *buf = parts.last().unwrap_or(&"").to_string();
440 }
441
442 if found_done || found_finish {
444 let content = accumulated.lock().unwrap().clone();
445 let final_message = AgentMessage {
446 role: MessageRole::Agent,
447 content: MessageContent::Text(content),
448 metadata: None,
449 };
450 *is_done.lock().unwrap() = true;
451 buf.clear();
452 return Ok(StreamChunk::Done {
453 message: final_message,
454 });
455 }
456
457 if !collected_deltas.is_empty() {
459 return Ok(StreamChunk::TextDelta(collected_deltas));
460 }
461
462 Ok(StreamChunk::TextDelta(String::new()))
463 }
464 Err(e) => {
465 if !*is_done.lock().unwrap() {
467 let content = accumulated.lock().unwrap().clone();
468 if !content.is_empty() {
469 let final_message = AgentMessage {
470 role: MessageRole::Agent,
471 content: MessageContent::Text(content),
472 metadata: None,
473 };
474 *is_done.lock().unwrap() = true;
475 return Ok(StreamChunk::Done {
476 message: final_message,
477 });
478 }
479 }
480 Err(anyhow::anyhow!("Stream error: {}", e))
481 }
482 }
483 });
484
485 let stream_with_finale = chunk_stream.chain(futures::stream::once(async move {
487 if !*final_is_done.lock().unwrap() {
489 let content = final_accumulated.lock().unwrap().clone();
490 if !content.is_empty() {
491 let final_message = AgentMessage {
492 role: MessageRole::Agent,
493 content: MessageContent::Text(content),
494 metadata: None,
495 };
496 let content_text = match &final_message.content {
497 MessageContent::Text(t) => t.as_str(),
498 _ => "non-text",
499 };
500 tracing::debug!(
501 "Stream ended naturally, sending final Done chunk with {} chars",
502 content_text.len()
503 );
504 return Ok(StreamChunk::Done {
505 message: final_message,
506 });
507 }
508 }
509 Ok(StreamChunk::TextDelta(String::new()))
511 }));
512
513 Ok(Box::pin(stream_with_finale))
514 }
515}
516
517#[cfg(test)]
518mod tests {
519 use super::*;
520
521 #[test]
522 fn openai_config_new_initializes_empty_custom_headers() {
523 let config = OpenAiConfig::new("test-key", "gpt-4");
524 assert_eq!(config.api_key, "test-key");
525 assert_eq!(config.model, "gpt-4");
526 assert!(config.custom_headers.is_empty());
527 assert!(config.api_url.is_none());
528 }
529
530 #[test]
531 fn openai_config_with_custom_headers_sets_headers() {
532 let headers = vec![
533 ("X-Custom-Header".to_string(), "value1".to_string()),
534 ("X-Another-Header".to_string(), "value2".to_string()),
535 ];
536 let config = OpenAiConfig::new("test-key", "gpt-4").with_custom_headers(headers.clone());
537
538 assert_eq!(config.custom_headers.len(), 2);
539 assert_eq!(config.custom_headers[0].0, "X-Custom-Header");
540 assert_eq!(config.custom_headers[0].1, "value1");
541 assert_eq!(config.custom_headers[1].0, "X-Another-Header");
542 assert_eq!(config.custom_headers[1].1, "value2");
543 }
544}