agents_runtime/providers/
openai.rs1use agents_core::llm::{ChunkStream, LanguageModel, LlmRequest, LlmResponse, StreamChunk};
2use agents_core::messaging::{AgentMessage, MessageContent, MessageRole};
3use agents_core::tools::ToolSchema;
4use async_trait::async_trait;
5use futures::stream::StreamExt;
6use reqwest::Client;
7use serde::{Deserialize, Serialize};
8use std::sync::{Arc, Mutex};
9
10#[derive(Clone)]
11pub struct OpenAiConfig {
12 pub api_key: String,
13 pub model: String,
14 pub api_url: Option<String>,
15}
16
17impl OpenAiConfig {
18 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
19 Self {
20 api_key: api_key.into(),
21 model: model.into(),
22 api_url: None,
23 }
24 }
25
26 pub fn with_api_url(mut self, api_url: Option<String>) -> Self {
27 self.api_url = api_url;
28 self
29 }
30}
31
32pub struct OpenAiChatModel {
33 client: Client,
34 config: OpenAiConfig,
35}
36
37impl OpenAiChatModel {
38 pub fn new(config: OpenAiConfig) -> anyhow::Result<Self> {
39 Ok(Self {
40 client: Client::builder()
41 .user_agent("rust-deep-agents-sdk/0.1")
42 .build()?,
43 config,
44 })
45 }
46}
47
48#[derive(Serialize)]
49struct ChatRequest<'a> {
50 model: &'a str,
51 messages: &'a [OpenAiMessage],
52 #[serde(skip_serializing_if = "Option::is_none")]
53 stream: Option<bool>,
54 #[serde(skip_serializing_if = "Option::is_none")]
55 tools: Option<Vec<OpenAiTool>>,
56}
57
58#[derive(Serialize)]
59struct OpenAiMessage {
60 role: &'static str,
61 content: String,
62}
63
64#[derive(Clone, Serialize)]
65struct OpenAiTool {
66 #[serde(rename = "type")]
67 tool_type: String,
68 function: OpenAiFunction,
69}
70
71#[derive(Clone, Serialize)]
72struct OpenAiFunction {
73 name: String,
74 description: String,
75 parameters: serde_json::Value,
76}
77
78#[derive(Deserialize)]
79struct ChatResponse {
80 choices: Vec<Choice>,
81}
82
83#[derive(Deserialize)]
84struct Choice {
85 message: ChoiceMessage,
86}
87
88#[derive(Deserialize)]
89struct ChoiceMessage {
90 content: Option<String>,
91 #[serde(default)]
92 tool_calls: Vec<OpenAiToolCall>,
93}
94
95#[derive(Deserialize)]
96struct OpenAiToolCall {
97 #[allow(dead_code)]
98 id: String,
99 #[serde(rename = "type")]
100 #[allow(dead_code)]
101 tool_type: String,
102 function: OpenAiFunctionCall,
103}
104
105#[derive(Deserialize)]
106struct OpenAiFunctionCall {
107 name: String,
108 arguments: String,
109}
110
111#[derive(Deserialize)]
113struct StreamResponse {
114 choices: Vec<StreamChoice>,
115}
116
117#[derive(Deserialize)]
118struct StreamChoice {
119 delta: StreamDelta,
120 finish_reason: Option<String>,
121}
122
123#[derive(Deserialize)]
124struct StreamDelta {
125 content: Option<String>,
126}
127
128fn to_openai_messages(request: &LlmRequest) -> Vec<OpenAiMessage> {
129 let mut messages = Vec::with_capacity(request.messages.len() + 1);
130 messages.push(OpenAiMessage {
131 role: "system",
132 content: request.system_prompt.clone(),
133 });
134
135 let mut last_was_tool_call = false;
137
138 for msg in &request.messages {
139 let role = match msg.role {
140 MessageRole::User => "user",
141 MessageRole::Agent => "assistant",
142 MessageRole::Tool => {
143 if !last_was_tool_call {
145 tracing::warn!("Skipping tool message without preceding tool_calls");
146 continue;
147 }
148 "tool"
149 }
150 MessageRole::System => "system",
151 };
152
153 let content = match &msg.content {
154 MessageContent::Text(text) => text.clone(),
155 MessageContent::Json(value) => value.to_string(),
156 };
157
158 last_was_tool_call =
160 matches!(msg.role, MessageRole::Agent) && content.contains("tool_calls");
161
162 messages.push(OpenAiMessage { role, content });
163 }
164 messages
165}
166
167fn to_openai_tools(tools: &[ToolSchema]) -> Option<Vec<OpenAiTool>> {
169 if tools.is_empty() {
170 return None;
171 }
172
173 Some(
174 tools
175 .iter()
176 .map(|tool| OpenAiTool {
177 tool_type: "function".to_string(),
178 function: OpenAiFunction {
179 name: tool.name.clone(),
180 description: tool.description.clone(),
181 parameters: serde_json::to_value(&tool.parameters)
182 .unwrap_or_else(|_| serde_json::json!({})),
183 },
184 })
185 .collect(),
186 )
187}
188
189#[async_trait]
190impl LanguageModel for OpenAiChatModel {
191 async fn generate(&self, request: LlmRequest) -> anyhow::Result<LlmResponse> {
192 let messages = to_openai_messages(&request);
193 let tools = to_openai_tools(&request.tools);
194
195 let body = ChatRequest {
196 model: &self.config.model,
197 messages: &messages,
198 stream: None,
199 tools: tools.clone(),
200 };
201 let url = self
202 .config
203 .api_url
204 .as_deref()
205 .unwrap_or("https://api.openai.com/v1/chat/completions");
206
207 tracing::debug!(
209 "OpenAI request: model={}, messages={}, tools={}",
210 self.config.model,
211 messages.len(),
212 tools.as_ref().map(|t| t.len()).unwrap_or(0)
213 );
214 for (i, msg) in messages.iter().enumerate() {
215 tracing::debug!(
216 "Message {}: role={}, content_len={}",
217 i,
218 msg.role,
219 msg.content.len()
220 );
221 if msg.content.len() < 500 {
222 tracing::debug!("Message {} content: {}", i, msg.content);
223 }
224 }
225
226 let response = self
227 .client
228 .post(url)
229 .bearer_auth(&self.config.api_key)
230 .json(&body)
231 .send()
232 .await?;
233
234 if !response.status().is_success() {
235 let status = response.status();
236 let error_text = response.text().await.unwrap_or_default();
237 tracing::error!("OpenAI API error: status={}, body={}", status, error_text);
238 return Err(anyhow::anyhow!(
239 "OpenAI API error: {} - {}",
240 status,
241 error_text
242 ));
243 }
244
245 let data: ChatResponse = response.json().await?;
246 let choice = data
247 .choices
248 .into_iter()
249 .next()
250 .ok_or_else(|| anyhow::anyhow!("OpenAI response missing choices"))?;
251
252 if !choice.message.tool_calls.is_empty() {
254 let tool_calls: Vec<_> = choice
256 .message
257 .tool_calls
258 .iter()
259 .map(|tc| {
260 serde_json::json!({
261 "name": tc.function.name,
262 "args": serde_json::from_str::<serde_json::Value>(&tc.function.arguments)
263 .unwrap_or_else(|_| serde_json::json!({}))
264 })
265 })
266 .collect();
267
268 tracing::debug!("OpenAI response contains {} tool calls", tool_calls.len());
269
270 return Ok(LlmResponse {
271 message: AgentMessage {
272 role: MessageRole::Agent,
273 content: MessageContent::Json(serde_json::json!({
274 "tool_calls": tool_calls
275 })),
276 metadata: None,
277 },
278 });
279 }
280
281 let content = choice.message.content.unwrap_or_else(|| "".to_string());
283
284 Ok(LlmResponse {
285 message: AgentMessage {
286 role: MessageRole::Agent,
287 content: MessageContent::Text(content),
288 metadata: None,
289 },
290 })
291 }
292
293 async fn generate_stream(&self, request: LlmRequest) -> anyhow::Result<ChunkStream> {
294 let messages = to_openai_messages(&request);
295 let tools = to_openai_tools(&request.tools);
296
297 let body = ChatRequest {
298 model: &self.config.model,
299 messages: &messages,
300 stream: Some(true),
301 tools,
302 };
303 let url = self
304 .config
305 .api_url
306 .as_deref()
307 .unwrap_or("https://api.openai.com/v1/chat/completions");
308
309 tracing::debug!(
310 "OpenAI streaming request: model={}, messages={}, tools={}",
311 self.config.model,
312 messages.len(),
313 request.tools.len()
314 );
315
316 let response = self
317 .client
318 .post(url)
319 .bearer_auth(&self.config.api_key)
320 .json(&body)
321 .send()
322 .await?;
323
324 if !response.status().is_success() {
325 let status = response.status();
326 let error_text = response.text().await.unwrap_or_default();
327 tracing::error!("OpenAI API error: status={}, body={}", status, error_text);
328 return Err(anyhow::anyhow!(
329 "OpenAI API error: {} - {}",
330 status,
331 error_text
332 ));
333 }
334
335 let stream = response.bytes_stream();
337 let accumulated_content = Arc::new(Mutex::new(String::new()));
338 let buffer = Arc::new(Mutex::new(String::new()));
339
340 let is_done = Arc::new(Mutex::new(false));
341
342 let final_accumulated = accumulated_content.clone();
344 let final_is_done = is_done.clone();
345
346 let chunk_stream = stream.map(move |result| {
347 let accumulated = accumulated_content.clone();
348 let buffer = buffer.clone();
349 let is_done = is_done.clone();
350
351 if *is_done.lock().unwrap() {
353 return Ok(StreamChunk::TextDelta(String::new()));
354 }
355
356 match result {
357 Ok(bytes) => {
358 let text = String::from_utf8_lossy(&bytes);
359
360 buffer.lock().unwrap().push_str(&text);
362
363 let mut buf = buffer.lock().unwrap();
364
365 let mut collected_deltas = String::new();
367 let mut found_done = false;
368 let mut found_finish = false;
369
370 let parts: Vec<&str> = buf.split("\n\n").collect();
372 let complete_messages = if parts.len() > 1 {
373 &parts[..parts.len() - 1] } else {
375 &[] };
377
378 for msg in complete_messages {
380 for line in msg.lines() {
381 if let Some(data) = line.strip_prefix("data: ") {
382 let json_str = data.trim();
383
384 if json_str == "[DONE]" {
386 found_done = true;
387 break;
388 }
389
390 match serde_json::from_str::<StreamResponse>(json_str) {
392 Ok(chunk) => {
393 if let Some(choice) = chunk.choices.first() {
394 if let Some(content) = &choice.delta.content {
396 if !content.is_empty() {
397 accumulated.lock().unwrap().push_str(content);
398 collected_deltas.push_str(content);
399 }
400 }
401
402 if choice.finish_reason.is_some() {
404 found_finish = true;
405 }
406 }
407 }
408 Err(e) => {
409 tracing::debug!("Failed to parse SSE message: {}", e);
410 }
411 }
412 }
413 }
414 if found_done || found_finish {
415 break;
416 }
417 }
418
419 if !complete_messages.is_empty() {
421 *buf = parts.last().unwrap_or(&"").to_string();
422 }
423
424 if found_done || found_finish {
426 let content = accumulated.lock().unwrap().clone();
427 let final_message = AgentMessage {
428 role: MessageRole::Agent,
429 content: MessageContent::Text(content),
430 metadata: None,
431 };
432 *is_done.lock().unwrap() = true;
433 buf.clear();
434 return Ok(StreamChunk::Done {
435 message: final_message,
436 });
437 }
438
439 if !collected_deltas.is_empty() {
441 return Ok(StreamChunk::TextDelta(collected_deltas));
442 }
443
444 Ok(StreamChunk::TextDelta(String::new()))
445 }
446 Err(e) => {
447 if !*is_done.lock().unwrap() {
449 let content = accumulated.lock().unwrap().clone();
450 if !content.is_empty() {
451 let final_message = AgentMessage {
452 role: MessageRole::Agent,
453 content: MessageContent::Text(content),
454 metadata: None,
455 };
456 *is_done.lock().unwrap() = true;
457 return Ok(StreamChunk::Done {
458 message: final_message,
459 });
460 }
461 }
462 Err(anyhow::anyhow!("Stream error: {}", e))
463 }
464 }
465 });
466
467 let stream_with_finale = chunk_stream.chain(futures::stream::once(async move {
469 if !*final_is_done.lock().unwrap() {
471 let content = final_accumulated.lock().unwrap().clone();
472 if !content.is_empty() {
473 let final_message = AgentMessage {
474 role: MessageRole::Agent,
475 content: MessageContent::Text(content),
476 metadata: None,
477 };
478 let content_text = match &final_message.content {
479 MessageContent::Text(t) => t.as_str(),
480 _ => "non-text",
481 };
482 tracing::debug!(
483 "Stream ended naturally, sending final Done chunk with {} chars",
484 content_text.len()
485 );
486 return Ok(StreamChunk::Done {
487 message: final_message,
488 });
489 }
490 }
491 Ok(StreamChunk::TextDelta(String::new()))
493 }));
494
495 Ok(Box::pin(stream_with_finale))
496 }
497}