use anyhow::{Context, Result};
use futures_core::Stream;
use reqwest::header;
use serde::{Deserialize, Serialize};
use std::time::Duration;
const RESPONSES_API_URL: &str = "https://api.openai.com/v1/responses";
const DEFAULT_MAX_OUTPUT_TOKENS: u32 = 4096;
const CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
const BLOCKING_RESPONSE_TIMEOUT: Duration = Duration::from_secs(120);
const RETRYABLE_STATUSES: &[u16] = &[429, 500, 502, 503];
const MAX_ATTEMPTS: u32 = 3;
#[derive(Clone)]
pub struct ChatGptClient {
client: reqwest::Client,
model: String,
reasoning_effort: Option<String>,
prompt_cache_key: Option<String>,
prompt_cache_retention: Option<String>,
max_output_tokens: u32,
base_url: String,
}
#[derive(Debug, Serialize, Clone)]
pub struct ToolDefinition {
#[serde(rename = "type")]
pub tool_type: String,
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Debug, Serialize)]
struct ResponseRequest {
model: String,
input: Vec<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
instructions: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<ToolDefinition>>,
#[serde(skip_serializing_if = "std::ops::Not::not")]
stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
reasoning: Option<ReasoningConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
previous_response_id: Option<String>,
store: bool,
#[serde(skip_serializing_if = "Option::is_none")]
prompt_cache_key: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
prompt_cache_retention: Option<String>,
max_output_tokens: u32,
truncation: &'static str,
}
#[derive(Debug, Serialize, Clone)]
struct ReasoningConfig {
effort: String,
}
#[derive(Debug, Deserialize)]
pub struct ApiResponse {
pub id: String,
pub status: String,
pub output: Vec<serde_json::Value>,
#[serde(default)]
pub output_text: Option<String>,
#[serde(default)]
pub usage: Option<Usage>,
#[serde(default)]
pub error: Option<ApiResponseError>,
}
#[derive(Debug, Deserialize)]
pub struct ApiResponseError {
pub message: String,
#[serde(default)]
pub code: Option<String>,
}
#[derive(Debug, Deserialize, Clone)]
pub struct FunctionCallItem {
pub id: String,
pub name: String,
pub call_id: String,
pub arguments: String,
pub status: String,
}
pub enum ResponseStreamEvent {
TextDelta(String),
FunctionCall(FunctionCallItem),
ResponseCompleted { id: String, usage: Option<Usage> },
}
impl ApiResponse {
pub fn function_calls(&self) -> Vec<FunctionCallItem> {
self.output
.iter()
.filter_map(|item| {
if item.get("type")?.as_str()? == "function_call" {
serde_json::from_value(item.clone()).ok()
} else {
None
}
})
.collect()
}
}
pub fn input_message(role: &str, content: &str) -> serde_json::Value {
serde_json::json!({ "type": "message", "role": role, "content": content })
}
pub fn input_function_call_output(call_id: &str, output: &str) -> serde_json::Value {
serde_json::json!({ "type": "function_call_output", "call_id": call_id, "output": output })
}
#[derive(Debug, Deserialize, Default, Clone, Copy)]
struct InputTokensDetails {
#[serde(default)]
cached_tokens: u32,
}
#[derive(Debug, Deserialize, Default, Clone, Copy)]
pub struct Usage {
pub input_tokens: u32,
pub output_tokens: u32,
pub total_tokens: u32,
#[serde(default)]
input_tokens_details: Option<InputTokensDetails>,
}
impl Usage {
pub fn cached_tokens(&self) -> u32 {
self.input_tokens_details.map_or(0, |d| d.cached_tokens)
}
}
impl std::ops::AddAssign for Usage {
fn add_assign(&mut self, rhs: Self) {
self.input_tokens += rhs.input_tokens;
self.output_tokens += rhs.output_tokens;
self.total_tokens += rhs.total_tokens;
let prev = self.input_tokens_details.unwrap_or_default().cached_tokens;
let added = rhs.input_tokens_details.unwrap_or_default().cached_tokens;
self.input_tokens_details = Some(InputTokensDetails {
cached_tokens: prev + added,
});
}
}
#[derive(Debug, thiserror::Error)]
pub enum LlmError {
#[error("OpenAI API error (HTTP {status}): {body}")]
Api { status: u16, body: String },
#[error(transparent)]
Transport(#[from] reqwest::Error),
#[error(transparent)]
Other(#[from] anyhow::Error),
}
async fn send_with_retry(
client: &reqwest::Client,
url: &str,
body: &serde_json::Value,
timeout: Option<Duration>,
) -> Result<reqwest::Response, LlmError> {
let mut attempt = 0u32;
loop {
let mut req = client.post(url).json(body);
if let Some(t) = timeout {
req = req.timeout(t);
}
let response = req.send().await?;
let status = response.status();
if status.is_success() {
return Ok(response);
}
let status_u16 = status.as_u16();
let is_retryable = RETRYABLE_STATUSES.contains(&status_u16);
let has_attempts_remaining = attempt + 1 < MAX_ATTEMPTS;
if !is_retryable || !has_attempts_remaining {
let body = response.text().await.unwrap_or_default();
return Err(LlmError::Api {
status: status_u16,
body,
});
}
let backoff = if status_u16 == 429 {
response
.headers()
.get("retry-after")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
.map(Duration::from_secs)
.unwrap_or_else(|| Duration::from_secs(1u64 << attempt))
} else {
Duration::from_secs(1u64 << attempt)
};
let backoff = backoff.min(Duration::from_secs(30));
tracing::warn!(
status = status_u16,
attempt = attempt + 1,
backoff_secs = backoff.as_secs_f32(),
"transient API error — retrying"
);
tokio::time::sleep(backoff).await;
attempt += 1;
}
}
impl ChatGptClient {
pub fn new(api_key: &str, model: &str) -> Result<Self> {
let mut headers = header::HeaderMap::new();
let mut auth = header::HeaderValue::from_str(&format!("Bearer {api_key}"))
.context("invalid API key characters")?;
auth.set_sensitive(true);
headers.insert(header::AUTHORIZATION, auth);
let client = reqwest::Client::builder()
.default_headers(headers)
.connect_timeout(CONNECT_TIMEOUT)
.build()
.context("failed to build HTTP client")?;
let reasoning_effort = if model.starts_with("gpt-5") || model.starts_with("gpt-6") {
if model.contains("nano") {
Some("minimal".to_owned())
} else if model.contains("mini") {
Some("low".to_owned())
} else {
Some("medium".to_owned())
}
} else {
None
};
let prompt_cache_key = Some("poe2-agent-v1".to_owned());
let prompt_cache_retention = if model.starts_with("gpt-5.1")
|| model.starts_with("gpt-5.2")
|| model.starts_with("gpt-6")
{
Some("24h".to_owned())
} else {
None
};
Ok(Self {
client,
model: model.to_owned(),
reasoning_effort,
prompt_cache_key,
prompt_cache_retention,
max_output_tokens: DEFAULT_MAX_OUTPUT_TOKENS,
base_url: RESPONSES_API_URL.to_owned(),
})
}
#[cfg(test)]
fn new_with_base_url(api_key: &str, model: &str, base_url: &str) -> Result<Self> {
let mut client = Self::new(api_key, model)?;
client.base_url = base_url.to_owned();
Ok(client)
}
pub fn with_max_output_tokens(mut self, n: u32) -> Self {
self.max_output_tokens = n;
self
}
pub fn with_reasoning_effort(mut self, effort: &str) -> Self {
self.reasoning_effort = Some(effort.to_owned());
self
}
pub fn model(&self) -> &str {
&self.model
}
pub async fn create_response(
&self,
input: &[serde_json::Value],
instructions: Option<&str>,
tools: Option<&[ToolDefinition]>,
previous_response_id: Option<&str>,
) -> Result<ApiResponse, LlmError> {
let request = ResponseRequest {
model: self.model.clone(),
input: input.to_vec(),
instructions: instructions.map(|s| s.to_owned()),
tools: tools.map(|t| t.to_vec()),
stream: false,
reasoning: self
.reasoning_effort
.as_ref()
.map(|e| ReasoningConfig { effort: e.clone() }),
previous_response_id: previous_response_id.map(|s| s.to_owned()),
store: true,
prompt_cache_key: self.prompt_cache_key.clone(),
prompt_cache_retention: self.prompt_cache_retention.clone(),
max_output_tokens: self.max_output_tokens,
truncation: "auto",
};
let body = serde_json::to_value(&request).map_err(|e| LlmError::Other(e.into()))?;
let response = send_with_retry(
&self.client,
&self.base_url,
&body,
Some(BLOCKING_RESPONSE_TIMEOUT),
)
.await?;
let parsed: ApiResponse = response.json().await?;
if let Some(ref u) = parsed.usage {
tracing::debug!(
input_tokens = u.input_tokens,
output_tokens = u.output_tokens,
cached_tokens = u.cached_tokens(),
total_tokens = u.total_tokens,
"llm response usage"
);
}
if parsed.status == "failed" {
let msg = parsed
.error
.as_ref()
.map(|e| e.message.as_str())
.unwrap_or("unknown error");
return Err(LlmError::Other(anyhow::anyhow!(
"API response failed: {msg}"
)));
}
Ok(parsed)
}
pub fn create_response_stream(
&self,
input: &[serde_json::Value],
instructions: Option<&str>,
tools: Option<&[ToolDefinition]>,
previous_response_id: Option<&str>,
) -> impl Stream<Item = Result<ResponseStreamEvent, LlmError>> + Send {
let client = self.client.clone();
let url = self.base_url.clone();
let request = ResponseRequest {
model: self.model.clone(),
input: input.to_vec(),
instructions: instructions.map(|s| s.to_owned()),
tools: tools.map(|t| t.to_vec()),
stream: true,
reasoning: self
.reasoning_effort
.as_ref()
.map(|e| ReasoningConfig { effort: e.clone() }),
previous_response_id: previous_response_id.map(|s| s.to_owned()),
store: true,
prompt_cache_key: self.prompt_cache_key.clone(),
prompt_cache_retention: self.prompt_cache_retention.clone(),
max_output_tokens: self.max_output_tokens,
truncation: "auto",
};
let body =
serde_json::to_value(&request).expect("ResponseRequest serialization is infallible");
async_stream::try_stream! {
let mut response = send_with_retry(&client, &url, &body, None).await?;
let mut buffer = String::new();
let mut event_type = String::new();
while let Some(chunk) = response.chunk().await? {
buffer.push_str(&String::from_utf8_lossy(&chunk));
while let Some(pos) = buffer.find("\n\n") {
let event_block = buffer[..pos].to_owned();
buffer = buffer[pos + 2..].to_owned();
event_type.clear();
let mut data_line = None;
for line in event_block.lines() {
if let Some(et) = line.strip_prefix("event: ") {
event_type = et.trim().to_owned();
} else if let Some(d) = line.strip_prefix("data: ") {
data_line = Some(d.to_owned());
}
}
let data = match data_line {
Some(d) => d,
None => continue,
};
match event_type.as_str() {
"response.output_text.delta" => {
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&data) {
if let Some(delta) = parsed.get("delta").and_then(|d| d.as_str()) {
yield ResponseStreamEvent::TextDelta(delta.to_owned());
}
}
}
"response.output_item.done" => {
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&data) {
if let Some(item) = parsed.get("item") {
if item.get("type").and_then(|t| t.as_str()) == Some("function_call") {
if let Ok(fc) = serde_json::from_value::<FunctionCallItem>(item.clone()) {
yield ResponseStreamEvent::FunctionCall(fc);
}
}
}
}
}
"response.completed" => {
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&data) {
let id = parsed.pointer("/response/id")
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_owned();
let usage = parsed.pointer("/response/usage")
.and_then(|v| serde_json::from_value::<Usage>(v.clone()).ok());
if let Some(ref u) = usage {
tracing::debug!(
input_tokens = u.input_tokens,
output_tokens = u.output_tokens,
cached_tokens = u.cached_tokens(),
total_tokens = u.total_tokens,
"llm stream response usage"
);
}
yield ResponseStreamEvent::ResponseCompleted { id, usage };
}
return;
}
"response.failed" | "response.incomplete" => {
let msg = serde_json::from_str::<serde_json::Value>(&data)
.ok()
.and_then(|v| {
v.pointer("/response/error/message")
.and_then(|m| m.as_str().map(|s| s.to_owned()))
})
.unwrap_or_else(|| format!("response {}", event_type));
Err(LlmError::Other(anyhow::anyhow!("{msg}")))?;
}
_ => {} }
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use wiremock::matchers::method;
use wiremock::{Mock, MockServer, ResponseTemplate};
fn success_body() -> serde_json::Value {
serde_json::json!({
"id": "resp_test",
"status": "completed",
"output": []
})
}
#[tokio::test]
async fn retry_on_429_respects_retry_after() {
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(429).insert_header("Retry-After", "1"))
.up_to_n_times(2)
.with_priority(1)
.mount(&mock_server)
.await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(200).set_body_json(success_body()))
.mount(&mock_server)
.await;
let client =
ChatGptClient::new_with_base_url("test-key", "gpt-4o", &mock_server.uri()).unwrap();
let result = client.create_response(&[], None, None, None).await;
assert!(result.is_ok(), "expected success after retries: {result:?}");
let requests = mock_server.received_requests().await.unwrap();
assert_eq!(requests.len(), 3, "expected exactly 3 requests");
}
#[tokio::test]
async fn retry_on_500_uses_exponential_backoff() {
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(500))
.up_to_n_times(2)
.with_priority(1)
.mount(&mock_server)
.await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(200).set_body_json(success_body()))
.mount(&mock_server)
.await;
let client =
ChatGptClient::new_with_base_url("test-key", "gpt-4o", &mock_server.uri()).unwrap();
let result = client.create_response(&[], None, None, None).await;
assert!(result.is_ok(), "expected success after retries: {result:?}");
let requests = mock_server.received_requests().await.unwrap();
assert_eq!(requests.len(), 3, "expected exactly 3 requests");
}
#[tokio::test]
async fn non_retryable_error_propagates() {
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(400).set_body_string("bad request"))
.mount(&mock_server)
.await;
let client =
ChatGptClient::new_with_base_url("test-key", "gpt-4o", &mock_server.uri()).unwrap();
let result = client.create_response(&[], None, None, None).await;
assert!(result.is_err());
let requests = mock_server.received_requests().await.unwrap();
assert_eq!(requests.len(), 1, "non-retryable error must not be retried");
match result.unwrap_err() {
LlmError::Api { status, .. } => assert_eq!(status, 400),
e => panic!("expected LlmError::Api, got {e:?}"),
}
}
#[tokio::test]
async fn max_retry_attempts_respected() {
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(503))
.mount(&mock_server)
.await;
let client =
ChatGptClient::new_with_base_url("test-key", "gpt-4o", &mock_server.uri()).unwrap();
let result = client.create_response(&[], None, None, None).await;
assert!(result.is_err());
let requests = mock_server.received_requests().await.unwrap();
assert_eq!(
requests.len(),
MAX_ATTEMPTS as usize,
"must stop after MAX_ATTEMPTS"
);
match result.unwrap_err() {
LlmError::Api { status, .. } => assert_eq!(status, 503),
e => panic!("expected LlmError::Api, got {e:?}"),
}
}
}