openai_agents_rust/model/
openai_chat.rs1use crate::config::Config;
2use crate::error::AgentError;
3use crate::model::{Model, ModelResponse, ToolCall};
4use crate::utils::env::var_bool;
5use crate::utils::env::var_opt;
6use async_trait::async_trait;
7use reqwest::Client;
8use serde::Deserialize;
9use tracing::debug;
10
11pub struct OpenAiChat {
13 client: Client,
14 config: Config,
15 base_url: String,
16 auth_token: Option<String>,
17}
18
19impl OpenAiChat {
20 pub fn new(config: Config) -> Self {
22 let client = Client::builder()
23 .user_agent("openai-agents-rust")
24 .build()
25 .expect("Failed to build reqwest client");
26 let auth_token = if config.api_key.is_empty() {
27 None
28 } else {
29 Some(config.api_key.clone())
30 };
31 let base_url = config.base_url.clone();
32 Self {
33 client,
34 config,
35 base_url,
36 auth_token,
37 }
38 }
39
40 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
42 self.base_url = base_url.into();
43 self
44 }
45
46 pub fn without_auth(mut self) -> Self {
48 self.auth_token = None;
49 self
50 }
51}
52
53#[derive(Deserialize)]
55struct FunctionCall {
56 name: String,
57 arguments: serde_json::Value,
59}
60#[derive(Deserialize)]
61struct ToolCallJson {
62 #[serde(rename = "type")]
63 _type: Option<String>,
64 id: Option<String>,
65 call_id: Option<String>,
66 function: Option<FunctionCall>,
67}
68#[derive(Deserialize)]
69struct Message {
70 content: Option<String>,
71 tool_calls: Option<Vec<ToolCallJson>>,
72 function_call: Option<FunctionCall>,
74}
75#[derive(Deserialize)]
76struct Choice {
77 message: Message,
78}
79#[derive(Deserialize)]
80struct ChatCompletion {
81 choices: Vec<Choice>,
82}
83
84fn parse_chat_completion(body: ChatCompletion) -> ModelResponse {
85 let mut text: Option<String> = None;
86 let mut tool_calls: Vec<ToolCall> = Vec::new();
87 if let Some(first) = body.choices.into_iter().next() {
88 text = first.message.content;
89 if let Some(tcs) = first.message.tool_calls {
90 for tc in tcs.into_iter() {
91 if let Some(func) = tc.function {
92 tool_calls.push(ToolCall {
93 id: tc.id,
94 name: func.name,
95 arguments: match func.arguments {
96 serde_json::Value::String(s) => s,
97 other => other.to_string(),
98 },
99 call_id: tc.call_id,
100 });
101 }
102 }
103 } else if let Some(func) = first.message.function_call {
104 tool_calls.push(ToolCall {
106 id: None,
107 name: func.name,
108 arguments: match func.arguments {
109 serde_json::Value::String(s) => s,
110 other => other.to_string(),
111 },
112 call_id: None,
113 });
114 }
115 }
116 ModelResponse {
117 id: None,
118 text,
119 tool_calls,
120 }
121}
122
123#[async_trait]
124impl Model for OpenAiChat {
125 async fn generate(&self, prompt: &str) -> Result<String, AgentError> {
127 let url = format!("{}/chat/completions", self.base_url);
128 let mut rb = self.client.post(&url);
129 if let Some(token) = &self.auth_token {
130 rb = rb.bearer_auth(token);
131 }
132 let resp = rb
133 .json(&serde_json::json!({
134 "model": self.config.model,
135 "messages": [{ "role": "user", "content": prompt }],
136 }))
137 .send()
138 .await
139 .map_err(AgentError::from)?;
140
141 let text = resp.text().await.map_err(AgentError::from)?;
142
143 Ok(text)
144 }
145
146 async fn get_response(
148 &self,
149 system_instructions: Option<&str>,
150 input: &str,
151 _model_settings: Option<std::collections::HashMap<String, String>>,
152 messages: Option<&[serde_json::Value]>,
153 tools: Option<&[serde_json::Value]>,
154 tool_choice: Option<serde_json::Value>,
155 _output_schema: Option<&str>,
156 _handoffs: Option<&[String]>,
157 _tracing_enabled: bool,
158 _previous_response_id: Option<&str>,
159 _prompt_config: Option<&str>,
160 ) -> Result<ModelResponse, AgentError> {
161 let url = format!("{}/chat/completions", self.base_url);
162 let mut msgs: Vec<serde_json::Value> = Vec::new();
164 if let Some(provided) = messages {
165 msgs.extend_from_slice(provided);
166 } else {
167 if let Some(sys) = system_instructions {
168 msgs.push(serde_json::json!({"role": "system", "content": sys}));
169 }
170 msgs.push(serde_json::json!({"role": "user", "content": input}));
171 }
172
173 let minimal_payload = var_bool("VLLM_MIN_PAYLOAD", false);
175 let force_functions = var_bool("VLLM_FORCE_FUNCTIONS", false);
176 let disable_parallel = var_bool("VLLM_DISABLE_PARALLEL_TOOL_CALLS", false);
178 let tool_choice_override = var_opt("VLLM_TOOL_CHOICE");
180
181 let mut payload = if minimal_payload {
183 serde_json::json!({
184 "model": self.config.model,
185 "messages": msgs,
186 })
187 } else {
188 serde_json::json!({
189 "model": self.config.model,
190 "messages": msgs,
191 "max_tokens": 512,
192 "temperature": 0.2,
193 })
194 };
195 let have_tools = if let Some(t) = tools {
196 if force_functions {
197 let mut functions: Vec<serde_json::Value> = Vec::new();
199 for tool in t.iter() {
200 if let Some(obj) = tool.as_object() {
201 if obj.get("type").and_then(|v| v.as_str()) == Some("function") {
202 if let Some(func) = obj.get("function") {
203 functions.push(func.clone());
204 }
205 }
206 }
207 }
208 if !functions.is_empty() {
209 payload["functions"] = serde_json::Value::Array(functions);
210 payload["function_call"] = serde_json::json!("auto");
211 }
212 } else if !minimal_payload {
213 payload["tools"] = serde_json::Value::Array(t.to_vec());
214 if !disable_parallel {
215 payload["parallel_tool_calls"] = serde_json::Value::Bool(true);
216 }
217 }
218 true
219 } else {
220 false
221 };
222 if !minimal_payload {
223 if let Some(choice) = &tool_choice {
224 payload["tool_choice"] = choice.clone();
225 } else if have_tools && !force_functions {
226 if let Some(tc) = tool_choice_override.as_deref() {
228 match tc {
229 "object:auto" => {
230 payload["tool_choice"] = serde_json::json!({"type": "auto"})
231 }
232 "object:none" => {
233 payload["tool_choice"] = serde_json::json!({"type": "none"})
234 }
235 "none" => payload["tool_choice"] = serde_json::json!("none"),
236 "auto" => payload["tool_choice"] = serde_json::json!("auto"),
237 _ => {}
238 }
239 }
240 }
241 }
242
243 if var_bool("VLLM_DEBUG_PAYLOAD", false) {
244 if let Ok(pretty) = serde_json::to_string_pretty(&payload) {
245 debug!(target: "openai_chat", payload = %pretty, "request payload");
246 }
247 }
248 debug!(
249 target: "openai_chat",
250 url = %url,
251 model = %self.config.model,
252 have_tools = %have_tools,
253 force_functions = %force_functions,
254 tool_choice = %tool_choice.as_ref().map(|v| v.to_string()).unwrap_or_else(|| "<none>".into()),
255 messages_len = payload["messages"].as_array().map(|a| a.len()).unwrap_or(0),
256 "sending chat.completions"
257 );
258 let mut req1 = self.client.post(&url);
259 if let Some(token) = &self.auth_token {
260 req1 = req1.bearer_auth(token);
261 }
262 let resp1 = req1.json(&payload).send().await.map_err(AgentError::from)?;
263 let status = resp1.status();
264 let body_text = resp1.text().await.map_err(AgentError::from)?;
265 debug!(target: "openai_chat", "request completed status={} have_tools={}", status, have_tools);
266 if !status.is_success() {
267 let truncated = if body_text.len() > 2000 {
268 format!("{}...<truncated>", &body_text[..2000])
269 } else {
270 body_text.clone()
271 };
272 return Err(AgentError::Other(format!(
273 "chat.completions failed (status: {}). The server returned an error while tools={} force_functions={}. No automatic retries are performed. Verify the endpoint supports your requested schema (tool_calls vs functions) or adjust config (e.g., VLLM_FORCE_FUNCTIONS, VLLM_TOOL_CHOICE). Response body: {}",
274 status,
275 if have_tools { "enabled" } else { "disabled" },
276 if var_bool("VLLM_FORCE_FUNCTIONS", false) {
277 "on"
278 } else {
279 "off"
280 },
281 truncated
282 )));
283 }
284
285 match serde_json::from_str::<ChatCompletion>(&body_text) {
286 Ok(body) => Ok(parse_chat_completion(body)),
287 Err(e) => {
288 let truncated = if body_text.len() > 2000 {
289 format!("{}...<truncated>", &body_text[..2000])
290 } else {
291 body_text
292 };
293 Err(AgentError::Other(format!(
294 "Failed to parse chat.completions response: {}. Expected OpenAI chat format with choices[0].message. Body: {}",
295 e, truncated
296 )))
297 }
298 }
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305
306 #[test]
307 fn parse_text_only() {
308 let json = serde_json::json!({
309 "choices": [
310 { "message": { "content": "Hello", "tool_calls": null } }
311 ]
312 });
313 let body: ChatCompletion = serde_json::from_value(json).unwrap();
314 let res = parse_chat_completion(body);
315 assert_eq!(res.text.as_deref(), Some("Hello"));
316 assert!(res.tool_calls.is_empty());
317 }
318
319 #[test]
320 fn parse_with_tool_calls() {
321 let json = serde_json::json!({
322 "choices": [
323 { "message": {
324 "content": null,
325 "tool_calls": [
326 { "type": "function", "function": { "name": "search", "arguments": "{\"q\":\"rust\"}" } },
327 { "type": "function", "function": { "name": "get_weather", "arguments": "{\"city\":\"NYC\"}" } }
328 ]
329 }}
330 ]
331 });
332 let body: ChatCompletion = serde_json::from_value(json).unwrap();
333 let res = parse_chat_completion(body);
334 assert!(res.text.is_none());
335 assert_eq!(res.tool_calls.len(), 2);
336 assert_eq!(res.tool_calls[0].name, "search");
337 assert_eq!(res.tool_calls[0].arguments, "{\"q\":\"rust\"}");
338 assert_eq!(res.tool_calls[1].name, "get_weather");
339 }
340}