Skip to main content

poe2_agent/
llm.rs

1//! OpenAI API client — Responses API for tool calling and streaming.
2//!
3//! Uses the Responses API (`/v1/responses`) for both blocking requests with tool
4//! calling and streaming text responses.
5
6use anyhow::{Context, Result};
7use futures_core::Stream;
8use reqwest::header;
9use serde::{Deserialize, Serialize};
10use std::time::Duration;
11
12/// Responses API endpoint.
13const RESPONSES_API_URL: &str = "https://api.openai.com/v1/responses";
14
15/// Default maximum number of output tokens per response.
16const DEFAULT_MAX_OUTPUT_TOKENS: u32 = 4096;
17
18/// Connect timeout for all requests.
19const CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
20
21/// Response timeout for blocking (non-streaming) requests.
22const BLOCKING_RESPONSE_TIMEOUT: Duration = Duration::from_secs(120);
23
24/// HTTP status codes that are retried.
25const RETRYABLE_STATUSES: &[u16] = &[429, 500, 502, 503];
26
27/// Maximum number of attempts (initial + retries).
28const MAX_ATTEMPTS: u32 = 3;
29
30/// OpenAI API client.
31#[derive(Clone)]
32pub struct ChatGptClient {
33    client: reqwest::Client,
34    model: String,
35    reasoning_effort: Option<String>,
36    prompt_cache_key: Option<String>,
37    prompt_cache_retention: Option<String>,
38    max_output_tokens: u32,
39    base_url: String,
40}
41
42// -- Responses API: Tool definitions -----------------------------------------
43
44/// Tool definition for the Responses API (flattened — no `function` wrapper).
45#[derive(Debug, Serialize, Clone)]
46pub struct ToolDefinition {
47    #[serde(rename = "type")]
48    pub tool_type: String,
49    pub name: String,
50    pub description: String,
51    pub parameters: serde_json::Value,
52}
53
54// -- Responses API: Request --------------------------------------------------
55
56#[derive(Debug, Serialize)]
57struct ResponseRequest {
58    model: String,
59    input: Vec<serde_json::Value>,
60    #[serde(skip_serializing_if = "Option::is_none")]
61    instructions: Option<String>,
62    #[serde(skip_serializing_if = "Option::is_none")]
63    tools: Option<Vec<ToolDefinition>>,
64    #[serde(skip_serializing_if = "std::ops::Not::not")]
65    stream: bool,
66    #[serde(skip_serializing_if = "Option::is_none")]
67    reasoning: Option<ReasoningConfig>,
68    #[serde(skip_serializing_if = "Option::is_none")]
69    previous_response_id: Option<String>,
70    store: bool,
71    #[serde(skip_serializing_if = "Option::is_none")]
72    prompt_cache_key: Option<String>,
73    #[serde(skip_serializing_if = "Option::is_none")]
74    prompt_cache_retention: Option<String>,
75    max_output_tokens: u32,
76    truncation: &'static str,
77}
78
79#[derive(Debug, Serialize, Clone)]
80struct ReasoningConfig {
81    effort: String,
82}
83
84// -- Responses API: Response -------------------------------------------------
85
86/// Parsed response from the Responses API.
87#[derive(Debug, Deserialize)]
88pub struct ApiResponse {
89    pub id: String,
90    pub status: String,
91    pub output: Vec<serde_json::Value>,
92    #[serde(default)]
93    pub output_text: Option<String>,
94    #[serde(default)]
95    pub usage: Option<Usage>,
96    #[serde(default)]
97    pub error: Option<ApiResponseError>,
98}
99
100/// Error details when `status == "failed"`.
101#[derive(Debug, Deserialize)]
102pub struct ApiResponseError {
103    pub message: String,
104    #[serde(default)]
105    pub code: Option<String>,
106}
107
108/// A parsed `function_call` item from the API output.
109#[derive(Debug, Deserialize, Clone)]
110pub struct FunctionCallItem {
111    pub id: String,
112    pub name: String,
113    pub call_id: String,
114    pub arguments: String,
115    pub status: String,
116}
117
118/// Events yielded by a streaming Responses API call.
119pub enum ResponseStreamEvent {
120    /// A chunk of the response text.
121    TextDelta(String),
122    /// A completed function call from the model's output.
123    FunctionCall(FunctionCallItem),
124    /// The response is complete. Carries the response ID (for chaining)
125    /// and token usage.
126    ResponseCompleted { id: String, usage: Option<Usage> },
127}
128
129impl ApiResponse {
130    /// Extract `function_call` items from the output array.
131    pub fn function_calls(&self) -> Vec<FunctionCallItem> {
132        self.output
133            .iter()
134            .filter_map(|item| {
135                if item.get("type")?.as_str()? == "function_call" {
136                    serde_json::from_value(item.clone()).ok()
137                } else {
138                    None
139                }
140            })
141            .collect()
142    }
143}
144
145// -- Responses API: Input item builders --------------------------------------
146
147/// Create an input message item.
148pub fn input_message(role: &str, content: &str) -> serde_json::Value {
149    serde_json::json!({ "type": "message", "role": role, "content": content })
150}
151
152/// Create an input item for a function call result.
153pub fn input_function_call_output(call_id: &str, output: &str) -> serde_json::Value {
154    serde_json::json!({ "type": "function_call_output", "call_id": call_id, "output": output })
155}
156
157// -- Usage -------------------------------------------------------------------
158
159/// Nested detail object for cached token reporting.
160#[derive(Debug, Deserialize, Default, Clone, Copy)]
161struct InputTokensDetails {
162    #[serde(default)]
163    cached_tokens: u32,
164}
165
166/// Token usage from an OpenAI Responses API response.
167#[derive(Debug, Deserialize, Default, Clone, Copy)]
168pub struct Usage {
169    pub input_tokens: u32,
170    pub output_tokens: u32,
171    pub total_tokens: u32,
172    #[serde(default)]
173    input_tokens_details: Option<InputTokensDetails>,
174}
175
176impl Usage {
177    /// Number of input tokens served from the prompt cache.
178    pub fn cached_tokens(&self) -> u32 {
179        self.input_tokens_details.map_or(0, |d| d.cached_tokens)
180    }
181}
182
183impl std::ops::AddAssign for Usage {
184    fn add_assign(&mut self, rhs: Self) {
185        self.input_tokens += rhs.input_tokens;
186        self.output_tokens += rhs.output_tokens;
187        self.total_tokens += rhs.total_tokens;
188        // Accumulate cached tokens into existing details or create new.
189        let prev = self.input_tokens_details.unwrap_or_default().cached_tokens;
190        let added = rhs.input_tokens_details.unwrap_or_default().cached_tokens;
191        self.input_tokens_details = Some(InputTokensDetails {
192            cached_tokens: prev + added,
193        });
194    }
195}
196
197// -- Errors ------------------------------------------------------------------
198
199#[derive(Debug, thiserror::Error)]
200pub enum LlmError {
201    #[error("OpenAI API error (HTTP {status}): {body}")]
202    Api { status: u16, body: String },
203
204    #[error(transparent)]
205    Transport(#[from] reqwest::Error),
206
207    #[error(transparent)]
208    Other(#[from] anyhow::Error),
209}
210
211// -- Retry helper ------------------------------------------------------------
212
213/// Send an HTTP POST with retry logic for transient errors.
214///
215/// Retries on 429/500/502/503 up to `MAX_ATTEMPTS` total attempts.
216/// Respects the `Retry-After` header on 429; falls back to exponential
217/// backoff (1s → 2s → …, capped at 30s).
218///
219/// `timeout` applies per-attempt to the full response (headers + body).
220/// Pass `None` for streaming requests where duration is bounded by
221/// `max_output_tokens`.
222async fn send_with_retry(
223    client: &reqwest::Client,
224    url: &str,
225    body: &serde_json::Value,
226    timeout: Option<Duration>,
227) -> Result<reqwest::Response, LlmError> {
228    let mut attempt = 0u32;
229    loop {
230        let mut req = client.post(url).json(body);
231        if let Some(t) = timeout {
232            req = req.timeout(t);
233        }
234        let response = req.send().await?;
235        let status = response.status();
236
237        if status.is_success() {
238            return Ok(response);
239        }
240
241        let status_u16 = status.as_u16();
242        let is_retryable = RETRYABLE_STATUSES.contains(&status_u16);
243        let has_attempts_remaining = attempt + 1 < MAX_ATTEMPTS;
244
245        if !is_retryable || !has_attempts_remaining {
246            let body = response.text().await.unwrap_or_default();
247            return Err(LlmError::Api {
248                status: status_u16,
249                body,
250            });
251        }
252
253        // Compute backoff: honour Retry-After on 429, otherwise exponential.
254        let backoff = if status_u16 == 429 {
255            response
256                .headers()
257                .get("retry-after")
258                .and_then(|v| v.to_str().ok())
259                .and_then(|s| s.parse::<u64>().ok())
260                .map(Duration::from_secs)
261                .unwrap_or_else(|| Duration::from_secs(1u64 << attempt))
262        } else {
263            Duration::from_secs(1u64 << attempt)
264        };
265        let backoff = backoff.min(Duration::from_secs(30));
266
267        tracing::warn!(
268            status = status_u16,
269            attempt = attempt + 1,
270            backoff_secs = backoff.as_secs_f32(),
271            "transient API error — retrying"
272        );
273
274        tokio::time::sleep(backoff).await;
275        attempt += 1;
276    }
277}
278
279// -- Client implementation ---------------------------------------------------
280
281impl ChatGptClient {
282    /// Create a new client. The API key is baked into the underlying
283    /// `reqwest::Client` as a default header so it doesn't need to be
284    /// cloned per-request.
285    pub fn new(api_key: &str, model: &str) -> Result<Self> {
286        let mut headers = header::HeaderMap::new();
287        let mut auth = header::HeaderValue::from_str(&format!("Bearer {api_key}"))
288            .context("invalid API key characters")?;
289        auth.set_sensitive(true);
290        headers.insert(header::AUTHORIZATION, auth);
291
292        let client = reqwest::Client::builder()
293            .default_headers(headers)
294            .connect_timeout(CONNECT_TIMEOUT)
295            .build()
296            .context("failed to build HTTP client")?;
297
298        // GPT-5+ reasoning models default to "medium" reasoning effort, which
299        // generates hidden reasoning tokens. Only set for gpt-5+ models.
300        let reasoning_effort = if model.starts_with("gpt-5") || model.starts_with("gpt-6") {
301            if model.contains("nano") {
302                Some("minimal".to_owned())
303            } else if model.contains("mini") {
304                Some("low".to_owned())
305            } else {
306                Some("medium".to_owned())
307            }
308        } else {
309            None
310        };
311
312        // All models benefit from prompt_cache_key for cache pool routing.
313        let prompt_cache_key = Some("poe2-agent-v1".to_owned());
314
315        // GPT-5.1+ supports extended 24h cache retention.
316        let prompt_cache_retention = if model.starts_with("gpt-5.1")
317            || model.starts_with("gpt-5.2")
318            || model.starts_with("gpt-6")
319        {
320            Some("24h".to_owned())
321        } else {
322            None
323        };
324
325        Ok(Self {
326            client,
327            model: model.to_owned(),
328            reasoning_effort,
329            prompt_cache_key,
330            prompt_cache_retention,
331            max_output_tokens: DEFAULT_MAX_OUTPUT_TOKENS,
332            base_url: RESPONSES_API_URL.to_owned(),
333        })
334    }
335
336    /// Create a client pointing at a custom base URL (used in tests).
337    #[cfg(test)]
338    fn new_with_base_url(api_key: &str, model: &str, base_url: &str) -> Result<Self> {
339        let mut client = Self::new(api_key, model)?;
340        client.base_url = base_url.to_owned();
341        Ok(client)
342    }
343
344    /// Override the maximum number of output tokens per response.
345    pub fn with_max_output_tokens(mut self, n: u32) -> Self {
346        self.max_output_tokens = n;
347        self
348    }
349
350    /// Override the default reasoning effort level.
351    ///
352    /// Valid values: `"minimal"`, `"low"`, `"medium"`, `"high"`.
353    pub fn with_reasoning_effort(mut self, effort: &str) -> Self {
354        self.reasoning_effort = Some(effort.to_owned());
355        self
356    }
357
358    /// Returns the model name this client is configured for.
359    pub fn model(&self) -> &str {
360        &self.model
361    }
362
363    /// Send a blocking request to the Responses API.
364    ///
365    /// Returns the full parsed response including output items and usage.
366    /// The agent loop inspects `function_calls()` to decide whether to
367    /// execute tools or return the final answer.
368    pub async fn create_response(
369        &self,
370        input: &[serde_json::Value],
371        instructions: Option<&str>,
372        tools: Option<&[ToolDefinition]>,
373        previous_response_id: Option<&str>,
374    ) -> Result<ApiResponse, LlmError> {
375        let request = ResponseRequest {
376            model: self.model.clone(),
377            input: input.to_vec(),
378            instructions: instructions.map(|s| s.to_owned()),
379            tools: tools.map(|t| t.to_vec()),
380            stream: false,
381            reasoning: self
382                .reasoning_effort
383                .as_ref()
384                .map(|e| ReasoningConfig { effort: e.clone() }),
385            previous_response_id: previous_response_id.map(|s| s.to_owned()),
386            store: true,
387            prompt_cache_key: self.prompt_cache_key.clone(),
388            prompt_cache_retention: self.prompt_cache_retention.clone(),
389            max_output_tokens: self.max_output_tokens,
390            truncation: "auto",
391        };
392
393        let body = serde_json::to_value(&request).map_err(|e| LlmError::Other(e.into()))?;
394        let response = send_with_retry(
395            &self.client,
396            &self.base_url,
397            &body,
398            Some(BLOCKING_RESPONSE_TIMEOUT),
399        )
400        .await?;
401
402        let parsed: ApiResponse = response.json().await?;
403        if let Some(ref u) = parsed.usage {
404            tracing::debug!(
405                input_tokens = u.input_tokens,
406                output_tokens = u.output_tokens,
407                cached_tokens = u.cached_tokens(),
408                total_tokens = u.total_tokens,
409                "llm response usage"
410            );
411        }
412        if parsed.status == "failed" {
413            let msg = parsed
414                .error
415                .as_ref()
416                .map(|e| e.message.as_str())
417                .unwrap_or("unknown error");
418            return Err(LlmError::Other(anyhow::anyhow!(
419                "API response failed: {msg}"
420            )));
421        }
422
423        Ok(parsed)
424    }
425
426    /// Stream a response from the Responses API, yielding structured events.
427    ///
428    /// Returns a stream of `ResponseStreamEvent`s: text deltas, function calls,
429    /// and a final `ResponseCompleted` with the response ID (for chaining) and
430    /// token usage.
431    ///
432    /// The returned stream is `'static` -- it clones the HTTP client and model
433    /// name so callers don't need to worry about lifetimes.
434    pub fn create_response_stream(
435        &self,
436        input: &[serde_json::Value],
437        instructions: Option<&str>,
438        tools: Option<&[ToolDefinition]>,
439        previous_response_id: Option<&str>,
440    ) -> impl Stream<Item = Result<ResponseStreamEvent, LlmError>> + Send {
441        let client = self.client.clone();
442        let url = self.base_url.clone();
443        let request = ResponseRequest {
444            model: self.model.clone(),
445            input: input.to_vec(),
446            instructions: instructions.map(|s| s.to_owned()),
447            tools: tools.map(|t| t.to_vec()),
448            stream: true,
449            reasoning: self
450                .reasoning_effort
451                .as_ref()
452                .map(|e| ReasoningConfig { effort: e.clone() }),
453            previous_response_id: previous_response_id.map(|s| s.to_owned()),
454            store: true,
455            prompt_cache_key: self.prompt_cache_key.clone(),
456            prompt_cache_retention: self.prompt_cache_retention.clone(),
457            max_output_tokens: self.max_output_tokens,
458            truncation: "auto",
459        };
460        // Serialize once outside the stream so we can rebuild the request each
461        // retry attempt without needing to clone RequestBuilder.
462        let body =
463            serde_json::to_value(&request).expect("ResponseRequest serialization is infallible");
464
465        async_stream::try_stream! {
466            // Retry covers the initial HTTP connection only; mid-stream
467            // failures propagate to the caller.
468            let mut response = send_with_retry(&client, &url, &body, None).await?;
469
470            let mut buffer = String::new();
471            let mut event_type = String::new();
472
473            while let Some(chunk) = response.chunk().await? {
474                buffer.push_str(&String::from_utf8_lossy(&chunk));
475
476                // Process complete SSE events (delimited by double newline).
477                while let Some(pos) = buffer.find("\n\n") {
478                    let event_block = buffer[..pos].to_owned();
479                    buffer = buffer[pos + 2..].to_owned();
480
481                    // Reset event_type each block so a missing `event:` line
482                    // doesn't reuse the previous block's type.
483                    event_type.clear();
484                    let mut data_line = None;
485                    for line in event_block.lines() {
486                        if let Some(et) = line.strip_prefix("event: ") {
487                            event_type = et.trim().to_owned();
488                        } else if let Some(d) = line.strip_prefix("data: ") {
489                            data_line = Some(d.to_owned());
490                        }
491                    }
492
493                    let data = match data_line {
494                        Some(d) => d,
495                        None => continue,
496                    };
497
498                    match event_type.as_str() {
499                        "response.output_text.delta" => {
500                            if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&data) {
501                                if let Some(delta) = parsed.get("delta").and_then(|d| d.as_str()) {
502                                    yield ResponseStreamEvent::TextDelta(delta.to_owned());
503                                }
504                            }
505                        }
506                        "response.output_item.done" => {
507                            if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&data) {
508                                if let Some(item) = parsed.get("item") {
509                                    if item.get("type").and_then(|t| t.as_str()) == Some("function_call") {
510                                        if let Ok(fc) = serde_json::from_value::<FunctionCallItem>(item.clone()) {
511                                            yield ResponseStreamEvent::FunctionCall(fc);
512                                        }
513                                    }
514                                }
515                            }
516                        }
517                        "response.completed" => {
518                            if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&data) {
519                                let id = parsed.pointer("/response/id")
520                                    .and_then(|v| v.as_str())
521                                    .unwrap_or_default()
522                                    .to_owned();
523                                let usage = parsed.pointer("/response/usage")
524                                    .and_then(|v| serde_json::from_value::<Usage>(v.clone()).ok());
525                                if let Some(ref u) = usage {
526                                    tracing::debug!(
527                                        input_tokens = u.input_tokens,
528                                        output_tokens = u.output_tokens,
529                                        cached_tokens = u.cached_tokens(),
530                                        total_tokens = u.total_tokens,
531                                        "llm stream response usage"
532                                    );
533                                }
534                                yield ResponseStreamEvent::ResponseCompleted { id, usage };
535                            }
536                            return;
537                        }
538                        "response.failed" | "response.incomplete" => {
539                            let msg = serde_json::from_str::<serde_json::Value>(&data)
540                                .ok()
541                                .and_then(|v| {
542                                    v.pointer("/response/error/message")
543                                        .and_then(|m| m.as_str().map(|s| s.to_owned()))
544                                })
545                                .unwrap_or_else(|| format!("response {}", event_type));
546                            Err(LlmError::Other(anyhow::anyhow!("{msg}")))?;
547                        }
548                        _ => {} // Ignore all other event types.
549                    }
550                }
551            }
552        }
553    }
554}
555
556// -- Tests -------------------------------------------------------------------
557
558#[cfg(test)]
559mod tests {
560    use super::*;
561    use wiremock::matchers::method;
562    use wiremock::{Mock, MockServer, ResponseTemplate};
563
564    fn success_body() -> serde_json::Value {
565        serde_json::json!({
566            "id": "resp_test",
567            "status": "completed",
568            "output": []
569        })
570    }
571
572    #[tokio::test]
573    async fn retry_on_429_respects_retry_after() {
574        let mock_server = MockServer::start().await;
575
576        // First two requests get 429 with Retry-After: 1.
577        Mock::given(method("POST"))
578            .respond_with(ResponseTemplate::new(429).insert_header("Retry-After", "1"))
579            .up_to_n_times(2)
580            .with_priority(1)
581            .mount(&mock_server)
582            .await;
583
584        // Third request succeeds.
585        Mock::given(method("POST"))
586            .respond_with(ResponseTemplate::new(200).set_body_json(success_body()))
587            .mount(&mock_server)
588            .await;
589
590        let client =
591            ChatGptClient::new_with_base_url("test-key", "gpt-4o", &mock_server.uri()).unwrap();
592        let result = client.create_response(&[], None, None, None).await;
593
594        assert!(result.is_ok(), "expected success after retries: {result:?}");
595        let requests = mock_server.received_requests().await.unwrap();
596        assert_eq!(requests.len(), 3, "expected exactly 3 requests");
597    }
598
599    #[tokio::test]
600    async fn retry_on_500_uses_exponential_backoff() {
601        let mock_server = MockServer::start().await;
602
603        // First two requests return 500.
604        Mock::given(method("POST"))
605            .respond_with(ResponseTemplate::new(500))
606            .up_to_n_times(2)
607            .with_priority(1)
608            .mount(&mock_server)
609            .await;
610
611        // Third request succeeds.
612        Mock::given(method("POST"))
613            .respond_with(ResponseTemplate::new(200).set_body_json(success_body()))
614            .mount(&mock_server)
615            .await;
616
617        let client =
618            ChatGptClient::new_with_base_url("test-key", "gpt-4o", &mock_server.uri()).unwrap();
619        let result = client.create_response(&[], None, None, None).await;
620
621        assert!(result.is_ok(), "expected success after retries: {result:?}");
622        let requests = mock_server.received_requests().await.unwrap();
623        assert_eq!(requests.len(), 3, "expected exactly 3 requests");
624    }
625
626    #[tokio::test]
627    async fn non_retryable_error_propagates() {
628        let mock_server = MockServer::start().await;
629
630        Mock::given(method("POST"))
631            .respond_with(ResponseTemplate::new(400).set_body_string("bad request"))
632            .mount(&mock_server)
633            .await;
634
635        let client =
636            ChatGptClient::new_with_base_url("test-key", "gpt-4o", &mock_server.uri()).unwrap();
637        let result = client.create_response(&[], None, None, None).await;
638
639        assert!(result.is_err());
640        let requests = mock_server.received_requests().await.unwrap();
641        assert_eq!(requests.len(), 1, "non-retryable error must not be retried");
642        match result.unwrap_err() {
643            LlmError::Api { status, .. } => assert_eq!(status, 400),
644            e => panic!("expected LlmError::Api, got {e:?}"),
645        }
646    }
647
648    #[tokio::test]
649    async fn max_retry_attempts_respected() {
650        let mock_server = MockServer::start().await;
651
652        // Always return 503.
653        Mock::given(method("POST"))
654            .respond_with(ResponseTemplate::new(503))
655            .mount(&mock_server)
656            .await;
657
658        let client =
659            ChatGptClient::new_with_base_url("test-key", "gpt-4o", &mock_server.uri()).unwrap();
660        let result = client.create_response(&[], None, None, None).await;
661
662        assert!(result.is_err());
663        let requests = mock_server.received_requests().await.unwrap();
664        assert_eq!(
665            requests.len(),
666            MAX_ATTEMPTS as usize,
667            "must stop after MAX_ATTEMPTS"
668        );
669        match result.unwrap_err() {
670            LlmError::Api { status, .. } => assert_eq!(status, 503),
671            e => panic!("expected LlmError::Api, got {e:?}"),
672        }
673    }
674}