Skip to main content

ai_lib_rust/transport/
http.rs

1use crate::protocol::ProtocolManifest;
2use crate::{BoxStream, Result};
3use bytes::Bytes;
4use futures::TryStreamExt;
5use keyring::Entry;
6use reqwest::Proxy;
7use std::env;
8use std::time::Duration;
9
10pub struct HttpTransport {
11    client: reqwest::Client,
12    base_url: String,
13    model: String,
14    api_key: Option<String>,
15}
16
17impl HttpTransport {
18    /// Create a new HttpTransport from a manifest.
19    ///
20    /// If `base_url_override` is provided, it will be used instead of the manifest's base_url.
21    /// This is useful for testing with mock servers.
22    pub fn new(manifest: &ProtocolManifest, model: &str) -> Result<Self> {
23        Self::new_with_base_url(manifest, model, None)
24    }
25
26    /// Create a new HttpTransport with an optional base_url override.
27    ///
28    /// This is primarily for testing, allowing injection of mock server URLs.
29    pub fn new_with_base_url(
30        manifest: &ProtocolManifest,
31        model: &str,
32        base_url_override: Option<&str>,
33    ) -> Result<Self> {
34        let provider_id = manifest.provider_id.as_deref().unwrap_or(&manifest.id);
35        let api_key = Self::get_api_key(provider_id);
36
37        // Use override if provided, otherwise use manifest endpoint.base_url
38        let base_url = base_url_override
39            .map(|s| s.to_string())
40            .unwrap_or_else(|| manifest.get_base_url().to_string());
41
42        // Minimal production-friendly defaults (env-overridable).
43        let timeout_secs = env::var("AI_HTTP_TIMEOUT_SECS")
44            .ok()
45            .and_then(|s| s.parse::<u64>().ok())
46            .or_else(|| {
47                env::var("AI_TIMEOUT_SECS")
48                    .ok()
49                    .and_then(|s| s.parse::<u64>().ok())
50            })
51            .unwrap_or(300);
52
53        let mut builder = reqwest::Client::builder()
54            .timeout(Duration::from_secs(timeout_secs))
55            .pool_max_idle_per_host(
56                env::var("AI_HTTP_POOL_MAX_IDLE_PER_HOST")
57                    .ok()
58                    .and_then(|s| s.parse::<usize>().ok())
59                    .unwrap_or(32),
60            )
61            .pool_idle_timeout(Some(Duration::from_secs(
62                env::var("AI_HTTP_POOL_IDLE_TIMEOUT_SECS")
63                    .ok()
64                    .and_then(|s| s.parse::<u64>().ok())
65                    .unwrap_or(90),
66            )))
67            // Conservative HTTP/2 keepalive defaults for long-lived connections.
68            // (No extra env knobs for now to keep developer UI simple.)
69            .http2_adaptive_window(true)
70            .http2_keep_alive_interval(Some(Duration::from_secs(30)))
71            .http2_keep_alive_timeout(Duration::from_secs(10));
72
73        if let Ok(proxy_url) = env::var("AI_PROXY_URL") {
74            if let Ok(proxy) = Proxy::all(&proxy_url) {
75                builder = builder.proxy(proxy);
76            }
77        }
78
79        let client = builder.build().map_err(|e| {
80            crate::Error::Transport(crate::transport::TransportError::Other(e.to_string()))
81        })?;
82
83        Ok(Self {
84            client,
85            base_url,
86            model: model.to_string(),
87            api_key,
88        })
89    }
90
91    fn get_api_key(provider_id: &str) -> Option<String> {
92        // 1. Try Environment Variable (PROVIDER_API_KEY)
93        let env_var = format!("{}_API_KEY", provider_id.to_uppercase());
94        if let Ok(key) = env::var(&env_var) {
95            tracing::debug!(
96                "Loaded API key for provider '{}' from environment variable '{}'. Length: {}. First char: '{}', Last char: '{}'",
97                provider_id,
98                env_var,
99                key.len(),
100                key.chars().next().unwrap_or('?'),
101                key.chars().last().unwrap_or('?')
102            );
103            tracing::debug!("Key bytes: {:?}", key.as_bytes());
104            return Some(key);
105        }
106
107        // 2. Try Keyring
108        let entry = Entry::new("ai-protocol", provider_id).ok();
109        if let Some(entry) = entry {
110            if let Ok(key) = entry.get_password() {
111                tracing::debug!(
112                    "Loaded API key for provider '{}' from system keyring",
113                    provider_id
114                );
115                return Some(key);
116            }
117        }
118
119        tracing::warn!(
120            "No API key found for provider '{}' (checked env var '{}' and keyring)",
121            provider_id,
122            env_var
123        );
124        None
125    }
126
127    pub async fn execute_stream_response(
128        &self,
129        method: &str,
130        path: &str,
131        request_body: &serde_json::Value,
132        client_request_id: Option<&str>,
133    ) -> Result<reqwest::Response> {
134        let interpolated_path = path.replace("{model}", &self.model);
135        let url = format!("{}{}", self.base_url, interpolated_path);
136
137        let mut req = match method.to_uppercase().as_str() {
138            "POST" => self.client.post(&url).json(request_body),
139            "PUT" => self.client.put(&url).json(request_body),
140            "DELETE" => self.client.delete(&url),
141            _ => self.client.get(&url),
142        };
143
144        if let Some(key) = &self.api_key {
145            req = req.bearer_auth(key);
146        }
147
148        // Prefer SSE for providers that support it
149        req = req.header("accept", "text/event-stream");
150        if let Some(id) = client_request_id {
151            // Our own correlation id. Providers may ignore it, but applications can use it for linkage.
152            req = req.header("x-ai-protocol-request-id", id);
153        }
154
155        req.send()
156            .await
157            .map_err(|e| crate::Error::Transport(crate::transport::TransportError::Http(e)))
158    }
159
160    pub async fn execute_stream<'a>(
161        &'a self,
162        method: &str,
163        path: &str,
164        request_body: &serde_json::Value,
165    ) -> Result<BoxStream<'a, Bytes>> {
166        let resp = self
167            .execute_stream_response(method, path, request_body, None)
168            .await?;
169
170        // Convert reqwest bytes stream to our unified BoxStream
171        let byte_stream = resp
172            .bytes_stream()
173            .map_err(|e| crate::Error::Transport(crate::transport::TransportError::Http(e)));
174        Ok(Box::pin(byte_stream))
175    }
176
177    pub async fn execute_get(&self, path: &str) -> Result<serde_json::Value> {
178        self.execute_service(path, "GET", None, None).await
179    }
180
181    pub async fn execute_service(
182        &self,
183        path: &str,
184        method: &str,
185        headers: Option<&std::collections::HashMap<String, String>>,
186        query_params: Option<&std::collections::HashMap<String, String>>,
187    ) -> Result<serde_json::Value> {
188        let interpolated_path = path.replace("{model}", &self.model);
189        let url = format!("{}{}", self.base_url, interpolated_path);
190        let mut request = match method.to_uppercase().as_str() {
191            "POST" => self.client.post(&url),
192            "PUT" => self.client.put(&url),
193            "DELETE" => self.client.delete(&url),
194            _ => self.client.get(&url),
195        };
196
197        if let Some(key) = &self.api_key {
198            request = request.bearer_auth(key);
199        }
200
201        if let Some(headers) = headers {
202            for (k, v) in headers {
203                request = request.header(k, v);
204            }
205        }
206
207        if let Some(params) = query_params {
208            request = request.query(params);
209        }
210
211        let response = request
212            .send()
213            .await
214            .map_err(|e| crate::Error::Transport(crate::transport::TransportError::Http(e)))?;
215
216        let json = response
217            .json()
218            .await
219            .map_err(|e| crate::Error::Transport(crate::transport::TransportError::Http(e)))?;
220
221        Ok(json)
222    }
223}
224
225#[derive(Debug, thiserror::Error)]
226pub enum TransportError {
227    #[error("HTTP error: {0}")]
228    Http(#[from] reqwest::Error),
229
230    #[error("Transport error: {0}")]
231    Other(String),
232}