ironflow_core/providers/http/
adapter.rs1use std::time::{Duration, Instant};
4
5use reqwest::Client;
6use serde_json::Value;
7use tracing::info;
8
9use crate::error::AgentError;
10use crate::provider::{AgentConfig, AgentOutput, AgentProvider, DebugMessage, InvokeFuture};
11use crate::providers::http::sse::{SseDelta, collect_sse_stream};
12
13#[derive(Debug)]
15pub struct TurnResult {
16 pub text: Option<String>,
18 #[allow(dead_code)]
20 pub tool_calls: Vec<HttpToolCall>,
21 pub is_final: bool,
23 pub structured_value: Option<Value>,
25 pub usage: HttpUsage,
27 pub model: Option<String>,
29}
30
31#[derive(Debug, Clone)]
33#[allow(dead_code)]
34pub struct HttpToolCall {
35 pub id: String,
37 pub name: String,
39 pub input: Value,
41}
42
43#[derive(Debug, Default)]
45pub struct HttpUsage {
46 pub input_tokens: Option<u64>,
48 pub output_tokens: Option<u64>,
50}
51
52pub trait HttpAgentAdapter: Send + Sync + 'static {
58 fn provider_name(&self) -> &'static str;
60
61 fn endpoint_url(&self, model: &str) -> String;
63
64 fn auth_headers(&self) -> Vec<(String, String)>;
66
67 fn build_request(&self, config: &AgentConfig) -> Result<Value, AgentError>;
69
70 fn parse_response(&self, body: &Value, config: &AgentConfig) -> Result<TurnResult, AgentError>;
72
73 fn parse_sse_line(&self, line: &str) -> Option<SseDelta>;
75
76 fn fold_sse_deltas(
78 &self,
79 deltas: Vec<SseDelta>,
80 config: &AgentConfig,
81 ) -> Result<TurnResult, AgentError>;
82
83 fn compute_cost(&self, model: &str, input_tokens: u64, output_tokens: u64) -> Option<f64>;
85
86 fn resolve_model(&self, model: &str) -> String;
88}
89
90const DEFAULT_TIMEOUT: Duration = Duration::from_secs(120);
92
93pub struct HttpAgentProvider<A: HttpAgentAdapter> {
99 adapter: A,
100 client: Client,
101 timeout: Duration,
102}
103
104impl<A: HttpAgentAdapter> HttpAgentProvider<A> {
105 pub fn new(adapter: A) -> Self {
107 let client = Client::builder()
108 .timeout(DEFAULT_TIMEOUT)
109 .build()
110 .expect("failed to build reqwest client");
111 Self {
112 adapter,
113 client,
114 timeout: DEFAULT_TIMEOUT,
115 }
116 }
117
118 pub fn with_timeout(mut self, timeout: Duration) -> Self {
120 self.timeout = timeout;
121 self.client = Client::builder()
122 .timeout(timeout)
123 .build()
124 .expect("failed to build reqwest client");
125 self
126 }
127
128 async fn execute_turn(
129 &self,
130 request_body: &Value,
131 config: &AgentConfig,
132 ) -> Result<TurnResult, AgentError> {
133 let model = self.adapter.resolve_model(&config.model);
134 let url = self.adapter.endpoint_url(&model);
135 let headers = self.adapter.auth_headers();
136
137 let mut req = self.client.post(&url).json(request_body);
138 for (key, value) in &headers {
139 req = req.header(key, value);
140 }
141
142 let response = tokio::time::timeout(self.timeout, req.send())
143 .await
144 .map_err(|_| AgentError::Timeout {
145 limit: self.timeout,
146 })?
147 .map_err(|e| {
148 if e.is_timeout() {
149 AgentError::Timeout {
150 limit: self.timeout,
151 }
152 } else {
153 AgentError::HttpProvider {
154 provider: self.adapter.provider_name().to_string(),
155 status_code: 0,
156 message: format!("connection failed: {e}"),
157 }
158 }
159 })?;
160
161 let status = response.status().as_u16();
162
163 if status == 429 {
164 let retry_after = response
165 .headers()
166 .get("retry-after")
167 .and_then(|v| v.to_str().ok())
168 .and_then(|v| v.parse::<u64>().ok());
169 return Err(AgentError::RateLimited {
170 provider: self.adapter.provider_name().to_string(),
171 retry_after_secs: retry_after,
172 });
173 }
174
175 if status >= 400 {
176 let body_text = response.text().await.unwrap_or_default();
177 let message = serde_json::from_str::<Value>(&body_text)
178 .ok()
179 .and_then(|v| {
180 v.get("error")
181 .and_then(|e| e.get("message"))
182 .and_then(|m| m.as_str())
183 .map(String::from)
184 })
185 .unwrap_or(body_text);
186 return Err(AgentError::HttpProvider {
187 provider: self.adapter.provider_name().to_string(),
188 status_code: status,
189 message,
190 });
191 }
192
193 if config.verbose {
194 let deltas = collect_sse_stream(&self.adapter, response, self.timeout).await?;
195 self.adapter.fold_sse_deltas(deltas, config)
196 } else {
197 let body: Value = response
198 .json()
199 .await
200 .map_err(|e| AgentError::HttpProvider {
201 provider: self.adapter.provider_name().to_string(),
202 status_code: 0,
203 message: format!("failed to parse response JSON: {e}"),
204 })?;
205 self.adapter.parse_response(&body, config)
206 }
207 }
208}
209
210impl<A: HttpAgentAdapter> AgentProvider for HttpAgentProvider<A> {
211 fn invoke<'a>(&'a self, config: &'a AgentConfig) -> InvokeFuture<'a> {
212 Box::pin(async move {
213 let start = Instant::now();
214
215 let request_body = self.adapter.build_request(config)?;
216 let turn_result = self.execute_turn(&request_body, config).await?;
217
218 let duration_ms = start.elapsed().as_millis() as u64;
219 let input_tokens = turn_result.usage.input_tokens.unwrap_or(0);
220 let output_tokens = turn_result.usage.output_tokens.unwrap_or(0);
221 let model_name = turn_result.model.clone();
222
223 let cost = model_name
224 .as_deref()
225 .and_then(|m| self.adapter.compute_cost(m, input_tokens, output_tokens));
226
227 let debug_messages = if config.verbose {
228 Some(vec![DebugMessage {
229 text: turn_result.text.clone(),
230 thinking: None,
231 thinking_redacted: false,
232 tool_calls: Vec::new(),
233 tool_results: Vec::new(),
234 stop_reason: if turn_result.is_final {
235 Some("end_turn".to_string())
236 } else {
237 Some("tool_use".to_string())
238 },
239 input_tokens: turn_result.usage.input_tokens,
240 output_tokens: turn_result.usage.output_tokens,
241 }])
242 } else {
243 None
244 };
245
246 let value = if let Some(structured) = turn_result.structured_value {
247 structured
248 } else {
249 turn_result
250 .text
251 .map(Value::String)
252 .unwrap_or(Value::String(String::new()))
253 };
254
255 info!(
256 provider = self.adapter.provider_name(),
257 duration_ms, input_tokens, output_tokens, "invocation complete"
258 );
259
260 Ok(AgentOutput {
261 value,
262 session_id: None,
263 cost_usd: cost,
264 input_tokens: Some(input_tokens),
265 output_tokens: Some(output_tokens),
266 model: model_name,
267 duration_ms,
268 debug_messages,
269 })
270 })
271 }
272}