Skip to main content

ironflow_core/providers/http/
adapter.rs

1//! Core adapter trait and generic provider wrapper for HTTP-based LLM APIs.
2
3use std::time::{Duration, Instant};
4
5use reqwest::Client;
6use serde_json::Value;
7use tracing::info;
8
9use crate::error::AgentError;
10use crate::provider::{AgentConfig, AgentOutput, AgentProvider, DebugMessage, InvokeFuture};
11use crate::providers::http::sse::{SseDelta, collect_sse_stream};
12
13/// Normalized result of one API turn (one HTTP request/response cycle).
14#[derive(Debug)]
15pub struct TurnResult {
16    /// Free-form text content from the model.
17    pub text: Option<String>,
18    /// Tool calls requested by the model in this turn (unused in V1 - no tool execution).
19    #[allow(dead_code)]
20    pub tool_calls: Vec<HttpToolCall>,
21    /// Whether this is the final turn.
22    pub is_final: bool,
23    /// Extracted structured JSON value when a schema was requested.
24    pub structured_value: Option<Value>,
25    /// Token usage reported by the provider.
26    pub usage: HttpUsage,
27    /// Concrete model identifier returned by the provider.
28    pub model: Option<String>,
29}
30
31/// A single tool call requested by the model.
32#[derive(Debug, Clone)]
33#[allow(dead_code)]
34pub struct HttpToolCall {
35    /// Provider-assigned call identifier.
36    pub id: String,
37    /// Tool name.
38    pub name: String,
39    /// Input arguments as JSON.
40    pub input: Value,
41}
42
43/// Token usage from a single turn.
44#[derive(Debug, Default)]
45pub struct HttpUsage {
46    /// Input/prompt tokens consumed.
47    pub input_tokens: Option<u64>,
48    /// Output/completion tokens generated.
49    pub output_tokens: Option<u64>,
50}
51
52/// Internal trait implemented by each HTTP provider backend.
53///
54/// The generic [`HttpAgentProvider`] calls these methods to build requests,
55/// parse responses, and configure authentication. The agentic loop, retry,
56/// and timeout are handled by the wrapper.
57pub trait HttpAgentAdapter: Send + Sync + 'static {
58    /// Provider name for logging and errors.
59    fn provider_name(&self) -> &'static str;
60
61    /// Full endpoint URL for the given model.
62    fn endpoint_url(&self, model: &str) -> String;
63
64    /// Authentication and provider-specific headers.
65    fn auth_headers(&self) -> Vec<(String, String)>;
66
67    /// Build the initial JSON request body from an [`AgentConfig`].
68    fn build_request(&self, config: &AgentConfig) -> Result<Value, AgentError>;
69
70    /// Parse a non-streaming response body into a [`TurnResult`].
71    fn parse_response(&self, body: &Value, config: &AgentConfig) -> Result<TurnResult, AgentError>;
72
73    /// Parse a single SSE `data:` line into a streaming delta.
74    fn parse_sse_line(&self, line: &str) -> Option<SseDelta>;
75
76    /// Fold accumulated SSE deltas into a complete [`TurnResult`].
77    fn fold_sse_deltas(
78        &self,
79        deltas: Vec<SseDelta>,
80        config: &AgentConfig,
81    ) -> Result<TurnResult, AgentError>;
82
83    /// Compute cost in USD from token counts. Returns `None` if unknown.
84    fn compute_cost(&self, model: &str, input_tokens: u64, output_tokens: u64) -> Option<f64>;
85
86    /// Resolve model alias (e.g. "sonnet") to a provider-specific model ID.
87    fn resolve_model(&self, model: &str) -> String;
88}
89
90/// Default timeout for HTTP provider requests.
91const DEFAULT_TIMEOUT: Duration = Duration::from_secs(120);
92
93/// Generic HTTP provider that wraps any [`HttpAgentAdapter`].
94///
95/// Implements [`AgentProvider`] by delegating request construction and response
96/// parsing to the adapter while handling the HTTP transport, timeout, and
97/// single-turn execution loop.
98pub struct HttpAgentProvider<A: HttpAgentAdapter> {
99    adapter: A,
100    client: Client,
101    timeout: Duration,
102}
103
104impl<A: HttpAgentAdapter> HttpAgentProvider<A> {
105    /// Create a new HTTP provider with the given adapter.
106    pub fn new(adapter: A) -> Self {
107        let client = Client::builder()
108            .timeout(DEFAULT_TIMEOUT)
109            .build()
110            .expect("failed to build reqwest client");
111        Self {
112            adapter,
113            client,
114            timeout: DEFAULT_TIMEOUT,
115        }
116    }
117
118    /// Override the request timeout.
119    pub fn with_timeout(mut self, timeout: Duration) -> Self {
120        self.timeout = timeout;
121        self.client = Client::builder()
122            .timeout(timeout)
123            .build()
124            .expect("failed to build reqwest client");
125        self
126    }
127
128    async fn execute_turn(
129        &self,
130        request_body: &Value,
131        config: &AgentConfig,
132    ) -> Result<TurnResult, AgentError> {
133        let model = self.adapter.resolve_model(&config.model);
134        let url = self.adapter.endpoint_url(&model);
135        let headers = self.adapter.auth_headers();
136
137        let mut req = self.client.post(&url).json(request_body);
138        for (key, value) in &headers {
139            req = req.header(key, value);
140        }
141
142        let response = tokio::time::timeout(self.timeout, req.send())
143            .await
144            .map_err(|_| AgentError::Timeout {
145                limit: self.timeout,
146            })?
147            .map_err(|e| {
148                if e.is_timeout() {
149                    AgentError::Timeout {
150                        limit: self.timeout,
151                    }
152                } else {
153                    AgentError::HttpProvider {
154                        provider: self.adapter.provider_name().to_string(),
155                        status_code: 0,
156                        message: format!("connection failed: {e}"),
157                    }
158                }
159            })?;
160
161        let status = response.status().as_u16();
162
163        if status == 429 {
164            let retry_after = response
165                .headers()
166                .get("retry-after")
167                .and_then(|v| v.to_str().ok())
168                .and_then(|v| v.parse::<u64>().ok());
169            return Err(AgentError::RateLimited {
170                provider: self.adapter.provider_name().to_string(),
171                retry_after_secs: retry_after,
172            });
173        }
174
175        if status >= 400 {
176            let body_text = response.text().await.unwrap_or_default();
177            let message = serde_json::from_str::<Value>(&body_text)
178                .ok()
179                .and_then(|v| {
180                    v.get("error")
181                        .and_then(|e| e.get("message"))
182                        .and_then(|m| m.as_str())
183                        .map(String::from)
184                })
185                .unwrap_or(body_text);
186            return Err(AgentError::HttpProvider {
187                provider: self.adapter.provider_name().to_string(),
188                status_code: status,
189                message,
190            });
191        }
192
193        if config.verbose {
194            let deltas = collect_sse_stream(&self.adapter, response, self.timeout).await?;
195            self.adapter.fold_sse_deltas(deltas, config)
196        } else {
197            let body: Value = response
198                .json()
199                .await
200                .map_err(|e| AgentError::HttpProvider {
201                    provider: self.adapter.provider_name().to_string(),
202                    status_code: 0,
203                    message: format!("failed to parse response JSON: {e}"),
204                })?;
205            self.adapter.parse_response(&body, config)
206        }
207    }
208}
209
210impl<A: HttpAgentAdapter> AgentProvider for HttpAgentProvider<A> {
211    fn invoke<'a>(&'a self, config: &'a AgentConfig) -> InvokeFuture<'a> {
212        Box::pin(async move {
213            let start = Instant::now();
214
215            let request_body = self.adapter.build_request(config)?;
216            let turn_result = self.execute_turn(&request_body, config).await?;
217
218            let duration_ms = start.elapsed().as_millis() as u64;
219            let input_tokens = turn_result.usage.input_tokens.unwrap_or(0);
220            let output_tokens = turn_result.usage.output_tokens.unwrap_or(0);
221            let model_name = turn_result.model.clone();
222
223            let cost = model_name
224                .as_deref()
225                .and_then(|m| self.adapter.compute_cost(m, input_tokens, output_tokens));
226
227            let debug_messages = if config.verbose {
228                Some(vec![DebugMessage {
229                    text: turn_result.text.clone(),
230                    thinking: None,
231                    thinking_redacted: false,
232                    tool_calls: Vec::new(),
233                    tool_results: Vec::new(),
234                    stop_reason: if turn_result.is_final {
235                        Some("end_turn".to_string())
236                    } else {
237                        Some("tool_use".to_string())
238                    },
239                    input_tokens: turn_result.usage.input_tokens,
240                    output_tokens: turn_result.usage.output_tokens,
241                }])
242            } else {
243                None
244            };
245
246            let value = if let Some(structured) = turn_result.structured_value {
247                structured
248            } else {
249                turn_result
250                    .text
251                    .map(Value::String)
252                    .unwrap_or(Value::String(String::new()))
253            };
254
255            info!(
256                provider = self.adapter.provider_name(),
257                duration_ms, input_tokens, output_tokens, "invocation complete"
258            );
259
260            Ok(AgentOutput {
261                value,
262                session_id: None,
263                cost_usd: cost,
264                input_tokens: Some(input_tokens),
265                output_tokens: Some(output_tokens),
266                model: model_name,
267                duration_ms,
268                debug_messages,
269            })
270        })
271    }
272}