openai_agents_rust/model/
openai_chat.rs

1use crate::config::Config;
2use crate::error::AgentError;
3use crate::model::{Model, ModelResponse, ToolCall};
4use crate::utils::env::var_bool;
5use crate::utils::env::var_opt;
6use async_trait::async_trait;
7use reqwest::Client;
8use serde::Deserialize;
9use tracing::debug;
10
11/// Simple OpenAI Chat model implementation.
12pub struct OpenAiChat {
13    client: Client,
14    config: Config,
15    base_url: String,
16    auth_token: Option<String>,
17}
18
19impl OpenAiChat {
20    /// Create a new instance with the given configuration.
21    pub fn new(config: Config) -> Self {
22        let client = Client::builder()
23            .user_agent("openai-agents-rust")
24            .build()
25            .expect("Failed to build reqwest client");
26        let auth_token = if config.api_key.is_empty() {
27            None
28        } else {
29            Some(config.api_key.clone())
30        };
31        let base_url = config.base_url.clone();
32        Self {
33            client,
34            config,
35            base_url,
36            auth_token,
37        }
38    }
39
40    /// Override the base URL (e.g., http://192.168.3.40:8000/v1)
41    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
42        self.base_url = base_url.into();
43        self
44    }
45
46    /// Disable authentication for local/open endpoints.
47    pub fn without_auth(mut self) -> Self {
48        self.auth_token = None;
49        self
50    }
51}
52
53// Response shapes for OpenAI chat.completions
54#[derive(Deserialize)]
55struct FunctionCall {
56    name: String,
57    // Some servers return a JSON object instead of a stringified JSON.
58    arguments: serde_json::Value,
59}
60#[derive(Deserialize)]
61struct ToolCallJson {
62    #[serde(rename = "type")]
63    _type: Option<String>,
64    id: Option<String>,
65    call_id: Option<String>,
66    function: Option<FunctionCall>,
67}
68#[derive(Deserialize)]
69struct Message {
70    content: Option<String>,
71    tool_calls: Option<Vec<ToolCallJson>>,
72    // Legacy function_call support (pre-tool_calls schema)
73    function_call: Option<FunctionCall>,
74}
75#[derive(Deserialize)]
76struct Choice {
77    message: Message,
78}
79#[derive(Deserialize)]
80struct ChatCompletion {
81    choices: Vec<Choice>,
82}
83
84fn parse_chat_completion(body: ChatCompletion) -> ModelResponse {
85    let mut text: Option<String> = None;
86    let mut tool_calls: Vec<ToolCall> = Vec::new();
87    if let Some(first) = body.choices.into_iter().next() {
88        text = first.message.content;
89        if let Some(tcs) = first.message.tool_calls {
90            for tc in tcs.into_iter() {
91                if let Some(func) = tc.function {
92                    tool_calls.push(ToolCall {
93                        id: tc.id,
94                        name: func.name,
95                        arguments: match func.arguments {
96                            serde_json::Value::String(s) => s,
97                            other => other.to_string(),
98                        },
99                        call_id: tc.call_id,
100                    });
101                }
102            }
103        } else if let Some(func) = first.message.function_call {
104            // Legacy single function call
105            tool_calls.push(ToolCall {
106                id: None,
107                name: func.name,
108                arguments: match func.arguments {
109                    serde_json::Value::String(s) => s,
110                    other => other.to_string(),
111                },
112                call_id: None,
113            });
114        }
115    }
116    ModelResponse {
117        id: None,
118        text,
119        tool_calls,
120    }
121}
122
123#[async_trait]
124impl Model for OpenAiChat {
125    /// Sends a chat completion request to the OpenAI API.
126    async fn generate(&self, prompt: &str) -> Result<String, AgentError> {
127        let url = format!("{}/chat/completions", self.base_url);
128        let mut rb = self.client.post(&url);
129        if let Some(token) = &self.auth_token {
130            rb = rb.bearer_auth(token);
131        }
132        let resp = rb
133            .json(&serde_json::json!({
134                "model": self.config.model,
135                "messages": [{ "role": "user", "content": prompt }],
136            }))
137            .send()
138            .await
139            .map_err(AgentError::from)?;
140
141        let text = resp.text().await.map_err(AgentError::from)?;
142
143        Ok(text)
144    }
145
146    /// Rich response with basic parsing of text and tool calls.
147    async fn get_response(
148        &self,
149        system_instructions: Option<&str>,
150        input: &str,
151        _model_settings: Option<std::collections::HashMap<String, String>>,
152        messages: Option<&[serde_json::Value]>,
153        tools: Option<&[serde_json::Value]>,
154        tool_choice: Option<serde_json::Value>,
155        _output_schema: Option<&str>,
156        _handoffs: Option<&[String]>,
157        _tracing_enabled: bool,
158        _previous_response_id: Option<&str>,
159        _prompt_config: Option<&str>,
160    ) -> Result<ModelResponse, AgentError> {
161        let url = format!("{}/chat/completions", self.base_url);
162        // Build messages array if not provided.
163        let mut msgs: Vec<serde_json::Value> = Vec::new();
164        if let Some(provided) = messages {
165            msgs.extend_from_slice(provided);
166        } else {
167            if let Some(sys) = system_instructions {
168                msgs.push(serde_json::json!({"role": "system", "content": sys}));
169            }
170            msgs.push(serde_json::json!({"role": "user", "content": input}));
171        }
172
173        // Env toggles for compatibility
174        let minimal_payload = var_bool("VLLM_MIN_PAYLOAD", false);
175        let force_functions = var_bool("VLLM_FORCE_FUNCTIONS", false);
176        // Default to enabling parallel tool calls for Harmony unless explicitly disabled
177        let disable_parallel = var_bool("VLLM_DISABLE_PARALLEL_TOOL_CALLS", false);
178        // Optional override: Values: "auto", "none", "object:auto", "object:none"
179        let tool_choice_override = var_opt("VLLM_TOOL_CHOICE");
180
181        // Prepare payload
182        let mut payload = if minimal_payload {
183            serde_json::json!({
184                "model": self.config.model,
185                "messages": msgs,
186            })
187        } else {
188            serde_json::json!({
189                "model": self.config.model,
190                "messages": msgs,
191                "max_tokens": 512,
192                "temperature": 0.2,
193            })
194        };
195        let have_tools = if let Some(t) = tools {
196            if force_functions {
197                // Build legacy functions list
198                let mut functions: Vec<serde_json::Value> = Vec::new();
199                for tool in t.iter() {
200                    if let Some(obj) = tool.as_object() {
201                        if obj.get("type").and_then(|v| v.as_str()) == Some("function") {
202                            if let Some(func) = obj.get("function") {
203                                functions.push(func.clone());
204                            }
205                        }
206                    }
207                }
208                if !functions.is_empty() {
209                    payload["functions"] = serde_json::Value::Array(functions);
210                    payload["function_call"] = serde_json::json!("auto");
211                }
212            } else if !minimal_payload {
213                payload["tools"] = serde_json::Value::Array(t.to_vec());
214                if !disable_parallel {
215                    payload["parallel_tool_calls"] = serde_json::Value::Bool(true);
216                }
217            }
218            true
219        } else {
220            false
221        };
222        if !minimal_payload {
223            if let Some(choice) = &tool_choice {
224                payload["tool_choice"] = choice.clone();
225            } else if have_tools && !force_functions {
226                // Harmony: omit tool_choice to let server decide, unless override is provided
227                if let Some(tc) = tool_choice_override.as_deref() {
228                    match tc {
229                        "object:auto" => {
230                            payload["tool_choice"] = serde_json::json!({"type": "auto"})
231                        }
232                        "object:none" => {
233                            payload["tool_choice"] = serde_json::json!({"type": "none"})
234                        }
235                        "none" => payload["tool_choice"] = serde_json::json!("none"),
236                        "auto" => payload["tool_choice"] = serde_json::json!("auto"),
237                        _ => {}
238                    }
239                }
240            }
241        }
242
243        if var_bool("VLLM_DEBUG_PAYLOAD", false) {
244            if let Ok(pretty) = serde_json::to_string_pretty(&payload) {
245                debug!(target: "openai_chat", payload = %pretty, "request payload");
246            }
247        }
248        debug!(
249            target: "openai_chat",
250            url = %url,
251            model = %self.config.model,
252            have_tools = %have_tools,
253            force_functions = %force_functions,
254            tool_choice = %tool_choice.as_ref().map(|v| v.to_string()).unwrap_or_else(|| "<none>".into()),
255            messages_len = payload["messages"].as_array().map(|a| a.len()).unwrap_or(0),
256            "sending chat.completions"
257        );
258        let mut req1 = self.client.post(&url);
259        if let Some(token) = &self.auth_token {
260            req1 = req1.bearer_auth(token);
261        }
262        let resp1 = req1.json(&payload).send().await.map_err(AgentError::from)?;
263        let status = resp1.status();
264        let body_text = resp1.text().await.map_err(AgentError::from)?;
265        debug!(target: "openai_chat", "request completed status={} have_tools={}", status, have_tools);
266        if !status.is_success() {
267            let truncated = if body_text.len() > 2000 {
268                format!("{}...<truncated>", &body_text[..2000])
269            } else {
270                body_text.clone()
271            };
272            return Err(AgentError::Other(format!(
273                "chat.completions failed (status: {}). The server returned an error while tools={} force_functions={}. No automatic retries are performed. Verify the endpoint supports your requested schema (tool_calls vs functions) or adjust config (e.g., VLLM_FORCE_FUNCTIONS, VLLM_TOOL_CHOICE). Response body: {}",
274                status,
275                if have_tools { "enabled" } else { "disabled" },
276                if var_bool("VLLM_FORCE_FUNCTIONS", false) {
277                    "on"
278                } else {
279                    "off"
280                },
281                truncated
282            )));
283        }
284
285        match serde_json::from_str::<ChatCompletion>(&body_text) {
286            Ok(body) => Ok(parse_chat_completion(body)),
287            Err(e) => {
288                let truncated = if body_text.len() > 2000 {
289                    format!("{}...<truncated>", &body_text[..2000])
290                } else {
291                    body_text
292                };
293                Err(AgentError::Other(format!(
294                    "Failed to parse chat.completions response: {}. Expected OpenAI chat format with choices[0].message. Body: {}",
295                    e, truncated
296                )))
297            }
298        }
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305
306    #[test]
307    fn parse_text_only() {
308        let json = serde_json::json!({
309            "choices": [
310                { "message": { "content": "Hello", "tool_calls": null } }
311            ]
312        });
313        let body: ChatCompletion = serde_json::from_value(json).unwrap();
314        let res = parse_chat_completion(body);
315        assert_eq!(res.text.as_deref(), Some("Hello"));
316        assert!(res.tool_calls.is_empty());
317    }
318
319    #[test]
320    fn parse_with_tool_calls() {
321        let json = serde_json::json!({
322            "choices": [
323                { "message": {
324                    "content": null,
325                    "tool_calls": [
326                        { "type": "function", "function": { "name": "search", "arguments": "{\"q\":\"rust\"}" } },
327                        { "type": "function", "function": { "name": "get_weather", "arguments": "{\"city\":\"NYC\"}" } }
328                    ]
329                }}
330            ]
331        });
332        let body: ChatCompletion = serde_json::from_value(json).unwrap();
333        let res = parse_chat_completion(body);
334        assert!(res.text.is_none());
335        assert_eq!(res.tool_calls.len(), 2);
336        assert_eq!(res.tool_calls[0].name, "search");
337        assert_eq!(res.tool_calls[0].arguments, "{\"q\":\"rust\"}");
338        assert_eq!(res.tool_calls[1].name, "get_weather");
339    }
340}