1use crate::session_context::SESSION_ID_HEADER;
2use anyhow::Result;
3use async_trait::async_trait;
4use reqwest::{
5 header::{HeaderMap, HeaderName, HeaderValue},
6 Certificate, Client, Identity, Response, StatusCode,
7};
8use serde_json::Value;
9use std::fmt;
10use std::fs::read_to_string;
11use std::path::PathBuf;
12use std::time::Duration;
13
14pub struct ApiClient {
15 client: Client,
16 host: String,
17 auth: AuthMethod,
18 default_headers: HeaderMap,
19 timeout: Duration,
20 tls_config: Option<TlsConfig>,
21}
22
23pub enum AuthMethod {
24 BearerToken(String),
25 ApiKey {
26 header_name: String,
27 key: String,
28 },
29 #[allow(dead_code)]
30 OAuth(OAuthConfig),
31 Custom(Box<dyn AuthProvider>),
32}
33
34#[derive(Debug, Clone)]
35pub struct TlsCertKeyPair {
36 pub cert_path: PathBuf,
37 pub key_path: PathBuf,
38}
39
40#[derive(Debug, Clone)]
41pub struct TlsConfig {
42 pub client_identity: Option<TlsCertKeyPair>,
43 pub ca_cert_path: Option<PathBuf>,
44}
45
46impl TlsConfig {
47 pub fn new() -> Self {
48 Self {
49 client_identity: None,
50 ca_cert_path: None,
51 }
52 }
53
54 pub fn from_config() -> Result<Option<Self>> {
55 let config = crate::config::Config::global();
56 let mut tls_config = TlsConfig::new();
57 let mut has_tls_config = false;
58
59 let client_cert_path = config.get_param::<String>("ASTER_CLIENT_CERT_PATH").ok();
60 let client_key_path = config.get_param::<String>("ASTER_CLIENT_KEY_PATH").ok();
61
62 match (client_cert_path, client_key_path) {
64 (Some(cert_path), Some(key_path)) => {
65 tls_config = tls_config.with_client_cert_and_key(
66 std::path::PathBuf::from(cert_path),
67 std::path::PathBuf::from(key_path),
68 );
69 has_tls_config = true;
70 }
71 (Some(_), None) => {
72 return Err(anyhow::anyhow!(
73 "Client certificate provided (ASTER_CLIENT_CERT_PATH) but no private key (ASTER_CLIENT_KEY_PATH)"
74 ));
75 }
76 (None, Some(_)) => {
77 return Err(anyhow::anyhow!(
78 "Client private key provided (ASTER_CLIENT_KEY_PATH) but no certificate (ASTER_CLIENT_CERT_PATH)"
79 ));
80 }
81 (None, None) => {}
82 }
83
84 if let Ok(ca_cert_path) = config.get_param::<String>("ASTER_CA_CERT_PATH") {
85 tls_config = tls_config.with_ca_cert(std::path::PathBuf::from(ca_cert_path));
86 has_tls_config = true;
87 }
88
89 if has_tls_config {
90 Ok(Some(tls_config))
91 } else {
92 Ok(None)
93 }
94 }
95
96 pub fn with_client_cert_and_key(mut self, cert_path: PathBuf, key_path: PathBuf) -> Self {
97 self.client_identity = Some(TlsCertKeyPair {
98 cert_path,
99 key_path,
100 });
101 self
102 }
103
104 pub fn with_ca_cert(mut self, path: PathBuf) -> Self {
105 self.ca_cert_path = Some(path);
106 self
107 }
108
109 pub fn is_configured(&self) -> bool {
110 self.client_identity.is_some() || self.ca_cert_path.is_some()
111 }
112
113 pub fn load_identity(&self) -> Result<Option<Identity>> {
114 if let Some(cert_key_pair) = &self.client_identity {
115 let cert_pem = read_to_string(&cert_key_pair.cert_path)
116 .map_err(|e| anyhow::anyhow!("Failed to read client certificate: {}", e))?;
117 let key_pem = read_to_string(&cert_key_pair.key_path)
118 .map_err(|e| anyhow::anyhow!("Failed to read client private key: {}", e))?;
119
120 let combined_pem = format!("{}\n{}", cert_pem, key_pem);
122
123 let identity = Identity::from_pem(combined_pem.as_bytes()).map_err(|e| {
124 anyhow::anyhow!("Failed to create identity from cert and key: {}", e)
125 })?;
126
127 Ok(Some(identity))
128 } else {
129 Ok(None)
130 }
131 }
132
133 pub fn load_ca_certificates(&self) -> Result<Vec<Certificate>> {
134 match &self.ca_cert_path {
135 Some(ca_path) => {
136 let ca_pem = read_to_string(ca_path)
137 .map_err(|e| anyhow::anyhow!("Failed to read CA certificate: {}", e))?;
138
139 let certs = Certificate::from_pem_bundle(ca_pem.as_bytes())
140 .map_err(|e| anyhow::anyhow!("Failed to parse CA certificate bundle: {}", e))?;
141
142 Ok(certs)
143 }
144 None => Ok(Vec::new()),
145 }
146 }
147}
148
149impl Default for TlsConfig {
150 fn default() -> Self {
151 Self::new()
152 }
153}
154
155pub struct OAuthConfig {
156 pub host: String,
157 pub client_id: String,
158 pub redirect_url: String,
159 pub scopes: Vec<String>,
160}
161
162#[async_trait]
163pub trait AuthProvider: Send + Sync {
164 async fn get_auth_header(&self) -> Result<(String, String)>;
165}
166
167pub struct ApiResponse {
168 pub status: StatusCode,
169 pub payload: Option<Value>,
170}
171
172impl fmt::Debug for AuthMethod {
173 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
174 match self {
175 AuthMethod::BearerToken(_) => f.debug_tuple("BearerToken").field(&"[hidden]").finish(),
176 AuthMethod::ApiKey { header_name, .. } => f
177 .debug_struct("ApiKey")
178 .field("header_name", header_name)
179 .field("key", &"[hidden]")
180 .finish(),
181 AuthMethod::OAuth(_) => f.debug_tuple("OAuth").field(&"[config]").finish(),
182 AuthMethod::Custom(_) => f.debug_tuple("Custom").field(&"[provider]").finish(),
183 }
184 }
185}
186
187impl ApiResponse {
188 pub async fn from_response(response: Response) -> Result<Self> {
189 let status = response.status();
190 let payload = response.json().await.ok();
191 Ok(Self { status, payload })
192 }
193}
194
195pub struct ApiRequestBuilder<'a> {
196 client: &'a ApiClient,
197 path: &'a str,
198 headers: HeaderMap,
199}
200
201impl ApiClient {
202 pub fn new(host: String, auth: AuthMethod) -> Result<Self> {
203 Self::with_timeout(host, auth, Duration::from_secs(600))
204 }
205
206 pub fn with_timeout(host: String, auth: AuthMethod, timeout: Duration) -> Result<Self> {
207 let mut client_builder = Client::builder().timeout(timeout);
208
209 let tls_config = TlsConfig::from_config()?;
211 if let Some(ref config) = tls_config {
212 client_builder = Self::configure_tls(client_builder, config)?;
213 }
214
215 let client = client_builder.build()?;
216
217 Ok(Self {
218 client,
219 host,
220 auth,
221 default_headers: HeaderMap::new(),
222 timeout,
223 tls_config,
224 })
225 }
226
227 fn rebuild_client(&mut self) -> Result<()> {
228 let mut client_builder = Client::builder()
229 .timeout(self.timeout)
230 .default_headers(self.default_headers.clone());
231
232 if let Some(ref tls_config) = self.tls_config {
234 client_builder = Self::configure_tls(client_builder, tls_config)?;
235 }
236
237 self.client = client_builder.build()?;
238 Ok(())
239 }
240
241 fn configure_tls(
243 mut client_builder: reqwest::ClientBuilder,
244 tls_config: &TlsConfig,
245 ) -> Result<reqwest::ClientBuilder> {
246 if tls_config.is_configured() {
247 if let Some(identity) = tls_config.load_identity()? {
249 client_builder = client_builder.identity(identity);
250 }
251
252 let ca_certs = tls_config.load_ca_certificates()?;
254 for ca_cert in ca_certs {
255 client_builder = client_builder.add_root_certificate(ca_cert);
256 }
257 }
258 Ok(client_builder)
259 }
260
261 pub fn with_headers(mut self, headers: HeaderMap) -> Result<Self> {
262 self.default_headers = headers;
263 self.rebuild_client()?;
264 Ok(self)
265 }
266
267 pub fn with_header(mut self, key: &str, value: &str) -> Result<Self> {
268 let header_name = HeaderName::from_bytes(key.as_bytes())?;
269 let header_value = HeaderValue::from_str(value)?;
270 self.default_headers.insert(header_name, header_value);
271 self.rebuild_client()?;
272 Ok(self)
273 }
274
275 pub fn request<'a>(&'a self, path: &'a str) -> ApiRequestBuilder<'a> {
276 ApiRequestBuilder {
277 client: self,
278 path,
279 headers: HeaderMap::new(),
280 }
281 }
282
283 pub async fn api_post(&self, path: &str, payload: &Value) -> Result<ApiResponse> {
284 self.request(path).api_post(payload).await
285 }
286
287 pub async fn response_post(&self, path: &str, payload: &Value) -> Result<Response> {
288 self.request(path).response_post(payload).await
289 }
290
291 pub async fn api_get(&self, path: &str) -> Result<ApiResponse> {
292 self.request(path).api_get().await
293 }
294
295 pub async fn response_get(&self, path: &str) -> Result<Response> {
296 self.request(path).response_get().await
297 }
298
299 fn build_url(&self, path: &str) -> Result<url::Url> {
300 use url::Url;
301 let mut base_url =
302 Url::parse(&self.host).map_err(|e| anyhow::anyhow!("Invalid base URL: {}", e))?;
303
304 let base_path = base_url.path();
305 if !base_path.is_empty() && base_path != "/" && !base_path.ends_with('/') {
306 base_url.set_path(&format!("{}/", base_path));
307 }
308
309 base_url
310 .join(path)
311 .map_err(|e| anyhow::anyhow!("Failed to construct URL: {}", e))
312 }
313
314 async fn get_oauth_token(&self, config: &OAuthConfig) -> Result<String> {
315 super::oauth::get_oauth_token_async(
316 &config.host,
317 &config.client_id,
318 &config.redirect_url,
319 &config.scopes,
320 )
321 .await
322 }
323}
324
325impl<'a> ApiRequestBuilder<'a> {
326 pub fn header(mut self, key: &str, value: &str) -> Result<Self> {
327 let header_name = HeaderName::from_bytes(key.as_bytes())?;
328 let header_value = HeaderValue::from_str(value)?;
329 self.headers.insert(header_name, header_value);
330 Ok(self)
331 }
332
333 #[allow(dead_code)]
334 pub fn headers(mut self, headers: HeaderMap) -> Self {
335 self.headers.extend(headers);
336 self
337 }
338
339 pub async fn api_post(self, payload: &Value) -> Result<ApiResponse> {
340 let response = self.response_post(payload).await?;
341 ApiResponse::from_response(response).await
342 }
343
344 pub async fn response_post(self, payload: &Value) -> Result<Response> {
345 tracing::debug!(
347 "LLM_REQUEST: {}",
348 serde_json::to_string(payload).unwrap_or_else(|_| "{}".to_string())
349 );
350
351 let request = self.send_request(|url, client| client.post(url)).await?;
352 Ok(request.json(payload).send().await?)
353 }
354
355 pub async fn api_get(self) -> Result<ApiResponse> {
356 let response = self.response_get().await?;
357 ApiResponse::from_response(response).await
358 }
359
360 pub async fn response_get(self) -> Result<Response> {
361 let request = self.send_request(|url, client| client.get(url)).await?;
362 Ok(request.send().await?)
363 }
364
365 async fn send_request<F>(&self, request_builder: F) -> Result<reqwest::RequestBuilder>
366 where
367 F: FnOnce(url::Url, &Client) -> reqwest::RequestBuilder,
368 {
369 let url = self.client.build_url(self.path)?;
370 let mut request = request_builder(url, &self.client.client);
371 request = request.headers(self.headers.clone());
372
373 if let Some(session_id) = crate::session_context::current_session_id() {
374 request = request.header(SESSION_ID_HEADER, session_id);
375 }
376
377 request = match &self.client.auth {
378 AuthMethod::BearerToken(token) => {
379 request.header("Authorization", format!("Bearer {}", token))
380 }
381 AuthMethod::ApiKey { header_name, key } => request.header(header_name.as_str(), key),
382 AuthMethod::OAuth(config) => {
383 let token = self.client.get_oauth_token(config).await?;
384 request.header("Authorization", format!("Bearer {}", token))
385 }
386 AuthMethod::Custom(provider) => {
387 let (header_name, header_value) = provider.get_auth_header().await?;
388 request.header(header_name, header_value)
389 }
390 };
391
392 Ok(request)
393 }
394}
395
396impl fmt::Debug for ApiClient {
397 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
398 f.debug_struct("ApiClient")
399 .field("host", &self.host)
400 .field("auth", &"[auth method]")
401 .field("timeout", &self.timeout)
402 .field("default_headers", &self.default_headers)
403 .finish_non_exhaustive()
404 }
405}
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410
411 #[tokio::test]
412 async fn test_session_id_header_injection() {
413 let client = ApiClient::new(
414 "http://localhost:8080".to_string(),
415 AuthMethod::BearerToken("test-token".to_string()),
416 )
417 .unwrap();
418
419 crate::session_context::with_session_id(Some("test-session-456".to_string()), async {
421 let builder = client.request("/test");
422 let request = builder
423 .send_request(|url, client| client.get(url))
424 .await
425 .unwrap();
426
427 let headers = request.build().unwrap().headers().clone();
428
429 assert!(headers.contains_key(SESSION_ID_HEADER));
430 assert_eq!(
431 headers.get(SESSION_ID_HEADER).unwrap().to_str().unwrap(),
432 "test-session-456"
433 );
434 })
435 .await;
436 }
437
438 #[tokio::test]
439 async fn test_no_session_id_header_when_absent() {
440 let client = ApiClient::new(
441 "http://localhost:8080".to_string(),
442 AuthMethod::BearerToken("test-token".to_string()),
443 )
444 .unwrap();
445
446 let builder = client.request("/test");
448 let request = builder
449 .send_request(|url, client| client.get(url))
450 .await
451 .unwrap();
452
453 let headers = request.build().unwrap().headers().clone();
454
455 assert!(!headers.contains_key(SESSION_ID_HEADER));
456 }
457}