1use async_trait::async_trait;
15use reqwest::Client;
16use serde_json::{Value, json};
17use tracing::{debug, instrument, warn};
18
19use crate::error::{LlmError, LlmResult};
20
21#[derive(Debug, Clone)]
23pub struct ResponsesRequest {
24 pub model: String,
26 pub input: String,
29 pub instructions: Option<String>,
31 pub tools: Option<Vec<Value>>,
34 pub tool_choice: Option<Value>,
37 pub temperature: Option<f32>,
39 pub max_output_tokens: Option<u32>,
41 pub user: Option<String>,
43 pub extra_fields: Option<Value>,
46}
47
48impl ResponsesRequest {
49 pub fn new(model: impl Into<String>, input: impl Into<String>) -> Self {
51 Self {
52 model: model.into(),
53 input: input.into(),
54 instructions: None,
55 tools: None,
56 tool_choice: None,
57 temperature: None,
58 max_output_tokens: None,
59 user: None,
60 extra_fields: None,
61 }
62 }
63
64 pub fn to_wire(&self) -> Value {
66 let mut obj = serde_json::Map::new();
67 obj.insert("model".into(), Value::String(self.model.clone()));
68 obj.insert("input".into(), Value::String(self.input.clone()));
69 if let Some(ref s) = self.instructions {
70 obj.insert("instructions".into(), Value::String(s.clone()));
71 }
72 if let Some(ref tools) = self.tools {
73 obj.insert("tools".into(), Value::Array(tools.clone()));
74 }
75 if let Some(ref tc) = self.tool_choice {
76 obj.insert("tool_choice".into(), tc.clone());
77 }
78 if let Some(t) = self.temperature {
79 obj.insert(
80 "temperature".into(),
81 serde_json::Number::from_f64(t as f64)
82 .map(Value::Number)
83 .unwrap_or(Value::Null),
84 );
85 }
86 if let Some(m) = self.max_output_tokens {
87 obj.insert("max_output_tokens".into(), Value::Number(m.into()));
88 }
89 if let Some(ref u) = self.user {
90 obj.insert("user".into(), Value::String(u.clone()));
91 }
92 if let Some(Value::Object(extra)) = self.extra_fields.as_ref() {
93 for (k, v) in extra {
94 obj.insert(k.clone(), v.clone());
95 }
96 }
97 Value::Object(obj)
98 }
99}
100
101#[async_trait]
108pub trait ResponsesClient: Send + Sync {
109 async fn create_response(&self, request: &ResponsesRequest) -> LlmResult<Value>;
114
115 async fn retrieve_response(&self, response_id: &str) -> LlmResult<Value>;
118
119 async fn submit_tool_outputs(
126 &self,
127 response_id: &str,
128 tool_outputs: Vec<Value>,
129 ) -> LlmResult<Value>;
130}
131
132#[derive(Clone)]
139pub struct OpenAIResponsesClient {
140 api_key: String,
141 base_url: String,
142 client: Client,
143 network_retries: usize,
144}
145
146impl OpenAIResponsesClient {
147 pub const DEFAULT_BASE_URL: &'static str = "https://api.openai.com/v1";
149 pub const DEFAULT_NETWORK_RETRIES: usize = 3;
151
152 pub fn new(api_key: impl Into<String>, base_url: Option<String>) -> LlmResult<Self> {
154 let client = Client::builder()
155 .timeout(std::time::Duration::from_secs(600))
156 .build()
157 .map_err(|e| LlmError::ConfigError(format!("Failed to create HTTP client: {e}")))?;
158 Ok(Self {
159 api_key: api_key.into(),
160 base_url: base_url.unwrap_or_else(|| Self::DEFAULT_BASE_URL.to_string()),
161 client,
162 network_retries: Self::DEFAULT_NETWORK_RETRIES,
163 })
164 }
165
166 pub fn with_network_retries(mut self, retries: u32) -> Self {
168 self.network_retries = usize::try_from(retries).unwrap_or(usize::MAX);
169 self
170 }
171
172 fn auth_header(&self) -> String {
173 format!("Bearer {}", self.api_key)
174 }
175
176 #[instrument(
179 name = "responses_api.post",
180 level = "info",
181 skip(self, body),
182 fields(url = tracing::field::Empty),
183 )]
184 async fn post_json(&self, path: &str, body: Value) -> LlmResult<Value> {
185 let url = format!("{}{}", self.base_url, path);
186 tracing::Span::current().record("url", url.as_str());
187 self.send_with_retries(reqwest::Method::POST, url, Some(body))
188 .await
189 }
190
191 #[instrument(
193 name = "responses_api.get",
194 level = "info",
195 skip(self),
196 fields(url = tracing::field::Empty),
197 )]
198 async fn get_json(&self, path: &str) -> LlmResult<Value> {
199 let url = format!("{}{}", self.base_url, path);
200 tracing::Span::current().record("url", url.as_str());
201 self.send_with_retries(reqwest::Method::GET, url, None)
202 .await
203 }
204
205 async fn send_with_retries(
206 &self,
207 method: reqwest::Method,
208 url: String,
209 body: Option<Value>,
210 ) -> LlmResult<Value> {
211 let mut last_error = LlmError::NetworkError("No attempt made".to_string());
212 for attempt in 0..=self.network_retries {
213 debug!(attempt, "Responses API attempt");
214 if attempt > 0 {
215 let delay_ms = (1_000u64 * 2u64.saturating_pow(attempt as u32 - 1)).min(30_000);
216 warn!(
217 attempt,
218 delay_ms,
219 error = %last_error,
220 "Responses API request failed, retrying",
221 );
222 tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
223 }
224
225 let mut builder = self
226 .client
227 .request(method.clone(), &url)
228 .header("Authorization", self.auth_header())
229 .header("Content-Type", "application/json");
230 if let Some(ref b) = body {
231 builder = builder.json(b);
232 }
233
234 let response = match builder.send().await {
235 Ok(r) => r,
236 Err(e) => {
237 last_error = LlmError::NetworkError(e.to_string());
238 continue;
239 }
240 };
241
242 let status = response.status();
243 if !status.is_success() {
244 let error_body = response
245 .text()
246 .await
247 .unwrap_or_else(|_| "Unknown error".to_string());
248 let err = match status.as_u16() {
249 401 => LlmError::AuthenticationError(error_body),
250 429 => LlmError::RateLimitExceeded(error_body),
251 400 => LlmError::InvalidResponse(format!("Bad request: {error_body}")),
252 404 => LlmError::ModelNotFound(error_body),
253 _ => LlmError::ApiError(format!("HTTP {status}: {error_body}")),
254 };
255 if matches!(status.as_u16(), 400 | 401 | 404) {
256 return Err(err);
257 }
258 last_error = err;
259 continue;
260 }
261
262 let body_text = response.text().await.map_err(|e| {
263 LlmError::DeserializationError(format!("Failed to read response body: {e}"))
264 })?;
265 return serde_json::from_str::<Value>(&body_text).map_err(|e| {
266 LlmError::DeserializationError(format!(
267 "Failed to parse response: {e}. Raw body: {body_text}"
268 ))
269 });
270 }
271
272 Err(LlmError::MaxRetriesExceeded(format!(
273 "Responses API request failed after {} attempt(s): {}",
274 self.network_retries + 1,
275 last_error
276 )))
277 }
278}
279
280#[async_trait]
281impl ResponsesClient for OpenAIResponsesClient {
282 async fn create_response(&self, request: &ResponsesRequest) -> LlmResult<Value> {
283 self.post_json("/responses", request.to_wire()).await
284 }
285
286 async fn retrieve_response(&self, response_id: &str) -> LlmResult<Value> {
287 self.get_json(&format!("/responses/{response_id}")).await
288 }
289
290 async fn submit_tool_outputs(
291 &self,
292 response_id: &str,
293 tool_outputs: Vec<Value>,
294 ) -> LlmResult<Value> {
295 self.post_json(
296 &format!("/responses/{response_id}/submit_tool_outputs"),
297 json!({ "tool_outputs": tool_outputs }),
298 )
299 .await
300 }
301}
302
303#[cfg(test)]
304mod tests {
305 #![allow(
306 clippy::unwrap_used,
307 clippy::expect_used,
308 reason = "test code — panics are acceptable"
309 )]
310 use super::*;
311
312 #[test]
313 fn request_wire_includes_only_set_fields() {
314 let req = ResponsesRequest::new("gpt-4o", "hello");
315 let wire = req.to_wire();
316 assert_eq!(wire["model"], "gpt-4o");
317 assert_eq!(wire["input"], "hello");
318 assert!(wire.get("temperature").is_none());
319 assert!(wire.get("tools").is_none());
320 assert!(wire.get("tool_choice").is_none());
321 assert!(wire.get("instructions").is_none());
322 }
323
324 #[test]
325 fn request_wire_serialises_optional_fields() {
326 let mut req = ResponsesRequest::new("gpt-4o", "hello");
327 req.temperature = Some(0.7);
328 req.max_output_tokens = Some(128);
329 req.tool_choice = Some(Value::String("auto".into()));
330 req.tools = Some(vec![json!({"type":"function","name":"search"})]);
331 req.instructions = Some("be terse".into());
332 req.user = Some("u-1".into());
333 let wire = req.to_wire();
334 let t = wire["temperature"]
335 .as_f64()
336 .expect("temperature is a number");
337 assert!((t - 0.7).abs() < 1e-3);
338 assert_eq!(wire["max_output_tokens"], 128);
339 assert_eq!(wire["tool_choice"], "auto");
340 assert_eq!(wire["tools"][0]["name"], "search");
341 assert_eq!(wire["instructions"], "be terse");
342 assert_eq!(wire["user"], "u-1");
343 }
344
345 #[test]
346 fn extra_fields_merge_into_top_level() {
347 let mut req = ResponsesRequest::new("gpt-4o", "hello");
348 req.extra_fields = Some(json!({"reasoning": {"effort": "low"}}));
349 let wire = req.to_wire();
350 assert_eq!(wire["reasoning"]["effort"], "low");
351 }
352}