Skip to main content

cognee_llm/
responses_client.rs

1//! OpenAI Responses API client abstraction.
2//!
3//! This is a separate surface from the chat-completions [`Llm`](crate::Llm) trait
4//! because the Responses API has a meaningfully different shape — `input` /
5//! `output` arrays, function-call items in `output`, and a different usage
6//! payload (`input_tokens` / `output_tokens` instead of `prompt_tokens` /
7//! `completion_tokens`).
8//!
9//! Used by the HTTP server's `POST /api/v1/responses` handler. The trait
10//! deliberately models the Python `client.responses.create(...)` return shape:
11//! a JSON `Value`-shaped response with `id`, `output`, and `usage`, plus a
12//! best-effort polling hook for stored / async responses.
13
14use async_trait::async_trait;
15use reqwest::Client;
16use serde_json::{Value, json};
17use tracing::{debug, instrument, warn};
18
19use crate::error::{LlmError, LlmResult};
20
21/// Request to the OpenAI Responses API.
22#[derive(Debug, Clone)]
23pub struct ResponsesRequest {
24    /// Model identifier.
25    pub model: String,
26    /// Free-form input text. Multimodal inputs (file references etc.) are
27    /// modelled via `extra_input_items` and merged into the wire payload.
28    pub input: String,
29    /// Optional `instructions` field (system-prompt analogue).
30    pub instructions: Option<String>,
31    /// Tools array — typically `DEFAULT_TOOLS`. `None` means do not send a
32    /// `tools` field at all.
33    pub tools: Option<Vec<Value>>,
34    /// Tool selection policy. `"auto"` / `"none"` / `"required"` or an
35    /// object. Sent verbatim.
36    pub tool_choice: Option<Value>,
37    /// Sampling temperature.
38    pub temperature: Option<f32>,
39    /// Optional cap on completion tokens (`max_output_tokens` on the wire).
40    pub max_output_tokens: Option<u32>,
41    /// Optional end-user identifier forwarded for abuse-tracking.
42    pub user: Option<String>,
43    /// Extra wire fields merged into the top-level request object. Use
44    /// sparingly — exists for forward-compat with new OpenAI fields.
45    pub extra_fields: Option<Value>,
46}
47
48impl ResponsesRequest {
49    /// Build a minimal request with only `model` and `input` set.
50    pub fn new(model: impl Into<String>, input: impl Into<String>) -> Self {
51        Self {
52            model: model.into(),
53            input: input.into(),
54            instructions: None,
55            tools: None,
56            tool_choice: None,
57            temperature: None,
58            max_output_tokens: None,
59            user: None,
60            extra_fields: None,
61        }
62    }
63
64    /// Render as the JSON body POSTed to `/v1/responses`.
65    pub fn to_wire(&self) -> Value {
66        let mut obj = serde_json::Map::new();
67        obj.insert("model".into(), Value::String(self.model.clone()));
68        obj.insert("input".into(), Value::String(self.input.clone()));
69        if let Some(ref s) = self.instructions {
70            obj.insert("instructions".into(), Value::String(s.clone()));
71        }
72        if let Some(ref tools) = self.tools {
73            obj.insert("tools".into(), Value::Array(tools.clone()));
74        }
75        if let Some(ref tc) = self.tool_choice {
76            obj.insert("tool_choice".into(), tc.clone());
77        }
78        if let Some(t) = self.temperature {
79            obj.insert(
80                "temperature".into(),
81                serde_json::Number::from_f64(t as f64)
82                    .map(Value::Number)
83                    .unwrap_or(Value::Null),
84            );
85        }
86        if let Some(m) = self.max_output_tokens {
87            obj.insert("max_output_tokens".into(), Value::Number(m.into()));
88        }
89        if let Some(ref u) = self.user {
90            obj.insert("user".into(), Value::String(u.clone()));
91        }
92        if let Some(Value::Object(extra)) = self.extra_fields.as_ref() {
93            for (k, v) in extra {
94                obj.insert(k.clone(), v.clone());
95            }
96        }
97        Value::Object(obj)
98    }
99}
100
101/// Object-safe trait wrapping the OpenAI Responses API.
102///
103/// Implementations return the raw `serde_json::Value` from the upstream
104/// response so the HTTP-server layer can mirror Python's
105/// `response.model_dump()` behaviour exactly without extra structural
106/// translation in the LLM crate.
107#[async_trait]
108pub trait ResponsesClient: Send + Sync {
109    /// Create a new response. Mirrors Python's
110    /// `client.responses.create(...)`. Returns the raw JSON `Value` from
111    /// the upstream API (the caller is responsible for shaping it into
112    /// the public `ResponseBodyDTO`).
113    async fn create_response(&self, request: &ResponsesRequest) -> LlmResult<Value>;
114
115    /// Retrieve a stored / async response by id. Used to poll until
116    /// completion. Mirrors `GET /v1/responses/{id}`.
117    async fn retrieve_response(&self, response_id: &str) -> LlmResult<Value>;
118
119    /// Submit tool outputs back for the given response id. Mirrors
120    /// `POST /v1/responses/{id}/submit_tool_outputs`. Returns the
121    /// updated response.
122    ///
123    /// `tool_outputs` is an array of `{"tool_call_id": "...", "output": "..."}`
124    /// objects (matching the OpenAI wire shape).
125    async fn submit_tool_outputs(
126        &self,
127        response_id: &str,
128        tool_outputs: Vec<Value>,
129    ) -> LlmResult<Value>;
130}
131
132// ─── OpenAI implementation ───────────────────────────────────────────────────
133
134/// OpenAI Responses API client.
135///
136/// Backed by the same `reqwest` client / retry semantics as
137/// [`crate::adapters::OpenAIAdapter`].
138#[derive(Clone)]
139pub struct OpenAIResponsesClient {
140    api_key: String,
141    base_url: String,
142    client: Client,
143    network_retries: usize,
144}
145
146impl OpenAIResponsesClient {
147    /// Default OpenAI API base URL.
148    pub const DEFAULT_BASE_URL: &'static str = "https://api.openai.com/v1";
149    /// Default retry attempts for transient network/server errors.
150    pub const DEFAULT_NETWORK_RETRIES: usize = 3;
151
152    /// Construct a new client.
153    pub fn new(api_key: impl Into<String>, base_url: Option<String>) -> LlmResult<Self> {
154        let client = Client::builder()
155            .timeout(std::time::Duration::from_secs(600))
156            .build()
157            .map_err(|e| LlmError::ConfigError(format!("Failed to create HTTP client: {e}")))?;
158        Ok(Self {
159            api_key: api_key.into(),
160            base_url: base_url.unwrap_or_else(|| Self::DEFAULT_BASE_URL.to_string()),
161            client,
162            network_retries: Self::DEFAULT_NETWORK_RETRIES,
163        })
164    }
165
166    /// Configure retry attempts for transient network/server errors.
167    pub fn with_network_retries(mut self, retries: u32) -> Self {
168        self.network_retries = usize::try_from(retries).unwrap_or(usize::MAX);
169        self
170    }
171
172    fn auth_header(&self) -> String {
173        format!("Bearer {}", self.api_key)
174    }
175
176    /// POST a JSON body to the given relative URL and parse the response
177    /// as JSON. Retries on transient (5xx, 429, network) failures.
178    #[instrument(
179        name = "responses_api.post",
180        level = "info",
181        skip(self, body),
182        fields(url = tracing::field::Empty),
183    )]
184    async fn post_json(&self, path: &str, body: Value) -> LlmResult<Value> {
185        let url = format!("{}{}", self.base_url, path);
186        tracing::Span::current().record("url", url.as_str());
187        self.send_with_retries(reqwest::Method::POST, url, Some(body))
188            .await
189    }
190
191    /// GET a path. Same retry semantics as `post_json`.
192    #[instrument(
193        name = "responses_api.get",
194        level = "info",
195        skip(self),
196        fields(url = tracing::field::Empty),
197    )]
198    async fn get_json(&self, path: &str) -> LlmResult<Value> {
199        let url = format!("{}{}", self.base_url, path);
200        tracing::Span::current().record("url", url.as_str());
201        self.send_with_retries(reqwest::Method::GET, url, None)
202            .await
203    }
204
205    async fn send_with_retries(
206        &self,
207        method: reqwest::Method,
208        url: String,
209        body: Option<Value>,
210    ) -> LlmResult<Value> {
211        let mut last_error = LlmError::NetworkError("No attempt made".to_string());
212        for attempt in 0..=self.network_retries {
213            debug!(attempt, "Responses API attempt");
214            if attempt > 0 {
215                let delay_ms = (1_000u64 * 2u64.saturating_pow(attempt as u32 - 1)).min(30_000);
216                warn!(
217                    attempt,
218                    delay_ms,
219                    error = %last_error,
220                    "Responses API request failed, retrying",
221                );
222                tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
223            }
224
225            let mut builder = self
226                .client
227                .request(method.clone(), &url)
228                .header("Authorization", self.auth_header())
229                .header("Content-Type", "application/json");
230            if let Some(ref b) = body {
231                builder = builder.json(b);
232            }
233
234            let response = match builder.send().await {
235                Ok(r) => r,
236                Err(e) => {
237                    last_error = LlmError::NetworkError(e.to_string());
238                    continue;
239                }
240            };
241
242            let status = response.status();
243            if !status.is_success() {
244                let error_body = response
245                    .text()
246                    .await
247                    .unwrap_or_else(|_| "Unknown error".to_string());
248                let err = match status.as_u16() {
249                    401 => LlmError::AuthenticationError(error_body),
250                    429 => LlmError::RateLimitExceeded(error_body),
251                    400 => LlmError::InvalidResponse(format!("Bad request: {error_body}")),
252                    404 => LlmError::ModelNotFound(error_body),
253                    _ => LlmError::ApiError(format!("HTTP {status}: {error_body}")),
254                };
255                if matches!(status.as_u16(), 400 | 401 | 404) {
256                    return Err(err);
257                }
258                last_error = err;
259                continue;
260            }
261
262            let body_text = response.text().await.map_err(|e| {
263                LlmError::DeserializationError(format!("Failed to read response body: {e}"))
264            })?;
265            return serde_json::from_str::<Value>(&body_text).map_err(|e| {
266                LlmError::DeserializationError(format!(
267                    "Failed to parse response: {e}. Raw body: {body_text}"
268                ))
269            });
270        }
271
272        Err(LlmError::MaxRetriesExceeded(format!(
273            "Responses API request failed after {} attempt(s): {}",
274            self.network_retries + 1,
275            last_error
276        )))
277    }
278}
279
280#[async_trait]
281impl ResponsesClient for OpenAIResponsesClient {
282    async fn create_response(&self, request: &ResponsesRequest) -> LlmResult<Value> {
283        self.post_json("/responses", request.to_wire()).await
284    }
285
286    async fn retrieve_response(&self, response_id: &str) -> LlmResult<Value> {
287        self.get_json(&format!("/responses/{response_id}")).await
288    }
289
290    async fn submit_tool_outputs(
291        &self,
292        response_id: &str,
293        tool_outputs: Vec<Value>,
294    ) -> LlmResult<Value> {
295        self.post_json(
296            &format!("/responses/{response_id}/submit_tool_outputs"),
297            json!({ "tool_outputs": tool_outputs }),
298        )
299        .await
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    #![allow(
306        clippy::unwrap_used,
307        clippy::expect_used,
308        reason = "test code — panics are acceptable"
309    )]
310    use super::*;
311
312    #[test]
313    fn request_wire_includes_only_set_fields() {
314        let req = ResponsesRequest::new("gpt-4o", "hello");
315        let wire = req.to_wire();
316        assert_eq!(wire["model"], "gpt-4o");
317        assert_eq!(wire["input"], "hello");
318        assert!(wire.get("temperature").is_none());
319        assert!(wire.get("tools").is_none());
320        assert!(wire.get("tool_choice").is_none());
321        assert!(wire.get("instructions").is_none());
322    }
323
324    #[test]
325    fn request_wire_serialises_optional_fields() {
326        let mut req = ResponsesRequest::new("gpt-4o", "hello");
327        req.temperature = Some(0.7);
328        req.max_output_tokens = Some(128);
329        req.tool_choice = Some(Value::String("auto".into()));
330        req.tools = Some(vec![json!({"type":"function","name":"search"})]);
331        req.instructions = Some("be terse".into());
332        req.user = Some("u-1".into());
333        let wire = req.to_wire();
334        let t = wire["temperature"]
335            .as_f64()
336            .expect("temperature is a number");
337        assert!((t - 0.7).abs() < 1e-3);
338        assert_eq!(wire["max_output_tokens"], 128);
339        assert_eq!(wire["tool_choice"], "auto");
340        assert_eq!(wire["tools"][0]["name"], "search");
341        assert_eq!(wire["instructions"], "be terse");
342        assert_eq!(wire["user"], "u-1");
343    }
344
345    #[test]
346    fn extra_fields_merge_into_top_level() {
347        let mut req = ResponsesRequest::new("gpt-4o", "hello");
348        req.extra_fields = Some(json!({"reasoning": {"effort": "low"}}));
349        let wire = req.to_wire();
350        assert_eq!(wire["reasoning"]["effort"], "low");
351    }
352}