openai_agents_rust/model/
gpt_oss_responses.rs

1use crate::config::Config;
2use crate::error::AgentError;
3use crate::model::{Model, ModelResponse, ToolCall};
4use crate::utils::env::var_bool;
5use async_trait::async_trait;
6use reqwest::Client;
7use serde::{Deserialize, Serialize};
8
9pub struct GptOssResponses {
10    client: Client,
11    config: Config,
12    base_url: String,
13    auth_token: Option<String>,
14}
15
16impl GptOssResponses {
17    pub fn new(config: Config) -> Self {
18        let client = Client::builder()
19            .user_agent("openai-agents-rust")
20            .build()
21            .expect("Failed to build reqwest client");
22        let auth_token = if config.api_key.is_empty() {
23            None
24        } else {
25            Some(config.api_key.clone())
26        };
27        Self {
28            client,
29            base_url: config.base_url.clone(),
30            config,
31            auth_token,
32        }
33    }
34
35    fn url(&self) -> String {
36        format!("{}/responses", self.base_url.trim_end_matches('/'))
37    }
38}
39
40#[derive(Serialize)]
41#[serde(untagged)]
42enum InputUnion {
43    Str(String),
44    Items(Vec<InputItem>),
45}
46
47#[derive(Serialize)]
48#[serde(tag = "type")]
49enum InputItem {
50    #[allow(dead_code)]
51    #[serde(rename = "message")]
52    Message { role: String, content: String },
53    #[allow(dead_code)]
54    #[serde(rename = "function_call")]
55    FunctionCall {
56        name: String,
57        arguments: String,
58        #[serde(skip_serializing_if = "Option::is_none")]
59        id: Option<String>,
60        #[serde(skip_serializing_if = "Option::is_none")]
61        call_id: Option<String>,
62    },
63    #[serde(rename = "function_call_output")]
64    FunctionCallOutput { call_id: String, output: String },
65}
66
67#[derive(Serialize)]
68struct FunctionToolDefinition {
69    #[serde(rename = "type")]
70    ty: String,
71    name: String,
72    parameters: serde_json::Value,
73    #[serde(skip_serializing_if = "Option::is_none")]
74    description: Option<String>,
75    #[serde(skip_serializing_if = "Option::is_none")]
76    strict: Option<bool>,
77}
78
79#[derive(Serialize)]
80struct ResponsesRequestBody {
81    #[serde(skip_serializing_if = "Option::is_none")]
82    instructions: Option<String>,
83    input: InputUnion,
84    #[serde(skip_serializing_if = "Option::is_none")]
85    model: Option<String>,
86    #[serde(skip_serializing_if = "Option::is_none")]
87    tools: Option<Vec<FunctionToolDefinition>>, // browser/code interpreter omitted for now
88    #[serde(skip_serializing_if = "Option::is_none")]
89    tool_choice: Option<String>, // "auto" | "none"
90    #[serde(skip_serializing_if = "Option::is_none")]
91    parallel_tool_calls: Option<bool>,
92    #[serde(skip_serializing_if = "Option::is_none")]
93    max_output_tokens: Option<i32>,
94    #[serde(skip_serializing_if = "Option::is_none")]
95    temperature: Option<f32>,
96    #[serde(skip_serializing_if = "Option::is_none")]
97    previous_response_id: Option<String>,
98    #[serde(skip_serializing_if = "Option::is_none")]
99    store: Option<bool>,
100}
101
102#[derive(Deserialize)]
103#[serde(tag = "type")]
104enum OutputItem {
105    #[serde(rename = "message")]
106    Message {
107        #[serde(rename = "role")]
108        _role: String,
109        content: Vec<TextPart>,
110    },
111    #[serde(rename = "function_call")]
112    FunctionCall {
113        name: String,
114        arguments: String,
115        id: String,
116        call_id: String,
117    },
118    #[serde(rename = "function_call_output")]
119    FunctionCallOutput {
120        #[allow(dead_code)]
121        call_id: String,
122        #[allow(dead_code)]
123        output: String,
124    },
125    #[serde(other)]
126    Other,
127}
128
129#[derive(Deserialize)]
130struct TextPart {
131    #[allow(dead_code)]
132    #[serde(rename = "type")]
133    _ty: String,
134    text: String,
135}
136
137#[derive(Deserialize)]
138struct ResponsesObject {
139    output: Vec<OutputItem>,
140    #[allow(dead_code)]
141    id: Option<String>,
142}
143
144fn map_openai_tools_to_oss(
145    tools: Option<&[serde_json::Value]>,
146) -> Option<Vec<FunctionToolDefinition>> {
147    let mut out = Vec::new();
148    if let Some(arr) = tools {
149        for t in arr.iter() {
150            if let Some(obj) = t.as_object() {
151                if obj.get("type").and_then(|v| v.as_str()) == Some("function") {
152                    if let Some(func) = obj.get("function").and_then(|v| v.as_object()) {
153                        let name = func
154                            .get("name")
155                            .and_then(|v| v.as_str())
156                            .unwrap_or("")
157                            .to_string();
158                        let description = func
159                            .get("description")
160                            .and_then(|v| v.as_str())
161                            .map(|s| s.to_string());
162                        let parameters = func
163                            .get("parameters")
164                            .cloned()
165                            .unwrap_or(serde_json::json!({"type":"object"}));
166                        out.push(FunctionToolDefinition {
167                            ty: "function".into(),
168                            name,
169                            parameters,
170                            description,
171                            strict: Some(false),
172                        });
173                    }
174                }
175            }
176        }
177    }
178    if out.is_empty() { None } else { Some(out) }
179}
180
181fn adapt_messages_to_input(messages: Option<&[serde_json::Value]>) -> InputUnion {
182    if let Some(msgs) = messages {
183        let mut items: Vec<InputItem> = Vec::new();
184        for m in msgs.iter() {
185            let role = m.get("role").and_then(|v| v.as_str()).unwrap_or("");
186            match role {
187                "user" | "assistant" | "system" => {
188                    if let Some(content) = m.get("content").and_then(|v| v.as_str()) {
189                        items.push(InputItem::Message {
190                            role: role.into(),
191                            content: content.into(),
192                        });
193                    }
194                }
195                "tool" => {
196                    if let Some(call_id) = m.get("tool_call_id").and_then(|v| v.as_str()) {
197                        let out = m
198                            .get("content")
199                            .and_then(|v| v.as_str())
200                            .unwrap_or("")
201                            .to_string();
202                        items.push(InputItem::FunctionCallOutput {
203                            call_id: call_id.into(),
204                            output: out,
205                        });
206                    }
207                }
208                _ => {}
209            }
210            // Do not inject function_call items into input for OSS Responses.
211            // The model server expects function_call_output linked via previous_response_id.
212        }
213        if items.is_empty() {
214            InputUnion::Str("".into())
215        } else {
216            InputUnion::Items(items)
217        }
218    } else {
219        InputUnion::Str("".into())
220    }
221}
222
223#[async_trait]
224impl Model for GptOssResponses {
225    async fn generate(&self, prompt: &str) -> Result<String, AgentError> {
226        let mut req = self.client.post(self.url());
227        if let Some(token) = &self.auth_token {
228            req = req.bearer_auth(token);
229        }
230        let body = ResponsesRequestBody {
231            instructions: None,
232            input: InputUnion::Str(prompt.to_string()),
233            model: Some(self.config.model.clone()),
234            tools: None,
235            tool_choice: None,
236            parallel_tool_calls: None,
237            max_output_tokens: Some(512),
238            temperature: Some(0.2),
239            previous_response_id: None,
240            store: None,
241        };
242        let resp = req.json(&body).send().await.map_err(AgentError::from)?;
243        let status = resp.status();
244        let text = resp.text().await.map_err(AgentError::from)?;
245        if !status.is_success() {
246            return Err(AgentError::Other(format!(
247                "HTTP {} error: {}",
248                status, text
249            )));
250        }
251        Ok(text)
252    }
253
254    async fn get_response(
255        &self,
256        system_instructions: Option<&str>,
257        _input: &str,
258        _model_settings: Option<std::collections::HashMap<String, String>>,
259        messages: Option<&[serde_json::Value]>,
260        tools: Option<&[serde_json::Value]>,
261        tool_choice: Option<serde_json::Value>,
262        _output_schema: Option<&str>,
263        _handoffs: Option<&[String]>,
264        _tracing_enabled: bool,
265        _previous_response_id: Option<&str>,
266        _prompt_config: Option<&str>,
267    ) -> Result<ModelResponse, AgentError> {
268        let mut req = self.client.post(self.url());
269        if let Some(token) = &self.auth_token {
270            req = req.bearer_auth(token);
271        }
272
273        let input = adapt_messages_to_input(messages);
274        let tools_mapped = map_openai_tools_to_oss(tools);
275        let tool_choice_str = tool_choice.and_then(|v| v.as_str().map(|s| s.to_string()));
276        let disable_prev = var_bool("OSS_DISABLE_PREVIOUS_RESPONSE", false)
277            || var_bool("OSS_TOOL_OUTPUT_AS_TEXT", false);
278        let body = ResponsesRequestBody {
279            instructions: system_instructions.map(|s| s.to_string()),
280            input,
281            model: Some(self.config.model.clone()),
282            tools: tools_mapped,
283            tool_choice: tool_choice_str,
284            parallel_tool_calls: Some(true),
285            max_output_tokens: Some(512),
286            temperature: Some(0.2),
287            previous_response_id: if disable_prev {
288                None
289            } else {
290                _previous_response_id.map(|s| s.to_string())
291            },
292            store: if disable_prev { None } else { Some(true) },
293        };
294        if var_bool("OSS_DEBUG_PAYLOAD", false) {
295            if let Ok(j) = serde_json::to_string_pretty(&body) {
296                tracing::debug!(target = "gpt_oss_responses", payload = %j, "OSS Responses request body");
297            }
298        }
299        if var_bool("OSS_DEBUG_HTTP", false) {
300            if let Ok(j) = serde_json::to_string_pretty(&body) {
301                eprintln!("OSS Responses REQUEST: {}", j);
302            }
303        }
304        let resp = req.json(&body).send().await.map_err(AgentError::from)?;
305        let status = resp.status();
306        let body_text = resp.text().await.map_err(AgentError::from)?;
307        if var_bool("OSS_DEBUG_PAYLOAD", false) {
308            tracing::debug!(target = "gpt_oss_responses", http_status = %status, body = %body_text, "OSS Responses response");
309        }
310        if var_bool("OSS_DEBUG_HTTP", false) {
311            eprintln!("OSS Responses HTTP {} body: {}", status, body_text);
312        }
313        if !status.is_success() {
314            return Err(AgentError::Other(format!(
315                "HTTP {} error: {}",
316                status, body_text
317            )));
318        }
319        let parsed: ResponsesObject = serde_json::from_str(&body_text).map_err(AgentError::from)?;
320        let mut text: Option<String> = None;
321        let mut tool_calls: Vec<ToolCall> = Vec::new();
322        let resp_id = parsed.id.clone();
323        for item in parsed.output.into_iter() {
324            match item {
325                OutputItem::Message { _role: _, content } => {
326                    let mut s = String::new();
327                    for p in content {
328                        s.push_str(&p.text);
329                    }
330                    if !s.is_empty() {
331                        text = Some(s);
332                    }
333                }
334                OutputItem::FunctionCall {
335                    name,
336                    arguments,
337                    id,
338                    call_id,
339                } => {
340                    tool_calls.push(ToolCall {
341                        id: Some(id),
342                        name,
343                        arguments,
344                        call_id: Some(call_id),
345                    });
346                }
347                _ => {}
348            }
349        }
350        Ok(ModelResponse {
351            id: resp_id,
352            text,
353            tool_calls,
354        })
355    }
356}