openai_agents_rust/model/
gpt_oss_responses.rs1use crate::config::Config;
2use crate::error::AgentError;
3use crate::model::{Model, ModelResponse, ToolCall};
4use crate::utils::env::var_bool;
5use async_trait::async_trait;
6use reqwest::Client;
7use serde::{Deserialize, Serialize};
8
9pub struct GptOssResponses {
10 client: Client,
11 config: Config,
12 base_url: String,
13 auth_token: Option<String>,
14}
15
16impl GptOssResponses {
17 pub fn new(config: Config) -> Self {
18 let client = Client::builder()
19 .user_agent("openai-agents-rust")
20 .build()
21 .expect("Failed to build reqwest client");
22 let auth_token = if config.api_key.is_empty() {
23 None
24 } else {
25 Some(config.api_key.clone())
26 };
27 Self {
28 client,
29 base_url: config.base_url.clone(),
30 config,
31 auth_token,
32 }
33 }
34
35 fn url(&self) -> String {
36 format!("{}/responses", self.base_url.trim_end_matches('/'))
37 }
38}
39
40#[derive(Serialize)]
41#[serde(untagged)]
42enum InputUnion {
43 Str(String),
44 Items(Vec<InputItem>),
45}
46
47#[derive(Serialize)]
48#[serde(tag = "type")]
49enum InputItem {
50 #[allow(dead_code)]
51 #[serde(rename = "message")]
52 Message { role: String, content: String },
53 #[allow(dead_code)]
54 #[serde(rename = "function_call")]
55 FunctionCall {
56 name: String,
57 arguments: String,
58 #[serde(skip_serializing_if = "Option::is_none")]
59 id: Option<String>,
60 #[serde(skip_serializing_if = "Option::is_none")]
61 call_id: Option<String>,
62 },
63 #[serde(rename = "function_call_output")]
64 FunctionCallOutput { call_id: String, output: String },
65}
66
67#[derive(Serialize)]
68struct FunctionToolDefinition {
69 #[serde(rename = "type")]
70 ty: String,
71 name: String,
72 parameters: serde_json::Value,
73 #[serde(skip_serializing_if = "Option::is_none")]
74 description: Option<String>,
75 #[serde(skip_serializing_if = "Option::is_none")]
76 strict: Option<bool>,
77}
78
79#[derive(Serialize)]
80struct ResponsesRequestBody {
81 #[serde(skip_serializing_if = "Option::is_none")]
82 instructions: Option<String>,
83 input: InputUnion,
84 #[serde(skip_serializing_if = "Option::is_none")]
85 model: Option<String>,
86 #[serde(skip_serializing_if = "Option::is_none")]
87 tools: Option<Vec<FunctionToolDefinition>>, #[serde(skip_serializing_if = "Option::is_none")]
89 tool_choice: Option<String>, #[serde(skip_serializing_if = "Option::is_none")]
91 parallel_tool_calls: Option<bool>,
92 #[serde(skip_serializing_if = "Option::is_none")]
93 max_output_tokens: Option<i32>,
94 #[serde(skip_serializing_if = "Option::is_none")]
95 temperature: Option<f32>,
96 #[serde(skip_serializing_if = "Option::is_none")]
97 previous_response_id: Option<String>,
98 #[serde(skip_serializing_if = "Option::is_none")]
99 store: Option<bool>,
100}
101
102#[derive(Deserialize)]
103#[serde(tag = "type")]
104enum OutputItem {
105 #[serde(rename = "message")]
106 Message {
107 #[serde(rename = "role")]
108 _role: String,
109 content: Vec<TextPart>,
110 },
111 #[serde(rename = "function_call")]
112 FunctionCall {
113 name: String,
114 arguments: String,
115 id: String,
116 call_id: String,
117 },
118 #[serde(rename = "function_call_output")]
119 FunctionCallOutput {
120 #[allow(dead_code)]
121 call_id: String,
122 #[allow(dead_code)]
123 output: String,
124 },
125 #[serde(other)]
126 Other,
127}
128
129#[derive(Deserialize)]
130struct TextPart {
131 #[allow(dead_code)]
132 #[serde(rename = "type")]
133 _ty: String,
134 text: String,
135}
136
137#[derive(Deserialize)]
138struct ResponsesObject {
139 output: Vec<OutputItem>,
140 #[allow(dead_code)]
141 id: Option<String>,
142}
143
144fn map_openai_tools_to_oss(
145 tools: Option<&[serde_json::Value]>,
146) -> Option<Vec<FunctionToolDefinition>> {
147 let mut out = Vec::new();
148 if let Some(arr) = tools {
149 for t in arr.iter() {
150 if let Some(obj) = t.as_object() {
151 if obj.get("type").and_then(|v| v.as_str()) == Some("function") {
152 if let Some(func) = obj.get("function").and_then(|v| v.as_object()) {
153 let name = func
154 .get("name")
155 .and_then(|v| v.as_str())
156 .unwrap_or("")
157 .to_string();
158 let description = func
159 .get("description")
160 .and_then(|v| v.as_str())
161 .map(|s| s.to_string());
162 let parameters = func
163 .get("parameters")
164 .cloned()
165 .unwrap_or(serde_json::json!({"type":"object"}));
166 out.push(FunctionToolDefinition {
167 ty: "function".into(),
168 name,
169 parameters,
170 description,
171 strict: Some(false),
172 });
173 }
174 }
175 }
176 }
177 }
178 if out.is_empty() { None } else { Some(out) }
179}
180
181fn adapt_messages_to_input(messages: Option<&[serde_json::Value]>) -> InputUnion {
182 if let Some(msgs) = messages {
183 let mut items: Vec<InputItem> = Vec::new();
184 for m in msgs.iter() {
185 let role = m.get("role").and_then(|v| v.as_str()).unwrap_or("");
186 match role {
187 "user" | "assistant" | "system" => {
188 if let Some(content) = m.get("content").and_then(|v| v.as_str()) {
189 items.push(InputItem::Message {
190 role: role.into(),
191 content: content.into(),
192 });
193 }
194 }
195 "tool" => {
196 if let Some(call_id) = m.get("tool_call_id").and_then(|v| v.as_str()) {
197 let out = m
198 .get("content")
199 .and_then(|v| v.as_str())
200 .unwrap_or("")
201 .to_string();
202 items.push(InputItem::FunctionCallOutput {
203 call_id: call_id.into(),
204 output: out,
205 });
206 }
207 }
208 _ => {}
209 }
210 }
213 if items.is_empty() {
214 InputUnion::Str("".into())
215 } else {
216 InputUnion::Items(items)
217 }
218 } else {
219 InputUnion::Str("".into())
220 }
221}
222
223#[async_trait]
224impl Model for GptOssResponses {
225 async fn generate(&self, prompt: &str) -> Result<String, AgentError> {
226 let mut req = self.client.post(self.url());
227 if let Some(token) = &self.auth_token {
228 req = req.bearer_auth(token);
229 }
230 let body = ResponsesRequestBody {
231 instructions: None,
232 input: InputUnion::Str(prompt.to_string()),
233 model: Some(self.config.model.clone()),
234 tools: None,
235 tool_choice: None,
236 parallel_tool_calls: None,
237 max_output_tokens: Some(512),
238 temperature: Some(0.2),
239 previous_response_id: None,
240 store: None,
241 };
242 let resp = req.json(&body).send().await.map_err(AgentError::from)?;
243 let status = resp.status();
244 let text = resp.text().await.map_err(AgentError::from)?;
245 if !status.is_success() {
246 return Err(AgentError::Other(format!(
247 "HTTP {} error: {}",
248 status, text
249 )));
250 }
251 Ok(text)
252 }
253
254 async fn get_response(
255 &self,
256 system_instructions: Option<&str>,
257 _input: &str,
258 _model_settings: Option<std::collections::HashMap<String, String>>,
259 messages: Option<&[serde_json::Value]>,
260 tools: Option<&[serde_json::Value]>,
261 tool_choice: Option<serde_json::Value>,
262 _output_schema: Option<&str>,
263 _handoffs: Option<&[String]>,
264 _tracing_enabled: bool,
265 _previous_response_id: Option<&str>,
266 _prompt_config: Option<&str>,
267 ) -> Result<ModelResponse, AgentError> {
268 let mut req = self.client.post(self.url());
269 if let Some(token) = &self.auth_token {
270 req = req.bearer_auth(token);
271 }
272
273 let input = adapt_messages_to_input(messages);
274 let tools_mapped = map_openai_tools_to_oss(tools);
275 let tool_choice_str = tool_choice.and_then(|v| v.as_str().map(|s| s.to_string()));
276 let disable_prev = var_bool("OSS_DISABLE_PREVIOUS_RESPONSE", false)
277 || var_bool("OSS_TOOL_OUTPUT_AS_TEXT", false);
278 let body = ResponsesRequestBody {
279 instructions: system_instructions.map(|s| s.to_string()),
280 input,
281 model: Some(self.config.model.clone()),
282 tools: tools_mapped,
283 tool_choice: tool_choice_str,
284 parallel_tool_calls: Some(true),
285 max_output_tokens: Some(512),
286 temperature: Some(0.2),
287 previous_response_id: if disable_prev {
288 None
289 } else {
290 _previous_response_id.map(|s| s.to_string())
291 },
292 store: if disable_prev { None } else { Some(true) },
293 };
294 if var_bool("OSS_DEBUG_PAYLOAD", false) {
295 if let Ok(j) = serde_json::to_string_pretty(&body) {
296 tracing::debug!(target = "gpt_oss_responses", payload = %j, "OSS Responses request body");
297 }
298 }
299 if var_bool("OSS_DEBUG_HTTP", false) {
300 if let Ok(j) = serde_json::to_string_pretty(&body) {
301 eprintln!("OSS Responses REQUEST: {}", j);
302 }
303 }
304 let resp = req.json(&body).send().await.map_err(AgentError::from)?;
305 let status = resp.status();
306 let body_text = resp.text().await.map_err(AgentError::from)?;
307 if var_bool("OSS_DEBUG_PAYLOAD", false) {
308 tracing::debug!(target = "gpt_oss_responses", http_status = %status, body = %body_text, "OSS Responses response");
309 }
310 if var_bool("OSS_DEBUG_HTTP", false) {
311 eprintln!("OSS Responses HTTP {} body: {}", status, body_text);
312 }
313 if !status.is_success() {
314 return Err(AgentError::Other(format!(
315 "HTTP {} error: {}",
316 status, body_text
317 )));
318 }
319 let parsed: ResponsesObject = serde_json::from_str(&body_text).map_err(AgentError::from)?;
320 let mut text: Option<String> = None;
321 let mut tool_calls: Vec<ToolCall> = Vec::new();
322 let resp_id = parsed.id.clone();
323 for item in parsed.output.into_iter() {
324 match item {
325 OutputItem::Message { _role: _, content } => {
326 let mut s = String::new();
327 for p in content {
328 s.push_str(&p.text);
329 }
330 if !s.is_empty() {
331 text = Some(s);
332 }
333 }
334 OutputItem::FunctionCall {
335 name,
336 arguments,
337 id,
338 call_id,
339 } => {
340 tool_calls.push(ToolCall {
341 id: Some(id),
342 name,
343 arguments,
344 call_id: Some(call_id),
345 });
346 }
347 _ => {}
348 }
349 }
350 Ok(ModelResponse {
351 id: resp_id,
352 text,
353 tool_calls,
354 })
355 }
356}