1use anyhow::{Context, Result};
7use futures_core::Stream;
8use reqwest::header;
9use serde::{Deserialize, Serialize};
10use std::time::Duration;
11
12const RESPONSES_API_URL: &str = "https://api.openai.com/v1/responses";
14
15const DEFAULT_MAX_OUTPUT_TOKENS: u32 = 4096;
17
18const CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
20
21const BLOCKING_RESPONSE_TIMEOUT: Duration = Duration::from_secs(120);
23
24const RETRYABLE_STATUSES: &[u16] = &[429, 500, 502, 503];
26
27const MAX_ATTEMPTS: u32 = 3;
29
30#[derive(Clone)]
32pub struct ChatGptClient {
33 client: reqwest::Client,
34 model: String,
35 reasoning_effort: Option<String>,
36 prompt_cache_key: Option<String>,
37 prompt_cache_retention: Option<String>,
38 max_output_tokens: u32,
39 base_url: String,
40}
41
42#[derive(Debug, Serialize, Clone)]
46pub struct ToolDefinition {
47 #[serde(rename = "type")]
48 pub tool_type: String,
49 pub name: String,
50 pub description: String,
51 pub parameters: serde_json::Value,
52}
53
54#[derive(Debug, Serialize)]
57struct ResponseRequest {
58 model: String,
59 input: Vec<serde_json::Value>,
60 #[serde(skip_serializing_if = "Option::is_none")]
61 instructions: Option<String>,
62 #[serde(skip_serializing_if = "Option::is_none")]
63 tools: Option<Vec<ToolDefinition>>,
64 #[serde(skip_serializing_if = "std::ops::Not::not")]
65 stream: bool,
66 #[serde(skip_serializing_if = "Option::is_none")]
67 reasoning: Option<ReasoningConfig>,
68 #[serde(skip_serializing_if = "Option::is_none")]
69 previous_response_id: Option<String>,
70 store: bool,
71 #[serde(skip_serializing_if = "Option::is_none")]
72 prompt_cache_key: Option<String>,
73 #[serde(skip_serializing_if = "Option::is_none")]
74 prompt_cache_retention: Option<String>,
75 max_output_tokens: u32,
76 truncation: &'static str,
77}
78
79#[derive(Debug, Serialize, Clone)]
80struct ReasoningConfig {
81 effort: String,
82}
83
84#[derive(Debug, Deserialize)]
88pub struct ApiResponse {
89 pub id: String,
90 pub status: String,
91 pub output: Vec<serde_json::Value>,
92 #[serde(default)]
93 pub output_text: Option<String>,
94 #[serde(default)]
95 pub usage: Option<Usage>,
96 #[serde(default)]
97 pub error: Option<ApiResponseError>,
98}
99
100#[derive(Debug, Deserialize)]
102pub struct ApiResponseError {
103 pub message: String,
104 #[serde(default)]
105 pub code: Option<String>,
106}
107
108#[derive(Debug, Deserialize, Clone)]
110pub struct FunctionCallItem {
111 pub id: String,
112 pub name: String,
113 pub call_id: String,
114 pub arguments: String,
115 pub status: String,
116}
117
118pub enum ResponseStreamEvent {
120 TextDelta(String),
122 FunctionCall(FunctionCallItem),
124 ResponseCompleted { id: String, usage: Option<Usage> },
127}
128
129impl ApiResponse {
130 pub fn function_calls(&self) -> Vec<FunctionCallItem> {
132 self.output
133 .iter()
134 .filter_map(|item| {
135 if item.get("type")?.as_str()? == "function_call" {
136 serde_json::from_value(item.clone()).ok()
137 } else {
138 None
139 }
140 })
141 .collect()
142 }
143}
144
145pub fn input_message(role: &str, content: &str) -> serde_json::Value {
149 serde_json::json!({ "type": "message", "role": role, "content": content })
150}
151
152pub fn input_function_call_output(call_id: &str, output: &str) -> serde_json::Value {
154 serde_json::json!({ "type": "function_call_output", "call_id": call_id, "output": output })
155}
156
157#[derive(Debug, Deserialize, Default, Clone, Copy)]
161struct InputTokensDetails {
162 #[serde(default)]
163 cached_tokens: u32,
164}
165
166#[derive(Debug, Deserialize, Default, Clone, Copy)]
168pub struct Usage {
169 pub input_tokens: u32,
170 pub output_tokens: u32,
171 pub total_tokens: u32,
172 #[serde(default)]
173 input_tokens_details: Option<InputTokensDetails>,
174}
175
176impl Usage {
177 pub fn cached_tokens(&self) -> u32 {
179 self.input_tokens_details.map_or(0, |d| d.cached_tokens)
180 }
181}
182
183impl std::ops::AddAssign for Usage {
184 fn add_assign(&mut self, rhs: Self) {
185 self.input_tokens += rhs.input_tokens;
186 self.output_tokens += rhs.output_tokens;
187 self.total_tokens += rhs.total_tokens;
188 let prev = self.input_tokens_details.unwrap_or_default().cached_tokens;
190 let added = rhs.input_tokens_details.unwrap_or_default().cached_tokens;
191 self.input_tokens_details = Some(InputTokensDetails {
192 cached_tokens: prev + added,
193 });
194 }
195}
196
197#[derive(Debug, thiserror::Error)]
200pub enum LlmError {
201 #[error("OpenAI API error (HTTP {status}): {body}")]
202 Api { status: u16, body: String },
203
204 #[error(transparent)]
205 Transport(#[from] reqwest::Error),
206
207 #[error(transparent)]
208 Other(#[from] anyhow::Error),
209}
210
211async fn send_with_retry(
223 client: &reqwest::Client,
224 url: &str,
225 body: &serde_json::Value,
226 timeout: Option<Duration>,
227) -> Result<reqwest::Response, LlmError> {
228 let mut attempt = 0u32;
229 loop {
230 let mut req = client.post(url).json(body);
231 if let Some(t) = timeout {
232 req = req.timeout(t);
233 }
234 let response = req.send().await?;
235 let status = response.status();
236
237 if status.is_success() {
238 return Ok(response);
239 }
240
241 let status_u16 = status.as_u16();
242 let is_retryable = RETRYABLE_STATUSES.contains(&status_u16);
243 let has_attempts_remaining = attempt + 1 < MAX_ATTEMPTS;
244
245 if !is_retryable || !has_attempts_remaining {
246 let body = response.text().await.unwrap_or_default();
247 return Err(LlmError::Api {
248 status: status_u16,
249 body,
250 });
251 }
252
253 let backoff = if status_u16 == 429 {
255 response
256 .headers()
257 .get("retry-after")
258 .and_then(|v| v.to_str().ok())
259 .and_then(|s| s.parse::<u64>().ok())
260 .map(Duration::from_secs)
261 .unwrap_or_else(|| Duration::from_secs(1u64 << attempt))
262 } else {
263 Duration::from_secs(1u64 << attempt)
264 };
265 let backoff = backoff.min(Duration::from_secs(30));
266
267 tracing::warn!(
268 status = status_u16,
269 attempt = attempt + 1,
270 backoff_secs = backoff.as_secs_f32(),
271 "transient API error — retrying"
272 );
273
274 tokio::time::sleep(backoff).await;
275 attempt += 1;
276 }
277}
278
279impl ChatGptClient {
282 pub fn new(api_key: &str, model: &str) -> Result<Self> {
286 let mut headers = header::HeaderMap::new();
287 let mut auth = header::HeaderValue::from_str(&format!("Bearer {api_key}"))
288 .context("invalid API key characters")?;
289 auth.set_sensitive(true);
290 headers.insert(header::AUTHORIZATION, auth);
291
292 let client = reqwest::Client::builder()
293 .default_headers(headers)
294 .connect_timeout(CONNECT_TIMEOUT)
295 .build()
296 .context("failed to build HTTP client")?;
297
298 let reasoning_effort = if model.starts_with("gpt-5") || model.starts_with("gpt-6") {
301 if model.contains("nano") {
302 Some("minimal".to_owned())
303 } else if model.contains("mini") {
304 Some("low".to_owned())
305 } else {
306 Some("medium".to_owned())
307 }
308 } else {
309 None
310 };
311
312 let prompt_cache_key = Some("poe2-agent-v1".to_owned());
314
315 let prompt_cache_retention = if model.starts_with("gpt-5.1")
317 || model.starts_with("gpt-5.2")
318 || model.starts_with("gpt-6")
319 {
320 Some("24h".to_owned())
321 } else {
322 None
323 };
324
325 Ok(Self {
326 client,
327 model: model.to_owned(),
328 reasoning_effort,
329 prompt_cache_key,
330 prompt_cache_retention,
331 max_output_tokens: DEFAULT_MAX_OUTPUT_TOKENS,
332 base_url: RESPONSES_API_URL.to_owned(),
333 })
334 }
335
336 #[cfg(test)]
338 fn new_with_base_url(api_key: &str, model: &str, base_url: &str) -> Result<Self> {
339 let mut client = Self::new(api_key, model)?;
340 client.base_url = base_url.to_owned();
341 Ok(client)
342 }
343
344 pub fn with_max_output_tokens(mut self, n: u32) -> Self {
346 self.max_output_tokens = n;
347 self
348 }
349
350 pub fn with_reasoning_effort(mut self, effort: &str) -> Self {
354 self.reasoning_effort = Some(effort.to_owned());
355 self
356 }
357
358 pub fn model(&self) -> &str {
360 &self.model
361 }
362
363 pub async fn create_response(
369 &self,
370 input: &[serde_json::Value],
371 instructions: Option<&str>,
372 tools: Option<&[ToolDefinition]>,
373 previous_response_id: Option<&str>,
374 ) -> Result<ApiResponse, LlmError> {
375 let request = ResponseRequest {
376 model: self.model.clone(),
377 input: input.to_vec(),
378 instructions: instructions.map(|s| s.to_owned()),
379 tools: tools.map(|t| t.to_vec()),
380 stream: false,
381 reasoning: self
382 .reasoning_effort
383 .as_ref()
384 .map(|e| ReasoningConfig { effort: e.clone() }),
385 previous_response_id: previous_response_id.map(|s| s.to_owned()),
386 store: true,
387 prompt_cache_key: self.prompt_cache_key.clone(),
388 prompt_cache_retention: self.prompt_cache_retention.clone(),
389 max_output_tokens: self.max_output_tokens,
390 truncation: "auto",
391 };
392
393 let body = serde_json::to_value(&request).map_err(|e| LlmError::Other(e.into()))?;
394 let response = send_with_retry(
395 &self.client,
396 &self.base_url,
397 &body,
398 Some(BLOCKING_RESPONSE_TIMEOUT),
399 )
400 .await?;
401
402 let parsed: ApiResponse = response.json().await?;
403 if let Some(ref u) = parsed.usage {
404 tracing::debug!(
405 input_tokens = u.input_tokens,
406 output_tokens = u.output_tokens,
407 cached_tokens = u.cached_tokens(),
408 total_tokens = u.total_tokens,
409 "llm response usage"
410 );
411 }
412 if parsed.status == "failed" {
413 let msg = parsed
414 .error
415 .as_ref()
416 .map(|e| e.message.as_str())
417 .unwrap_or("unknown error");
418 return Err(LlmError::Other(anyhow::anyhow!(
419 "API response failed: {msg}"
420 )));
421 }
422
423 Ok(parsed)
424 }
425
426 pub fn create_response_stream(
435 &self,
436 input: &[serde_json::Value],
437 instructions: Option<&str>,
438 tools: Option<&[ToolDefinition]>,
439 previous_response_id: Option<&str>,
440 ) -> impl Stream<Item = Result<ResponseStreamEvent, LlmError>> + Send {
441 let client = self.client.clone();
442 let url = self.base_url.clone();
443 let request = ResponseRequest {
444 model: self.model.clone(),
445 input: input.to_vec(),
446 instructions: instructions.map(|s| s.to_owned()),
447 tools: tools.map(|t| t.to_vec()),
448 stream: true,
449 reasoning: self
450 .reasoning_effort
451 .as_ref()
452 .map(|e| ReasoningConfig { effort: e.clone() }),
453 previous_response_id: previous_response_id.map(|s| s.to_owned()),
454 store: true,
455 prompt_cache_key: self.prompt_cache_key.clone(),
456 prompt_cache_retention: self.prompt_cache_retention.clone(),
457 max_output_tokens: self.max_output_tokens,
458 truncation: "auto",
459 };
460 let body =
463 serde_json::to_value(&request).expect("ResponseRequest serialization is infallible");
464
465 async_stream::try_stream! {
466 let mut response = send_with_retry(&client, &url, &body, None).await?;
469
470 let mut buffer = String::new();
471 let mut event_type = String::new();
472
473 while let Some(chunk) = response.chunk().await? {
474 buffer.push_str(&String::from_utf8_lossy(&chunk));
475
476 while let Some(pos) = buffer.find("\n\n") {
478 let event_block = buffer[..pos].to_owned();
479 buffer = buffer[pos + 2..].to_owned();
480
481 event_type.clear();
484 let mut data_line = None;
485 for line in event_block.lines() {
486 if let Some(et) = line.strip_prefix("event: ") {
487 event_type = et.trim().to_owned();
488 } else if let Some(d) = line.strip_prefix("data: ") {
489 data_line = Some(d.to_owned());
490 }
491 }
492
493 let data = match data_line {
494 Some(d) => d,
495 None => continue,
496 };
497
498 match event_type.as_str() {
499 "response.output_text.delta" => {
500 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&data) {
501 if let Some(delta) = parsed.get("delta").and_then(|d| d.as_str()) {
502 yield ResponseStreamEvent::TextDelta(delta.to_owned());
503 }
504 }
505 }
506 "response.output_item.done" => {
507 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&data) {
508 if let Some(item) = parsed.get("item") {
509 if item.get("type").and_then(|t| t.as_str()) == Some("function_call") {
510 if let Ok(fc) = serde_json::from_value::<FunctionCallItem>(item.clone()) {
511 yield ResponseStreamEvent::FunctionCall(fc);
512 }
513 }
514 }
515 }
516 }
517 "response.completed" => {
518 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&data) {
519 let id = parsed.pointer("/response/id")
520 .and_then(|v| v.as_str())
521 .unwrap_or_default()
522 .to_owned();
523 let usage = parsed.pointer("/response/usage")
524 .and_then(|v| serde_json::from_value::<Usage>(v.clone()).ok());
525 if let Some(ref u) = usage {
526 tracing::debug!(
527 input_tokens = u.input_tokens,
528 output_tokens = u.output_tokens,
529 cached_tokens = u.cached_tokens(),
530 total_tokens = u.total_tokens,
531 "llm stream response usage"
532 );
533 }
534 yield ResponseStreamEvent::ResponseCompleted { id, usage };
535 }
536 return;
537 }
538 "response.failed" | "response.incomplete" => {
539 let msg = serde_json::from_str::<serde_json::Value>(&data)
540 .ok()
541 .and_then(|v| {
542 v.pointer("/response/error/message")
543 .and_then(|m| m.as_str().map(|s| s.to_owned()))
544 })
545 .unwrap_or_else(|| format!("response {}", event_type));
546 Err(LlmError::Other(anyhow::anyhow!("{msg}")))?;
547 }
548 _ => {} }
550 }
551 }
552 }
553 }
554}
555
556#[cfg(test)]
559mod tests {
560 use super::*;
561 use wiremock::matchers::method;
562 use wiremock::{Mock, MockServer, ResponseTemplate};
563
564 fn success_body() -> serde_json::Value {
565 serde_json::json!({
566 "id": "resp_test",
567 "status": "completed",
568 "output": []
569 })
570 }
571
572 #[tokio::test]
573 async fn retry_on_429_respects_retry_after() {
574 let mock_server = MockServer::start().await;
575
576 Mock::given(method("POST"))
578 .respond_with(ResponseTemplate::new(429).insert_header("Retry-After", "1"))
579 .up_to_n_times(2)
580 .with_priority(1)
581 .mount(&mock_server)
582 .await;
583
584 Mock::given(method("POST"))
586 .respond_with(ResponseTemplate::new(200).set_body_json(success_body()))
587 .mount(&mock_server)
588 .await;
589
590 let client =
591 ChatGptClient::new_with_base_url("test-key", "gpt-4o", &mock_server.uri()).unwrap();
592 let result = client.create_response(&[], None, None, None).await;
593
594 assert!(result.is_ok(), "expected success after retries: {result:?}");
595 let requests = mock_server.received_requests().await.unwrap();
596 assert_eq!(requests.len(), 3, "expected exactly 3 requests");
597 }
598
599 #[tokio::test]
600 async fn retry_on_500_uses_exponential_backoff() {
601 let mock_server = MockServer::start().await;
602
603 Mock::given(method("POST"))
605 .respond_with(ResponseTemplate::new(500))
606 .up_to_n_times(2)
607 .with_priority(1)
608 .mount(&mock_server)
609 .await;
610
611 Mock::given(method("POST"))
613 .respond_with(ResponseTemplate::new(200).set_body_json(success_body()))
614 .mount(&mock_server)
615 .await;
616
617 let client =
618 ChatGptClient::new_with_base_url("test-key", "gpt-4o", &mock_server.uri()).unwrap();
619 let result = client.create_response(&[], None, None, None).await;
620
621 assert!(result.is_ok(), "expected success after retries: {result:?}");
622 let requests = mock_server.received_requests().await.unwrap();
623 assert_eq!(requests.len(), 3, "expected exactly 3 requests");
624 }
625
626 #[tokio::test]
627 async fn non_retryable_error_propagates() {
628 let mock_server = MockServer::start().await;
629
630 Mock::given(method("POST"))
631 .respond_with(ResponseTemplate::new(400).set_body_string("bad request"))
632 .mount(&mock_server)
633 .await;
634
635 let client =
636 ChatGptClient::new_with_base_url("test-key", "gpt-4o", &mock_server.uri()).unwrap();
637 let result = client.create_response(&[], None, None, None).await;
638
639 assert!(result.is_err());
640 let requests = mock_server.received_requests().await.unwrap();
641 assert_eq!(requests.len(), 1, "non-retryable error must not be retried");
642 match result.unwrap_err() {
643 LlmError::Api { status, .. } => assert_eq!(status, 400),
644 e => panic!("expected LlmError::Api, got {e:?}"),
645 }
646 }
647
648 #[tokio::test]
649 async fn max_retry_attempts_respected() {
650 let mock_server = MockServer::start().await;
651
652 Mock::given(method("POST"))
654 .respond_with(ResponseTemplate::new(503))
655 .mount(&mock_server)
656 .await;
657
658 let client =
659 ChatGptClient::new_with_base_url("test-key", "gpt-4o", &mock_server.uri()).unwrap();
660 let result = client.create_response(&[], None, None, None).await;
661
662 assert!(result.is_err());
663 let requests = mock_server.received_requests().await.unwrap();
664 assert_eq!(
665 requests.len(),
666 MAX_ATTEMPTS as usize,
667 "must stop after MAX_ATTEMPTS"
668 );
669 match result.unwrap_err() {
670 LlmError::Api { status, .. } => assert_eq!(status, 503),
671 e => panic!("expected LlmError::Api, got {e:?}"),
672 }
673 }
674}