agents_runtime/providers/
openai.rs1use agents_core::llm::{LanguageModel, LlmRequest, LlmResponse};
2use agents_core::messaging::{AgentMessage, MessageContent, MessageRole};
3use async_trait::async_trait;
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6
7#[derive(Clone)]
8pub struct OpenAiConfig {
9 pub api_key: String,
10 pub model: String,
11 pub api_url: Option<String>,
12}
13
14impl OpenAiConfig {
15 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
16 Self {
17 api_key: api_key.into(),
18 model: model.into(),
19 api_url: None,
20 }
21 }
22
23 pub fn with_api_url(mut self, api_url: Option<String>) -> Self {
24 self.api_url = api_url;
25 self
26 }
27}
28
29pub struct OpenAiChatModel {
30 client: Client,
31 config: OpenAiConfig,
32}
33
34impl OpenAiChatModel {
35 pub fn new(config: OpenAiConfig) -> anyhow::Result<Self> {
36 Ok(Self {
37 client: Client::builder()
38 .user_agent("rust-deep-agents-sdk/0.1")
39 .build()?,
40 config,
41 })
42 }
43}
44
45#[derive(Serialize)]
46struct ChatRequest<'a> {
47 model: &'a str,
48 messages: &'a [OpenAiMessage],
49}
50
51#[derive(Serialize)]
52struct OpenAiMessage {
53 role: &'static str,
54 content: String,
55}
56
57#[derive(Deserialize)]
58struct ChatResponse {
59 choices: Vec<Choice>,
60}
61
62#[derive(Deserialize)]
63struct Choice {
64 message: ChoiceMessage,
65}
66
67#[derive(Deserialize)]
68struct ChoiceMessage {
69 content: String,
70}
71
72fn to_openai_messages(request: &LlmRequest) -> Vec<OpenAiMessage> {
73 let mut messages = Vec::with_capacity(request.messages.len() + 1);
74 messages.push(OpenAiMessage {
75 role: "system",
76 content: request.system_prompt.clone(),
77 });
78
79 let mut last_was_tool_call = false;
81
82 for msg in &request.messages {
83 let role = match msg.role {
84 MessageRole::User => "user",
85 MessageRole::Agent => "assistant",
86 MessageRole::Tool => {
87 if !last_was_tool_call {
89 tracing::warn!("Skipping tool message without preceding tool_calls");
90 continue;
91 }
92 "tool"
93 }
94 MessageRole::System => "system",
95 };
96
97 let content = match &msg.content {
98 MessageContent::Text(text) => text.clone(),
99 MessageContent::Json(value) => value.to_string(),
100 };
101
102 last_was_tool_call =
104 matches!(msg.role, MessageRole::Agent) && content.contains("tool_calls");
105
106 messages.push(OpenAiMessage { role, content });
107 }
108 messages
109}
110
111#[async_trait]
112impl LanguageModel for OpenAiChatModel {
113 async fn generate(&self, request: LlmRequest) -> anyhow::Result<LlmResponse> {
114 let messages = to_openai_messages(&request);
115 let body = ChatRequest {
116 model: &self.config.model,
117 messages: &messages,
118 };
119 let url = self
120 .config
121 .api_url
122 .as_deref()
123 .unwrap_or("https://api.openai.com/v1/chat/completions");
124
125 tracing::debug!(
127 "OpenAI request: model={}, messages={}",
128 self.config.model,
129 messages.len()
130 );
131 for (i, msg) in messages.iter().enumerate() {
132 tracing::debug!(
133 "Message {}: role={}, content_len={}",
134 i,
135 msg.role,
136 msg.content.len()
137 );
138 if msg.content.len() < 500 {
139 tracing::debug!("Message {} content: {}", i, msg.content);
140 }
141 }
142
143 let response = self
144 .client
145 .post(url)
146 .bearer_auth(&self.config.api_key)
147 .json(&body)
148 .send()
149 .await?;
150
151 if !response.status().is_success() {
152 let status = response.status();
153 let error_text = response.text().await.unwrap_or_default();
154 tracing::error!("OpenAI API error: status={}, body={}", status, error_text);
155 return Err(anyhow::anyhow!(
156 "OpenAI API error: {} - {}",
157 status,
158 error_text
159 ));
160 }
161
162 let data: ChatResponse = response.json().await?;
163 let choice = data
164 .choices
165 .into_iter()
166 .next()
167 .ok_or_else(|| anyhow::anyhow!("OpenAI response missing choices"))?;
168
169 Ok(LlmResponse {
170 message: AgentMessage {
171 role: MessageRole::Agent,
172 content: MessageContent::Text(choice.message.content),
173 metadata: None,
174 },
175 })
176 }
177}