use std::time::{Duration, Instant};
use reqwest::Client;
use serde_json::Value;
use tracing::info;
use crate::error::AgentError;
use crate::provider::{AgentConfig, AgentOutput, AgentProvider, DebugMessage, InvokeFuture};
use crate::providers::http::sse::{SseDelta, collect_sse_stream};
#[derive(Debug)]
pub struct TurnResult {
pub text: Option<String>,
#[allow(dead_code)]
pub tool_calls: Vec<HttpToolCall>,
pub is_final: bool,
pub structured_value: Option<Value>,
pub usage: HttpUsage,
pub model: Option<String>,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct HttpToolCall {
pub id: String,
pub name: String,
pub input: Value,
}
#[derive(Debug, Default)]
pub struct HttpUsage {
pub input_tokens: Option<u64>,
pub output_tokens: Option<u64>,
}
pub trait HttpAgentAdapter: Send + Sync + 'static {
fn provider_name(&self) -> &'static str;
fn endpoint_url(&self, model: &str) -> String;
fn auth_headers(&self) -> Vec<(String, String)>;
fn build_request(&self, config: &AgentConfig) -> Result<Value, AgentError>;
fn parse_response(&self, body: &Value, config: &AgentConfig) -> Result<TurnResult, AgentError>;
fn parse_sse_line(&self, line: &str) -> Option<SseDelta>;
fn fold_sse_deltas(
&self,
deltas: Vec<SseDelta>,
config: &AgentConfig,
) -> Result<TurnResult, AgentError>;
fn compute_cost(&self, model: &str, input_tokens: u64, output_tokens: u64) -> Option<f64>;
fn resolve_model(&self, model: &str) -> String;
}
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(120);
pub struct HttpAgentProvider<A: HttpAgentAdapter> {
adapter: A,
client: Client,
timeout: Duration,
}
impl<A: HttpAgentAdapter> HttpAgentProvider<A> {
pub fn new(adapter: A) -> Self {
let client = Client::builder()
.timeout(DEFAULT_TIMEOUT)
.build()
.expect("failed to build reqwest client");
Self {
adapter,
client,
timeout: DEFAULT_TIMEOUT,
}
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self.client = Client::builder()
.timeout(timeout)
.build()
.expect("failed to build reqwest client");
self
}
async fn execute_turn(
&self,
request_body: &Value,
config: &AgentConfig,
) -> Result<TurnResult, AgentError> {
let model = self.adapter.resolve_model(&config.model);
let url = self.adapter.endpoint_url(&model);
let headers = self.adapter.auth_headers();
let mut req = self.client.post(&url).json(request_body);
for (key, value) in &headers {
req = req.header(key, value);
}
let response = tokio::time::timeout(self.timeout, req.send())
.await
.map_err(|_| AgentError::Timeout {
limit: self.timeout,
})?
.map_err(|e| AgentError::ProcessFailed {
exit_code: -1,
stderr: format!("HTTP request failed: {e}"),
})?;
let status = response.status().as_u16();
if status == 429 {
let retry_after = response
.headers()
.get("retry-after")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u64>().ok());
return Err(AgentError::RateLimited {
provider: self.adapter.provider_name().to_string(),
retry_after_secs: retry_after,
});
}
if status >= 400 {
let body_text = response.text().await.unwrap_or_default();
let message = serde_json::from_str::<Value>(&body_text)
.ok()
.and_then(|v| {
v.get("error")
.and_then(|e| e.get("message"))
.and_then(|m| m.as_str())
.map(String::from)
})
.unwrap_or(body_text);
return Err(AgentError::HttpProvider {
provider: self.adapter.provider_name().to_string(),
status_code: status,
message,
});
}
if config.verbose {
let deltas = collect_sse_stream(&self.adapter, response, self.timeout).await?;
self.adapter.fold_sse_deltas(deltas, config)
} else {
let body: Value = response
.json()
.await
.map_err(|e| AgentError::ProcessFailed {
exit_code: -1,
stderr: format!("failed to parse response JSON: {e}"),
})?;
self.adapter.parse_response(&body, config)
}
}
}
impl<A: HttpAgentAdapter> AgentProvider for HttpAgentProvider<A> {
fn invoke<'a>(&'a self, config: &'a AgentConfig) -> InvokeFuture<'a> {
Box::pin(async move {
let start = Instant::now();
let request_body = self.adapter.build_request(config)?;
let turn_result = self.execute_turn(&request_body, config).await?;
let duration_ms = start.elapsed().as_millis() as u64;
let input_tokens = turn_result.usage.input_tokens.unwrap_or(0);
let output_tokens = turn_result.usage.output_tokens.unwrap_or(0);
let model_name = turn_result.model.clone();
let cost = model_name
.as_deref()
.and_then(|m| self.adapter.compute_cost(m, input_tokens, output_tokens));
let debug_messages = if config.verbose {
Some(vec![DebugMessage {
text: turn_result.text.clone(),
thinking: None,
thinking_redacted: false,
tool_calls: Vec::new(),
tool_results: Vec::new(),
stop_reason: if turn_result.is_final {
Some("end_turn".to_string())
} else {
Some("tool_use".to_string())
},
input_tokens: turn_result.usage.input_tokens,
output_tokens: turn_result.usage.output_tokens,
}])
} else {
None
};
let value = if let Some(structured) = turn_result.structured_value {
structured
} else {
turn_result
.text
.map(Value::String)
.unwrap_or(Value::String(String::new()))
};
info!(
provider = self.adapter.provider_name(),
duration_ms, input_tokens, output_tokens, "invocation complete"
);
Ok(AgentOutput {
value,
session_id: None,
cost_usd: cost,
input_tokens: Some(input_tokens),
output_tokens: Some(output_tokens),
model: model_name,
duration_ms,
debug_messages,
})
})
}
}