Skip to main content

ironflow_core/providers/http/
adapter.rs

1//! Core adapter trait and generic provider wrapper for HTTP-based LLM APIs.
2
3use std::time::{Duration, Instant};
4
5use reqwest::Client;
6use serde_json::{Value, json};
7use tracing::{debug, info, warn};
8
9use crate::error::AgentError;
10use crate::provider::{
11    AgentConfig, AgentOutput, AgentProvider, DebugMessage, DebugToolCall, DebugToolResult,
12    InvokeFuture,
13};
14use crate::providers::http::sse::{SseDelta, collect_sse_stream};
15use crate::providers::http::tools::ToolRegistry;
16
17/// Normalized result of one API turn (one HTTP request/response cycle).
18#[derive(Debug)]
19pub struct TurnResult {
20    /// Free-form text content from the model.
21    pub text: Option<String>,
22    /// Tool calls requested by the model in this turn (unused in V1 - no tool execution).
23    #[allow(dead_code)]
24    pub tool_calls: Vec<HttpToolCall>,
25    /// Whether this is the final turn.
26    pub is_final: bool,
27    /// Extracted structured JSON value when a schema was requested.
28    pub structured_value: Option<Value>,
29    /// Token usage reported by the provider.
30    pub usage: HttpUsage,
31    /// Concrete model identifier returned by the provider.
32    pub model: Option<String>,
33}
34
35/// A single tool call requested by the model.
36#[derive(Debug, Clone)]
37#[allow(dead_code)]
38pub struct HttpToolCall {
39    /// Provider-assigned call identifier.
40    pub id: String,
41    /// Tool name.
42    pub name: String,
43    /// Input arguments as JSON.
44    pub input: Value,
45}
46
47/// Token usage from a single turn.
48#[derive(Debug, Default)]
49pub struct HttpUsage {
50    /// Input/prompt tokens consumed.
51    pub input_tokens: Option<u64>,
52    /// Output/completion tokens generated.
53    pub output_tokens: Option<u64>,
54}
55
56/// Internal trait implemented by each HTTP provider backend.
57///
58/// The generic [`HttpAgentProvider`] calls these methods to build requests,
59/// parse responses, and configure authentication. The agentic loop, retry,
60/// and timeout are handled by the wrapper.
61pub trait HttpAgentAdapter: Send + Sync + 'static {
62    /// Provider name for logging and errors.
63    fn provider_name(&self) -> &'static str;
64
65    /// Full endpoint URL for the given model.
66    fn endpoint_url(&self, model: &str) -> String;
67
68    /// Authentication and provider-specific headers.
69    fn auth_headers(&self) -> Vec<(String, String)>;
70
71    /// Build the initial JSON request body from an [`AgentConfig`].
72    fn build_request(&self, config: &AgentConfig) -> Result<Value, AgentError>;
73
74    /// Parse a non-streaming response body into a [`TurnResult`].
75    fn parse_response(&self, body: &Value, config: &AgentConfig) -> Result<TurnResult, AgentError>;
76
77    /// Parse a single SSE `data:` line into a streaming delta.
78    fn parse_sse_line(&self, line: &str) -> Option<SseDelta>;
79
80    /// Fold accumulated SSE deltas into a complete [`TurnResult`].
81    fn fold_sse_deltas(
82        &self,
83        deltas: Vec<SseDelta>,
84        config: &AgentConfig,
85    ) -> Result<TurnResult, AgentError>;
86
87    /// Compute cost in USD from token counts. Returns `None` if unknown.
88    fn compute_cost(&self, model: &str, input_tokens: u64, output_tokens: u64) -> Option<f64>;
89
90    /// Resolve model alias (e.g. "sonnet") to a provider-specific model ID.
91    fn resolve_model(&self, model: &str) -> String;
92}
93
94/// Default timeout for HTTP provider requests.
95const DEFAULT_TIMEOUT: Duration = Duration::from_secs(120);
96
97/// Generic HTTP provider that wraps any [`HttpAgentAdapter`].
98///
99/// Implements [`AgentProvider`] by delegating request construction and response
100/// parsing to the adapter while handling the HTTP transport, timeout, and
101/// agentic execution loop.
102///
103/// When a [`ToolRegistry`] is attached via [`with_tools`](Self::with_tools),
104/// the provider runs a multi-turn agentic loop: executing tool calls locally
105/// and feeding results back to the model until it produces a final response
106/// (or hits `max_turns` / `max_budget_usd` limits).
107///
108/// Without a registry, the provider behaves as single-turn (backward-compatible).
109pub struct HttpAgentProvider<A: HttpAgentAdapter> {
110    adapter: A,
111    client: Client,
112    timeout: Duration,
113    tool_registry: Option<ToolRegistry>,
114}
115
116impl<A: HttpAgentAdapter> HttpAgentProvider<A> {
117    /// Create a new HTTP provider with the given adapter.
118    pub fn new(adapter: A) -> Self {
119        let client = Client::builder()
120            .timeout(DEFAULT_TIMEOUT)
121            .build()
122            .expect("failed to build reqwest client");
123        Self {
124            adapter,
125            client,
126            timeout: DEFAULT_TIMEOUT,
127            tool_registry: None,
128        }
129    }
130
131    /// Attach a tool registry to enable multi-turn agentic execution.
132    ///
133    /// When tools are registered, the provider will:
134    /// 1. Include the tools in every request (OpenAI `tools` format).
135    /// 2. Execute tool calls returned by the model.
136    /// 3. Loop until the model produces a final response or limits are hit.
137    pub fn with_tools(mut self, registry: ToolRegistry) -> Self {
138        self.tool_registry = Some(registry);
139        self
140    }
141
142    /// Override the request timeout.
143    pub fn with_timeout(mut self, timeout: Duration) -> Self {
144        self.timeout = timeout;
145        self.client = Client::builder()
146            .timeout(timeout)
147            .build()
148            .expect("failed to build reqwest client");
149        self
150    }
151
152    async fn execute_turn(
153        &self,
154        request_body: &Value,
155        config: &AgentConfig,
156    ) -> Result<TurnResult, AgentError> {
157        let model = self.adapter.resolve_model(&config.model);
158        let url = self.adapter.endpoint_url(&model);
159        let headers = self.adapter.auth_headers();
160
161        let mut req = self.client.post(&url).json(request_body);
162        for (key, value) in &headers {
163            req = req.header(key, value);
164        }
165
166        let response = tokio::time::timeout(self.timeout, req.send())
167            .await
168            .map_err(|_| AgentError::Timeout {
169                limit: self.timeout,
170            })?
171            .map_err(|e| {
172                if e.is_timeout() {
173                    AgentError::Timeout {
174                        limit: self.timeout,
175                    }
176                } else {
177                    AgentError::HttpProvider {
178                        provider: self.adapter.provider_name().to_string(),
179                        status_code: 0,
180                        message: format!("connection failed: {e}"),
181                    }
182                }
183            })?;
184
185        let status = response.status().as_u16();
186
187        if status == 429 {
188            let retry_after = response
189                .headers()
190                .get("retry-after")
191                .and_then(|v| v.to_str().ok())
192                .and_then(|v| v.parse::<u64>().ok());
193            return Err(AgentError::RateLimited {
194                provider: self.adapter.provider_name().to_string(),
195                retry_after_secs: retry_after,
196            });
197        }
198
199        if status >= 400 {
200            let body_text = response.text().await.unwrap_or_default();
201            let message = serde_json::from_str::<Value>(&body_text)
202                .ok()
203                .and_then(|v| {
204                    v.get("error")
205                        .and_then(|e| e.get("message"))
206                        .and_then(|m| m.as_str())
207                        .map(String::from)
208                })
209                .unwrap_or(body_text);
210            return Err(AgentError::HttpProvider {
211                provider: self.adapter.provider_name().to_string(),
212                status_code: status,
213                message,
214            });
215        }
216
217        if config.verbose {
218            let deltas = collect_sse_stream(&self.adapter, response, self.timeout).await?;
219            self.adapter.fold_sse_deltas(deltas, config)
220        } else {
221            let body: Value = response
222                .json()
223                .await
224                .map_err(|e| AgentError::HttpProvider {
225                    provider: self.adapter.provider_name().to_string(),
226                    status_code: 0,
227                    message: format!("failed to parse response JSON: {e}"),
228                })?;
229            self.adapter.parse_response(&body, config)
230        }
231    }
232}
233
234/// Accumulates usage across turns and builds the final [`AgentOutput`].
235struct LoopState {
236    start: Instant,
237    total_input_tokens: u64,
238    total_output_tokens: u64,
239    total_cost: f64,
240    model_name: Option<String>,
241    debug_messages: Vec<DebugMessage>,
242    verbose: bool,
243}
244
245impl LoopState {
246    fn new(start: Instant, verbose: bool) -> Self {
247        Self {
248            start,
249            total_input_tokens: 0,
250            total_output_tokens: 0,
251            total_cost: 0.0,
252            model_name: None,
253            debug_messages: Vec::new(),
254            verbose,
255        }
256    }
257
258    fn into_output(self, value: Value) -> AgentOutput {
259        AgentOutput {
260            value,
261            session_id: None,
262            cost_usd: if self.total_cost > 0.0 {
263                Some(self.total_cost)
264            } else {
265                None
266            },
267            input_tokens: Some(self.total_input_tokens),
268            output_tokens: Some(self.total_output_tokens),
269            model: self.model_name,
270            duration_ms: self.start.elapsed().as_millis() as u64,
271            debug_messages: if self.verbose {
272                Some(self.debug_messages)
273            } else {
274                None
275            },
276        }
277    }
278}
279
280/// Extract the final value from a turn result (structured or text).
281fn extract_value(turn_result: &TurnResult) -> Value {
282    if let Some(ref structured) = turn_result.structured_value {
283        structured.clone()
284    } else {
285        turn_result
286            .text
287            .as_ref()
288            .map(|t| Value::String(t.clone()))
289            .unwrap_or(Value::String(String::new()))
290    }
291}
292
293/// Extract the text value from a turn result (ignoring structured).
294fn extract_text_value(turn_result: &TurnResult) -> Value {
295    turn_result
296        .text
297        .as_ref()
298        .map(|t| Value::String(t.clone()))
299        .unwrap_or(Value::String(String::new()))
300}
301
302impl<A: HttpAgentAdapter> AgentProvider for HttpAgentProvider<A> {
303    fn invoke<'a>(&'a self, config: &'a AgentConfig) -> InvokeFuture<'a> {
304        Box::pin(async move {
305            let mut request_body = self.adapter.build_request(config)?;
306
307            // Inject tools into the request if a registry is available
308            if let Some(ref registry) = self.tool_registry
309                && !registry.is_empty()
310            {
311                let tools_array = registry.to_openai_tools();
312                request_body["tools"] = Value::Array(tools_array);
313            }
314
315            let max_turns = config.max_turns.unwrap_or(25) as usize;
316            let max_budget = config.max_budget_usd.unwrap_or(f64::MAX);
317            let mut state = LoopState::new(Instant::now(), config.verbose);
318
319            // Messages array for multi-turn
320            let mut messages: Vec<Value> = request_body
321                .get("messages")
322                .and_then(|m| m.as_array())
323                .cloned()
324                .unwrap_or_default();
325
326            for turn in 0..max_turns {
327                request_body["messages"] = Value::Array(messages.clone());
328                let turn_result = self.execute_turn(&request_body, config).await?;
329
330                // Accumulate usage
331                let turn_input = turn_result.usage.input_tokens.unwrap_or(0);
332                let turn_output = turn_result.usage.output_tokens.unwrap_or(0);
333                state.total_input_tokens += turn_input;
334                state.total_output_tokens += turn_output;
335
336                if state.model_name.is_none() {
337                    state.model_name = turn_result.model.clone();
338                }
339
340                if let Some(ref model) = state.model_name
341                    && let Some(turn_cost) =
342                        self.adapter.compute_cost(model, turn_input, turn_output)
343                {
344                    state.total_cost += turn_cost;
345                }
346
347                // Record debug trace for this turn
348                if config.verbose {
349                    let tool_calls_debug: Vec<DebugToolCall> = turn_result
350                        .tool_calls
351                        .iter()
352                        .map(|tc| DebugToolCall {
353                            id: Some(tc.id.clone()),
354                            name: tc.name.clone(),
355                            input: tc.input.clone(),
356                        })
357                        .collect();
358
359                    state.debug_messages.push(DebugMessage {
360                        text: turn_result.text.clone(),
361                        thinking: None,
362                        thinking_redacted: false,
363                        tool_calls: tool_calls_debug,
364                        tool_results: Vec::new(),
365                        stop_reason: if turn_result.is_final {
366                            Some("end_turn".to_string())
367                        } else {
368                            Some("tool_use".to_string())
369                        },
370                        input_tokens: Some(turn_input),
371                        output_tokens: Some(turn_output),
372                    });
373                }
374
375                // Final response (no tool calls) -> return
376                if turn_result.is_final || turn_result.tool_calls.is_empty() {
377                    info!(
378                        provider = self.adapter.provider_name(),
379                        turns = turn + 1,
380                        duration_ms = state.start.elapsed().as_millis() as u64,
381                        input_tokens = state.total_input_tokens,
382                        output_tokens = state.total_output_tokens,
383                        "invocation complete"
384                    );
385                    return Ok(state.into_output(extract_value(&turn_result)));
386                }
387
388                // Tool calls but no registry -> return text (backward compat)
389                let registry = match self.tool_registry {
390                    Some(ref r) => r,
391                    None => {
392                        warn!(
393                            provider = self.adapter.provider_name(),
394                            tool_calls = turn_result.tool_calls.len(),
395                            "model requested tool calls but no registry attached, returning text"
396                        );
397                        return Ok(state.into_output(extract_text_value(&turn_result)));
398                    }
399                };
400
401                // Budget exceeded -> stop
402                if state.total_cost >= max_budget {
403                    warn!(
404                        provider = self.adapter.provider_name(),
405                        cost = state.total_cost,
406                        budget = max_budget,
407                        "budget exceeded, stopping agentic loop"
408                    );
409                    return Ok(state.into_output(extract_text_value(&turn_result)));
410                }
411
412                // Build assistant message with tool_calls for conversation history
413                let assistant_tool_calls: Vec<Value> = turn_result
414                    .tool_calls
415                    .iter()
416                    .map(|tc| {
417                        json!({
418                            "id": tc.id,
419                            "type": "function",
420                            "function": {
421                                "name": tc.name,
422                                "arguments": tc.input.to_string()
423                            }
424                        })
425                    })
426                    .collect();
427
428                let mut assistant_msg = json!({"role": "assistant"});
429                if let Some(ref text) = turn_result.text {
430                    assistant_msg["content"] = Value::String(text.clone());
431                } else {
432                    assistant_msg["content"] = Value::Null;
433                }
434                assistant_msg["tool_calls"] = Value::Array(assistant_tool_calls);
435                messages.push(assistant_msg);
436
437                // Execute each tool call
438                let mut tool_results_debug: Vec<DebugToolResult> = Vec::new();
439
440                for tc in &turn_result.tool_calls {
441                    debug!(
442                        provider = self.adapter.provider_name(),
443                        tool = %tc.name,
444                        call_id = %tc.id,
445                        "executing tool call"
446                    );
447
448                    let (content, is_error) =
449                        match registry.execute(&tc.name, tc.input.clone()).await {
450                            Some(Ok(output)) => (output.content, output.is_error),
451                            Some(Err(err)) => (format!("Tool execution error: {err}"), true),
452                            None => (format!("Unknown tool: {}", tc.name), true),
453                        };
454
455                    messages.push(json!({
456                        "role": "tool",
457                        "tool_call_id": tc.id,
458                        "content": content
459                    }));
460
461                    if config.verbose {
462                        tool_results_debug.push(DebugToolResult {
463                            tool_use_id: Some(tc.id.clone()),
464                            content: Value::String(content.clone()),
465                            is_error,
466                        });
467                    }
468                }
469
470                if config.verbose
471                    && let Some(last_msg) = state.debug_messages.last_mut()
472                {
473                    last_msg.tool_results = tool_results_debug;
474                }
475
476                info!(
477                    provider = self.adapter.provider_name(),
478                    turn = turn + 1,
479                    tools_executed = turn_result.tool_calls.len(),
480                    "turn complete, continuing loop"
481                );
482            }
483
484            warn!(
485                provider = self.adapter.provider_name(),
486                max_turns, "max turns reached, returning last state"
487            );
488            Ok(state.into_output(Value::String(String::new())))
489        })
490    }
491}