Skip to main content

ai_lib_rust/transport/
http.rs

1use crate::protocol::ProtocolManifest;
2use crate::{BoxStream, Result};
3use bytes::Bytes;
4use futures::TryStreamExt;
5use keyring::Entry;
6use reqwest::Proxy;
7use std::env;
8use std::time::Duration;
9
10pub struct HttpTransport {
11    client: reqwest::Client,
12    base_url: String,
13    model: String,
14    api_key: Option<String>,
15}
16
17impl HttpTransport {
18    /// Create a new HttpTransport from a manifest.
19    ///
20    /// If `base_url_override` is provided, it will be used instead of the manifest's base_url.
21    /// This is useful for testing with mock servers.
22    pub fn new(manifest: &ProtocolManifest, model: &str) -> Result<Self> {
23        Self::new_with_base_url(manifest, model, None)
24    }
25
26    /// Create a new HttpTransport with an optional base_url override.
27    ///
28    /// This is primarily for testing, allowing injection of mock server URLs.
29    pub fn new_with_base_url(
30        manifest: &ProtocolManifest,
31        model: &str,
32        base_url_override: Option<&str>,
33    ) -> Result<Self> {
34        let provider_id = manifest.provider_id.as_deref().unwrap_or(&manifest.id);
35        let api_key = Self::get_api_key(provider_id);
36
37        // Use override if provided, otherwise use manifest endpoint.base_url
38        let base_url = base_url_override
39            .map(|s| s.to_string())
40            .unwrap_or_else(|| manifest.get_base_url().to_string());
41
42        // Minimal production-friendly defaults (env-overridable).
43        let timeout_secs = env::var("AI_HTTP_TIMEOUT_SECS")
44            .ok()
45            .and_then(|s| s.parse::<u64>().ok())
46            .or_else(|| {
47                env::var("AI_TIMEOUT_SECS")
48                    .ok()
49                    .and_then(|s| s.parse::<u64>().ok())
50            })
51            .unwrap_or(300);
52
53        let mut builder = reqwest::Client::builder()
54            .timeout(Duration::from_secs(timeout_secs))
55            .pool_max_idle_per_host(
56                env::var("AI_HTTP_POOL_MAX_IDLE_PER_HOST")
57                    .ok()
58                    .and_then(|s| s.parse::<usize>().ok())
59                    .unwrap_or(32),
60            )
61            .pool_idle_timeout(Some(Duration::from_secs(
62                env::var("AI_HTTP_POOL_IDLE_TIMEOUT_SECS")
63                    .ok()
64                    .and_then(|s| s.parse::<u64>().ok())
65                    .unwrap_or(90),
66            )))
67            // Conservative HTTP/2 keepalive defaults for long-lived connections.
68            // (No extra env knobs for now to keep developer UI simple.)
69            .http2_adaptive_window(true)
70            .http2_keep_alive_interval(Some(Duration::from_secs(30)))
71            .http2_keep_alive_timeout(Duration::from_secs(10));
72
73        if let Ok(proxy_url) = env::var("AI_PROXY_URL") {
74            if let Ok(proxy) = Proxy::all(&proxy_url) {
75                builder = builder.proxy(proxy);
76            }
77        }
78
79        let client = builder.build().map_err(|e| {
80            crate::Error::Transport(crate::transport::TransportError::Other(e.to_string()))
81        })?;
82
83        Ok(Self {
84            client,
85            base_url,
86            model: model.to_string(),
87            api_key,
88        })
89    }
90
91    fn get_api_key(provider_id: &str) -> Option<String> {
92        // 1. Try Keyring
93        let entry = Entry::new("ai-protocol", provider_id).ok();
94        if let Some(entry) = entry {
95            if let Ok(key) = entry.get_password() {
96                return Some(key);
97            }
98        }
99
100        // 2. Try Environment Variable (PROVIDER_API_KEY)
101        let env_var = format!("{}_API_KEY", provider_id.to_uppercase());
102        env::var(env_var).ok()
103    }
104
105    pub async fn execute_stream_response(
106        &self,
107        method: &str,
108        path: &str,
109        request_body: &serde_json::Value,
110        client_request_id: Option<&str>,
111    ) -> Result<reqwest::Response> {
112        let interpolated_path = path.replace("{model}", &self.model);
113        let url = format!("{}{}", self.base_url, interpolated_path);
114
115        let mut req = match method.to_uppercase().as_str() {
116            "POST" => self.client.post(&url).json(request_body),
117            "PUT" => self.client.put(&url).json(request_body),
118            "DELETE" => self.client.delete(&url),
119            _ => self.client.get(&url),
120        };
121
122        if let Some(key) = &self.api_key {
123            req = req.bearer_auth(key);
124        }
125
126        // Prefer SSE for providers that support it
127        req = req.header("accept", "text/event-stream");
128        if let Some(id) = client_request_id {
129            // Our own correlation id. Providers may ignore it, but applications can use it for linkage.
130            req = req.header("x-ai-protocol-request-id", id);
131        }
132
133        req.send()
134            .await
135            .map_err(|e| crate::Error::Transport(crate::transport::TransportError::Http(e)))
136    }
137
138    pub async fn execute_stream<'a>(
139        &'a self,
140        method: &str,
141        path: &str,
142        request_body: &serde_json::Value,
143    ) -> Result<BoxStream<'a, Bytes>> {
144        let resp = self
145            .execute_stream_response(method, path, request_body, None)
146            .await?;
147
148        // Convert reqwest bytes stream to our unified BoxStream
149        let byte_stream = resp
150            .bytes_stream()
151            .map_err(|e| crate::Error::Transport(crate::transport::TransportError::Http(e)));
152        Ok(Box::pin(byte_stream))
153    }
154
155    pub async fn execute_get(&self, path: &str) -> Result<serde_json::Value> {
156        self.execute_service(path, "GET", None, None).await
157    }
158
159    pub async fn execute_service(
160        &self,
161        path: &str,
162        method: &str,
163        headers: Option<&std::collections::HashMap<String, String>>,
164        query_params: Option<&std::collections::HashMap<String, String>>,
165    ) -> Result<serde_json::Value> {
166        let interpolated_path = path.replace("{model}", &self.model);
167        let url = format!("{}{}", self.base_url, interpolated_path);
168        let mut request = match method.to_uppercase().as_str() {
169            "POST" => self.client.post(&url),
170            "PUT" => self.client.put(&url),
171            "DELETE" => self.client.delete(&url),
172            _ => self.client.get(&url),
173        };
174
175        if let Some(key) = &self.api_key {
176            request = request.bearer_auth(key);
177        }
178
179        if let Some(headers) = headers {
180            for (k, v) in headers {
181                request = request.header(k, v);
182            }
183        }
184
185        if let Some(params) = query_params {
186            request = request.query(params);
187        }
188
189        let response = request
190            .send()
191            .await
192            .map_err(|e| crate::Error::Transport(crate::transport::TransportError::Http(e)))?;
193
194        let json = response
195            .json()
196            .await
197            .map_err(|e| crate::Error::Transport(crate::transport::TransportError::Http(e)))?;
198
199        Ok(json)
200    }
201}
202
203#[derive(Debug, thiserror::Error)]
204pub enum TransportError {
205    #[error("HTTP error: {0}")]
206    Http(#[from] reqwest::Error),
207
208    #[error("Transport error: {0}")]
209    Other(String),
210}