ironflow_core/providers/http/
adapter.rs1use std::time::{Duration, Instant};
4
5use reqwest::Client;
6use serde_json::{Value, json};
7use tracing::{debug, info, warn};
8
9use crate::error::AgentError;
10use crate::provider::{
11 AgentConfig, AgentOutput, AgentProvider, DebugMessage, DebugToolCall, DebugToolResult,
12 InvokeFuture,
13};
14use crate::providers::http::sse::{SseDelta, collect_sse_stream};
15use crate::providers::http::tools::ToolRegistry;
16
17#[derive(Debug)]
19pub struct TurnResult {
20 pub text: Option<String>,
22 #[allow(dead_code)]
24 pub tool_calls: Vec<HttpToolCall>,
25 pub is_final: bool,
27 pub structured_value: Option<Value>,
29 pub usage: HttpUsage,
31 pub model: Option<String>,
33}
34
35#[derive(Debug, Clone)]
37#[allow(dead_code)]
38pub struct HttpToolCall {
39 pub id: String,
41 pub name: String,
43 pub input: Value,
45}
46
47#[derive(Debug, Default)]
49pub struct HttpUsage {
50 pub input_tokens: Option<u64>,
52 pub output_tokens: Option<u64>,
54}
55
56pub trait HttpAgentAdapter: Send + Sync + 'static {
62 fn provider_name(&self) -> &'static str;
64
65 fn endpoint_url(&self, model: &str) -> String;
67
68 fn auth_headers(&self) -> Vec<(String, String)>;
70
71 fn build_request(&self, config: &AgentConfig) -> Result<Value, AgentError>;
73
74 fn parse_response(&self, body: &Value, config: &AgentConfig) -> Result<TurnResult, AgentError>;
76
77 fn parse_sse_line(&self, line: &str) -> Option<SseDelta>;
79
80 fn fold_sse_deltas(
82 &self,
83 deltas: Vec<SseDelta>,
84 config: &AgentConfig,
85 ) -> Result<TurnResult, AgentError>;
86
87 fn compute_cost(&self, model: &str, input_tokens: u64, output_tokens: u64) -> Option<f64>;
89
90 fn resolve_model(&self, model: &str) -> String;
92}
93
94const DEFAULT_TIMEOUT: Duration = Duration::from_secs(120);
96
97pub struct HttpAgentProvider<A: HttpAgentAdapter> {
110 adapter: A,
111 client: Client,
112 timeout: Duration,
113 tool_registry: Option<ToolRegistry>,
114}
115
116impl<A: HttpAgentAdapter> HttpAgentProvider<A> {
117 pub fn new(adapter: A) -> Self {
119 let client = Client::builder()
120 .timeout(DEFAULT_TIMEOUT)
121 .build()
122 .expect("failed to build reqwest client");
123 Self {
124 adapter,
125 client,
126 timeout: DEFAULT_TIMEOUT,
127 tool_registry: None,
128 }
129 }
130
131 pub fn with_tools(mut self, registry: ToolRegistry) -> Self {
138 self.tool_registry = Some(registry);
139 self
140 }
141
142 pub fn with_timeout(mut self, timeout: Duration) -> Self {
144 self.timeout = timeout;
145 self.client = Client::builder()
146 .timeout(timeout)
147 .build()
148 .expect("failed to build reqwest client");
149 self
150 }
151
152 async fn execute_turn(
153 &self,
154 request_body: &Value,
155 config: &AgentConfig,
156 ) -> Result<TurnResult, AgentError> {
157 let model = self.adapter.resolve_model(&config.model);
158 let url = self.adapter.endpoint_url(&model);
159 let headers = self.adapter.auth_headers();
160
161 let mut req = self.client.post(&url).json(request_body);
162 for (key, value) in &headers {
163 req = req.header(key, value);
164 }
165
166 let response = tokio::time::timeout(self.timeout, req.send())
167 .await
168 .map_err(|_| AgentError::Timeout {
169 limit: self.timeout,
170 })?
171 .map_err(|e| {
172 if e.is_timeout() {
173 AgentError::Timeout {
174 limit: self.timeout,
175 }
176 } else {
177 AgentError::HttpProvider {
178 provider: self.adapter.provider_name().to_string(),
179 status_code: 0,
180 message: format!("connection failed: {e}"),
181 }
182 }
183 })?;
184
185 let status = response.status().as_u16();
186
187 if status == 429 {
188 let retry_after = response
189 .headers()
190 .get("retry-after")
191 .and_then(|v| v.to_str().ok())
192 .and_then(|v| v.parse::<u64>().ok());
193 return Err(AgentError::RateLimited {
194 provider: self.adapter.provider_name().to_string(),
195 retry_after_secs: retry_after,
196 });
197 }
198
199 if status >= 400 {
200 let body_text = response.text().await.unwrap_or_default();
201 let message = serde_json::from_str::<Value>(&body_text)
202 .ok()
203 .and_then(|v| {
204 v.get("error")
205 .and_then(|e| e.get("message"))
206 .and_then(|m| m.as_str())
207 .map(String::from)
208 })
209 .unwrap_or(body_text);
210 return Err(AgentError::HttpProvider {
211 provider: self.adapter.provider_name().to_string(),
212 status_code: status,
213 message,
214 });
215 }
216
217 if config.verbose {
218 let deltas = collect_sse_stream(&self.adapter, response, self.timeout).await?;
219 self.adapter.fold_sse_deltas(deltas, config)
220 } else {
221 let body: Value = response
222 .json()
223 .await
224 .map_err(|e| AgentError::HttpProvider {
225 provider: self.adapter.provider_name().to_string(),
226 status_code: 0,
227 message: format!("failed to parse response JSON: {e}"),
228 })?;
229 self.adapter.parse_response(&body, config)
230 }
231 }
232}
233
234struct LoopState {
236 start: Instant,
237 total_input_tokens: u64,
238 total_output_tokens: u64,
239 total_cost: f64,
240 model_name: Option<String>,
241 debug_messages: Vec<DebugMessage>,
242 verbose: bool,
243}
244
245impl LoopState {
246 fn new(start: Instant, verbose: bool) -> Self {
247 Self {
248 start,
249 total_input_tokens: 0,
250 total_output_tokens: 0,
251 total_cost: 0.0,
252 model_name: None,
253 debug_messages: Vec::new(),
254 verbose,
255 }
256 }
257
258 fn into_output(self, value: Value) -> AgentOutput {
259 AgentOutput {
260 value,
261 session_id: None,
262 cost_usd: if self.total_cost > 0.0 {
263 Some(self.total_cost)
264 } else {
265 None
266 },
267 input_tokens: Some(self.total_input_tokens),
268 output_tokens: Some(self.total_output_tokens),
269 model: self.model_name,
270 duration_ms: self.start.elapsed().as_millis() as u64,
271 debug_messages: if self.verbose {
272 Some(self.debug_messages)
273 } else {
274 None
275 },
276 }
277 }
278}
279
280fn extract_value(turn_result: &TurnResult) -> Value {
282 if let Some(ref structured) = turn_result.structured_value {
283 structured.clone()
284 } else {
285 turn_result
286 .text
287 .as_ref()
288 .map(|t| Value::String(t.clone()))
289 .unwrap_or(Value::String(String::new()))
290 }
291}
292
293fn extract_text_value(turn_result: &TurnResult) -> Value {
295 turn_result
296 .text
297 .as_ref()
298 .map(|t| Value::String(t.clone()))
299 .unwrap_or(Value::String(String::new()))
300}
301
302impl<A: HttpAgentAdapter> AgentProvider for HttpAgentProvider<A> {
303 fn invoke<'a>(&'a self, config: &'a AgentConfig) -> InvokeFuture<'a> {
304 Box::pin(async move {
305 let mut request_body = self.adapter.build_request(config)?;
306
307 if let Some(ref registry) = self.tool_registry
309 && !registry.is_empty()
310 {
311 let tools_array = registry.to_openai_tools();
312 request_body["tools"] = Value::Array(tools_array);
313 }
314
315 let max_turns = config.max_turns.unwrap_or(25) as usize;
316 let max_budget = config.max_budget_usd.unwrap_or(f64::MAX);
317 let mut state = LoopState::new(Instant::now(), config.verbose);
318
319 let mut messages: Vec<Value> = request_body
321 .get("messages")
322 .and_then(|m| m.as_array())
323 .cloned()
324 .unwrap_or_default();
325
326 for turn in 0..max_turns {
327 request_body["messages"] = Value::Array(messages.clone());
328 let turn_result = self.execute_turn(&request_body, config).await?;
329
330 let turn_input = turn_result.usage.input_tokens.unwrap_or(0);
332 let turn_output = turn_result.usage.output_tokens.unwrap_or(0);
333 state.total_input_tokens += turn_input;
334 state.total_output_tokens += turn_output;
335
336 if state.model_name.is_none() {
337 state.model_name = turn_result.model.clone();
338 }
339
340 if let Some(ref model) = state.model_name
341 && let Some(turn_cost) =
342 self.adapter.compute_cost(model, turn_input, turn_output)
343 {
344 state.total_cost += turn_cost;
345 }
346
347 if config.verbose {
349 let tool_calls_debug: Vec<DebugToolCall> = turn_result
350 .tool_calls
351 .iter()
352 .map(|tc| DebugToolCall {
353 id: Some(tc.id.clone()),
354 name: tc.name.clone(),
355 input: tc.input.clone(),
356 })
357 .collect();
358
359 state.debug_messages.push(DebugMessage {
360 text: turn_result.text.clone(),
361 thinking: None,
362 thinking_redacted: false,
363 tool_calls: tool_calls_debug,
364 tool_results: Vec::new(),
365 stop_reason: if turn_result.is_final {
366 Some("end_turn".to_string())
367 } else {
368 Some("tool_use".to_string())
369 },
370 input_tokens: Some(turn_input),
371 output_tokens: Some(turn_output),
372 });
373 }
374
375 if turn_result.is_final || turn_result.tool_calls.is_empty() {
377 info!(
378 provider = self.adapter.provider_name(),
379 turns = turn + 1,
380 duration_ms = state.start.elapsed().as_millis() as u64,
381 input_tokens = state.total_input_tokens,
382 output_tokens = state.total_output_tokens,
383 "invocation complete"
384 );
385 return Ok(state.into_output(extract_value(&turn_result)));
386 }
387
388 let registry = match self.tool_registry {
390 Some(ref r) => r,
391 None => {
392 warn!(
393 provider = self.adapter.provider_name(),
394 tool_calls = turn_result.tool_calls.len(),
395 "model requested tool calls but no registry attached, returning text"
396 );
397 return Ok(state.into_output(extract_text_value(&turn_result)));
398 }
399 };
400
401 if state.total_cost >= max_budget {
403 warn!(
404 provider = self.adapter.provider_name(),
405 cost = state.total_cost,
406 budget = max_budget,
407 "budget exceeded, stopping agentic loop"
408 );
409 return Ok(state.into_output(extract_text_value(&turn_result)));
410 }
411
412 let assistant_tool_calls: Vec<Value> = turn_result
414 .tool_calls
415 .iter()
416 .map(|tc| {
417 json!({
418 "id": tc.id,
419 "type": "function",
420 "function": {
421 "name": tc.name,
422 "arguments": tc.input.to_string()
423 }
424 })
425 })
426 .collect();
427
428 let mut assistant_msg = json!({"role": "assistant"});
429 if let Some(ref text) = turn_result.text {
430 assistant_msg["content"] = Value::String(text.clone());
431 } else {
432 assistant_msg["content"] = Value::Null;
433 }
434 assistant_msg["tool_calls"] = Value::Array(assistant_tool_calls);
435 messages.push(assistant_msg);
436
437 let mut tool_results_debug: Vec<DebugToolResult> = Vec::new();
439
440 for tc in &turn_result.tool_calls {
441 debug!(
442 provider = self.adapter.provider_name(),
443 tool = %tc.name,
444 call_id = %tc.id,
445 "executing tool call"
446 );
447
448 let (content, is_error) =
449 match registry.execute(&tc.name, tc.input.clone()).await {
450 Some(Ok(output)) => (output.content, output.is_error),
451 Some(Err(err)) => (format!("Tool execution error: {err}"), true),
452 None => (format!("Unknown tool: {}", tc.name), true),
453 };
454
455 messages.push(json!({
456 "role": "tool",
457 "tool_call_id": tc.id,
458 "content": content
459 }));
460
461 if config.verbose {
462 tool_results_debug.push(DebugToolResult {
463 tool_use_id: Some(tc.id.clone()),
464 content: Value::String(content.clone()),
465 is_error,
466 });
467 }
468 }
469
470 if config.verbose
471 && let Some(last_msg) = state.debug_messages.last_mut()
472 {
473 last_msg.tool_results = tool_results_debug;
474 }
475
476 info!(
477 provider = self.adapter.provider_name(),
478 turn = turn + 1,
479 tools_executed = turn_result.tool_calls.len(),
480 "turn complete, continuing loop"
481 );
482 }
483
484 warn!(
485 provider = self.adapter.provider_name(),
486 max_turns, "max turns reached, returning last state"
487 );
488 Ok(state.into_output(Value::String(String::new())))
489 })
490 }
491}