ai_lib_rust/transport/
http.rs1use crate::protocol::ProtocolManifest;
2use crate::{BoxStream, Result};
3use bytes::Bytes;
4use futures::TryStreamExt;
5use keyring::Entry;
6use reqwest::Proxy;
7use std::env;
8use std::time::Duration;
9
10pub struct HttpTransport {
11 client: reqwest::Client,
12 base_url: String,
13 model: String,
14 api_key: Option<String>,
15}
16
17impl HttpTransport {
18 pub fn new(manifest: &ProtocolManifest, model: &str) -> Result<Self> {
23 Self::new_with_base_url(manifest, model, None)
24 }
25
26 pub fn new_with_base_url(
30 manifest: &ProtocolManifest,
31 model: &str,
32 base_url_override: Option<&str>,
33 ) -> Result<Self> {
34 let provider_id = manifest.provider_id.as_deref().unwrap_or(&manifest.id);
35 let api_key = Self::get_api_key(provider_id);
36
37 let base_url = base_url_override
39 .map(|s| s.to_string())
40 .unwrap_or_else(|| manifest.get_base_url().to_string());
41
42 let timeout_secs = env::var("AI_HTTP_TIMEOUT_SECS")
44 .ok()
45 .and_then(|s| s.parse::<u64>().ok())
46 .or_else(|| {
47 env::var("AI_TIMEOUT_SECS")
48 .ok()
49 .and_then(|s| s.parse::<u64>().ok())
50 })
51 .unwrap_or(300);
52
53 let mut builder = reqwest::Client::builder()
54 .timeout(Duration::from_secs(timeout_secs))
55 .pool_max_idle_per_host(
56 env::var("AI_HTTP_POOL_MAX_IDLE_PER_HOST")
57 .ok()
58 .and_then(|s| s.parse::<usize>().ok())
59 .unwrap_or(32),
60 )
61 .pool_idle_timeout(Some(Duration::from_secs(
62 env::var("AI_HTTP_POOL_IDLE_TIMEOUT_SECS")
63 .ok()
64 .and_then(|s| s.parse::<u64>().ok())
65 .unwrap_or(90),
66 )))
67 .http2_adaptive_window(true)
70 .http2_keep_alive_interval(Some(Duration::from_secs(30)))
71 .http2_keep_alive_timeout(Duration::from_secs(10));
72
73 if let Ok(proxy_url) = env::var("AI_PROXY_URL") {
74 if let Ok(proxy) = Proxy::all(&proxy_url) {
75 builder = builder.proxy(proxy);
76 }
77 }
78
79 let client = builder.build().map_err(|e| {
80 crate::Error::Transport(crate::transport::TransportError::Other(e.to_string()))
81 })?;
82
83 Ok(Self {
84 client,
85 base_url,
86 model: model.to_string(),
87 api_key,
88 })
89 }
90
91 fn get_api_key(provider_id: &str) -> Option<String> {
92 let entry = Entry::new("ai-protocol", provider_id).ok();
94 if let Some(entry) = entry {
95 if let Ok(key) = entry.get_password() {
96 return Some(key);
97 }
98 }
99
100 let env_var = format!("{}_API_KEY", provider_id.to_uppercase());
102 env::var(env_var).ok()
103 }
104
105 pub async fn execute_stream_response(
106 &self,
107 method: &str,
108 path: &str,
109 request_body: &serde_json::Value,
110 client_request_id: Option<&str>,
111 ) -> Result<reqwest::Response> {
112 let interpolated_path = path.replace("{model}", &self.model);
113 let url = format!("{}{}", self.base_url, interpolated_path);
114
115 let mut req = match method.to_uppercase().as_str() {
116 "POST" => self.client.post(&url).json(request_body),
117 "PUT" => self.client.put(&url).json(request_body),
118 "DELETE" => self.client.delete(&url),
119 _ => self.client.get(&url),
120 };
121
122 if let Some(key) = &self.api_key {
123 req = req.bearer_auth(key);
124 }
125
126 req = req.header("accept", "text/event-stream");
128 if let Some(id) = client_request_id {
129 req = req.header("x-ai-protocol-request-id", id);
131 }
132
133 req.send()
134 .await
135 .map_err(|e| crate::Error::Transport(crate::transport::TransportError::Http(e)))
136 }
137
138 pub async fn execute_stream<'a>(
139 &'a self,
140 method: &str,
141 path: &str,
142 request_body: &serde_json::Value,
143 ) -> Result<BoxStream<'a, Bytes>> {
144 let resp = self
145 .execute_stream_response(method, path, request_body, None)
146 .await?;
147
148 let byte_stream = resp
150 .bytes_stream()
151 .map_err(|e| crate::Error::Transport(crate::transport::TransportError::Http(e)));
152 Ok(Box::pin(byte_stream))
153 }
154
155 pub async fn execute_get(&self, path: &str) -> Result<serde_json::Value> {
156 self.execute_service(path, "GET", None, None).await
157 }
158
159 pub async fn execute_service(
160 &self,
161 path: &str,
162 method: &str,
163 headers: Option<&std::collections::HashMap<String, String>>,
164 query_params: Option<&std::collections::HashMap<String, String>>,
165 ) -> Result<serde_json::Value> {
166 let interpolated_path = path.replace("{model}", &self.model);
167 let url = format!("{}{}", self.base_url, interpolated_path);
168 let mut request = match method.to_uppercase().as_str() {
169 "POST" => self.client.post(&url),
170 "PUT" => self.client.put(&url),
171 "DELETE" => self.client.delete(&url),
172 _ => self.client.get(&url),
173 };
174
175 if let Some(key) = &self.api_key {
176 request = request.bearer_auth(key);
177 }
178
179 if let Some(headers) = headers {
180 for (k, v) in headers {
181 request = request.header(k, v);
182 }
183 }
184
185 if let Some(params) = query_params {
186 request = request.query(params);
187 }
188
189 let response = request
190 .send()
191 .await
192 .map_err(|e| crate::Error::Transport(crate::transport::TransportError::Http(e)))?;
193
194 let json = response
195 .json()
196 .await
197 .map_err(|e| crate::Error::Transport(crate::transport::TransportError::Http(e)))?;
198
199 Ok(json)
200 }
201}
202
203#[derive(Debug, thiserror::Error)]
204pub enum TransportError {
205 #[error("HTTP error: {0}")]
206 Http(#[from] reqwest::Error),
207
208 #[error("Transport error: {0}")]
209 Other(String),
210}