Skip to main content

aster/providers/
api_client.rs

1use crate::session_context::SESSION_ID_HEADER;
2use anyhow::Result;
3use async_trait::async_trait;
4use reqwest::{
5    header::{HeaderMap, HeaderName, HeaderValue},
6    Certificate, Client, Identity, Response, StatusCode,
7};
8use serde_json::Value;
9use std::fmt;
10use std::fs::read_to_string;
11use std::path::PathBuf;
12use std::time::Duration;
13
14pub struct ApiClient {
15    client: Client,
16    host: String,
17    auth: AuthMethod,
18    default_headers: HeaderMap,
19    timeout: Duration,
20    tls_config: Option<TlsConfig>,
21}
22
23pub enum AuthMethod {
24    BearerToken(String),
25    ApiKey {
26        header_name: String,
27        key: String,
28    },
29    #[allow(dead_code)]
30    OAuth(OAuthConfig),
31    Custom(Box<dyn AuthProvider>),
32}
33
34#[derive(Debug, Clone)]
35pub struct TlsCertKeyPair {
36    pub cert_path: PathBuf,
37    pub key_path: PathBuf,
38}
39
40#[derive(Debug, Clone)]
41pub struct TlsConfig {
42    pub client_identity: Option<TlsCertKeyPair>,
43    pub ca_cert_path: Option<PathBuf>,
44}
45
46impl TlsConfig {
47    pub fn new() -> Self {
48        Self {
49            client_identity: None,
50            ca_cert_path: None,
51        }
52    }
53
54    pub fn from_config() -> Result<Option<Self>> {
55        let config = crate::config::Config::global();
56        let mut tls_config = TlsConfig::new();
57        let mut has_tls_config = false;
58
59        let client_cert_path = config.get_param::<String>("ASTER_CLIENT_CERT_PATH").ok();
60        let client_key_path = config.get_param::<String>("ASTER_CLIENT_KEY_PATH").ok();
61
62        // Validate that both cert and key are provided if either is provided
63        match (client_cert_path, client_key_path) {
64            (Some(cert_path), Some(key_path)) => {
65                tls_config = tls_config.with_client_cert_and_key(
66                    std::path::PathBuf::from(cert_path),
67                    std::path::PathBuf::from(key_path),
68                );
69                has_tls_config = true;
70            }
71            (Some(_), None) => {
72                return Err(anyhow::anyhow!(
73                    "Client certificate provided (ASTER_CLIENT_CERT_PATH) but no private key (ASTER_CLIENT_KEY_PATH)"
74                ));
75            }
76            (None, Some(_)) => {
77                return Err(anyhow::anyhow!(
78                    "Client private key provided (ASTER_CLIENT_KEY_PATH) but no certificate (ASTER_CLIENT_CERT_PATH)"
79                ));
80            }
81            (None, None) => {}
82        }
83
84        if let Ok(ca_cert_path) = config.get_param::<String>("ASTER_CA_CERT_PATH") {
85            tls_config = tls_config.with_ca_cert(std::path::PathBuf::from(ca_cert_path));
86            has_tls_config = true;
87        }
88
89        if has_tls_config {
90            Ok(Some(tls_config))
91        } else {
92            Ok(None)
93        }
94    }
95
96    pub fn with_client_cert_and_key(mut self, cert_path: PathBuf, key_path: PathBuf) -> Self {
97        self.client_identity = Some(TlsCertKeyPair {
98            cert_path,
99            key_path,
100        });
101        self
102    }
103
104    pub fn with_ca_cert(mut self, path: PathBuf) -> Self {
105        self.ca_cert_path = Some(path);
106        self
107    }
108
109    pub fn is_configured(&self) -> bool {
110        self.client_identity.is_some() || self.ca_cert_path.is_some()
111    }
112
113    pub fn load_identity(&self) -> Result<Option<Identity>> {
114        if let Some(cert_key_pair) = &self.client_identity {
115            let cert_pem = read_to_string(&cert_key_pair.cert_path)
116                .map_err(|e| anyhow::anyhow!("Failed to read client certificate: {}", e))?;
117            let key_pem = read_to_string(&cert_key_pair.key_path)
118                .map_err(|e| anyhow::anyhow!("Failed to read client private key: {}", e))?;
119
120            // Create a combined PEM file with certificate and private key
121            let combined_pem = format!("{}\n{}", cert_pem, key_pem);
122
123            let identity = Identity::from_pem(combined_pem.as_bytes()).map_err(|e| {
124                anyhow::anyhow!("Failed to create identity from cert and key: {}", e)
125            })?;
126
127            Ok(Some(identity))
128        } else {
129            Ok(None)
130        }
131    }
132
133    pub fn load_ca_certificates(&self) -> Result<Vec<Certificate>> {
134        match &self.ca_cert_path {
135            Some(ca_path) => {
136                let ca_pem = read_to_string(ca_path)
137                    .map_err(|e| anyhow::anyhow!("Failed to read CA certificate: {}", e))?;
138
139                let certs = Certificate::from_pem_bundle(ca_pem.as_bytes())
140                    .map_err(|e| anyhow::anyhow!("Failed to parse CA certificate bundle: {}", e))?;
141
142                Ok(certs)
143            }
144            None => Ok(Vec::new()),
145        }
146    }
147}
148
149impl Default for TlsConfig {
150    fn default() -> Self {
151        Self::new()
152    }
153}
154
155pub struct OAuthConfig {
156    pub host: String,
157    pub client_id: String,
158    pub redirect_url: String,
159    pub scopes: Vec<String>,
160}
161
162#[async_trait]
163pub trait AuthProvider: Send + Sync {
164    async fn get_auth_header(&self) -> Result<(String, String)>;
165}
166
167pub struct ApiResponse {
168    pub status: StatusCode,
169    pub payload: Option<Value>,
170}
171
172impl fmt::Debug for AuthMethod {
173    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
174        match self {
175            AuthMethod::BearerToken(_) => f.debug_tuple("BearerToken").field(&"[hidden]").finish(),
176            AuthMethod::ApiKey { header_name, .. } => f
177                .debug_struct("ApiKey")
178                .field("header_name", header_name)
179                .field("key", &"[hidden]")
180                .finish(),
181            AuthMethod::OAuth(_) => f.debug_tuple("OAuth").field(&"[config]").finish(),
182            AuthMethod::Custom(_) => f.debug_tuple("Custom").field(&"[provider]").finish(),
183        }
184    }
185}
186
187impl ApiResponse {
188    pub async fn from_response(response: Response) -> Result<Self> {
189        let status = response.status();
190        let payload = response.json().await.ok();
191        Ok(Self { status, payload })
192    }
193}
194
195pub struct ApiRequestBuilder<'a> {
196    client: &'a ApiClient,
197    path: &'a str,
198    headers: HeaderMap,
199}
200
201impl ApiClient {
202    pub fn new(host: String, auth: AuthMethod) -> Result<Self> {
203        Self::with_timeout(host, auth, Duration::from_secs(600))
204    }
205
206    pub fn with_timeout(host: String, auth: AuthMethod, timeout: Duration) -> Result<Self> {
207        let mut client_builder = Client::builder().timeout(timeout);
208
209        // Configure TLS if needed
210        let tls_config = TlsConfig::from_config()?;
211        if let Some(ref config) = tls_config {
212            client_builder = Self::configure_tls(client_builder, config)?;
213        }
214
215        let client = client_builder.build()?;
216
217        Ok(Self {
218            client,
219            host,
220            auth,
221            default_headers: HeaderMap::new(),
222            timeout,
223            tls_config,
224        })
225    }
226
227    fn rebuild_client(&mut self) -> Result<()> {
228        let mut client_builder = Client::builder()
229            .timeout(self.timeout)
230            .default_headers(self.default_headers.clone());
231
232        // Configure TLS if needed
233        if let Some(ref tls_config) = self.tls_config {
234            client_builder = Self::configure_tls(client_builder, tls_config)?;
235        }
236
237        self.client = client_builder.build()?;
238        Ok(())
239    }
240
241    /// Configure TLS settings on a reqwest ClientBuilder
242    fn configure_tls(
243        mut client_builder: reqwest::ClientBuilder,
244        tls_config: &TlsConfig,
245    ) -> Result<reqwest::ClientBuilder> {
246        if tls_config.is_configured() {
247            // Load client identity (certificate + private key)
248            if let Some(identity) = tls_config.load_identity()? {
249                client_builder = client_builder.identity(identity);
250            }
251
252            // Load CA certificates
253            let ca_certs = tls_config.load_ca_certificates()?;
254            for ca_cert in ca_certs {
255                client_builder = client_builder.add_root_certificate(ca_cert);
256            }
257        }
258        Ok(client_builder)
259    }
260
261    pub fn with_headers(mut self, headers: HeaderMap) -> Result<Self> {
262        self.default_headers = headers;
263        self.rebuild_client()?;
264        Ok(self)
265    }
266
267    pub fn with_header(mut self, key: &str, value: &str) -> Result<Self> {
268        let header_name = HeaderName::from_bytes(key.as_bytes())?;
269        let header_value = HeaderValue::from_str(value)?;
270        self.default_headers.insert(header_name, header_value);
271        self.rebuild_client()?;
272        Ok(self)
273    }
274
275    pub fn request<'a>(&'a self, path: &'a str) -> ApiRequestBuilder<'a> {
276        ApiRequestBuilder {
277            client: self,
278            path,
279            headers: HeaderMap::new(),
280        }
281    }
282
283    pub async fn api_post(&self, path: &str, payload: &Value) -> Result<ApiResponse> {
284        self.request(path).api_post(payload).await
285    }
286
287    pub async fn response_post(&self, path: &str, payload: &Value) -> Result<Response> {
288        self.request(path).response_post(payload).await
289    }
290
291    pub async fn api_get(&self, path: &str) -> Result<ApiResponse> {
292        self.request(path).api_get().await
293    }
294
295    pub async fn response_get(&self, path: &str) -> Result<Response> {
296        self.request(path).response_get().await
297    }
298
299    fn build_url(&self, path: &str) -> Result<url::Url> {
300        use url::Url;
301        let mut base_url =
302            Url::parse(&self.host).map_err(|e| anyhow::anyhow!("Invalid base URL: {}", e))?;
303
304        let base_path = base_url.path();
305        if !base_path.is_empty() && base_path != "/" && !base_path.ends_with('/') {
306            base_url.set_path(&format!("{}/", base_path));
307        }
308
309        base_url
310            .join(path)
311            .map_err(|e| anyhow::anyhow!("Failed to construct URL: {}", e))
312    }
313
314    async fn get_oauth_token(&self, config: &OAuthConfig) -> Result<String> {
315        super::oauth::get_oauth_token_async(
316            &config.host,
317            &config.client_id,
318            &config.redirect_url,
319            &config.scopes,
320        )
321        .await
322    }
323}
324
325impl<'a> ApiRequestBuilder<'a> {
326    pub fn header(mut self, key: &str, value: &str) -> Result<Self> {
327        let header_name = HeaderName::from_bytes(key.as_bytes())?;
328        let header_value = HeaderValue::from_str(value)?;
329        self.headers.insert(header_name, header_value);
330        Ok(self)
331    }
332
333    #[allow(dead_code)]
334    pub fn headers(mut self, headers: HeaderMap) -> Self {
335        self.headers.extend(headers);
336        self
337    }
338
339    pub async fn api_post(self, payload: &Value) -> Result<ApiResponse> {
340        let response = self.response_post(payload).await?;
341        ApiResponse::from_response(response).await
342    }
343
344    pub async fn response_post(self, payload: &Value) -> Result<Response> {
345        // Log the JSON payload being sent to the LLM
346        tracing::debug!(
347            "LLM_REQUEST: {}",
348            serde_json::to_string(payload).unwrap_or_else(|_| "{}".to_string())
349        );
350
351        let request = self.send_request(|url, client| client.post(url)).await?;
352        Ok(request.json(payload).send().await?)
353    }
354
355    pub async fn api_get(self) -> Result<ApiResponse> {
356        let response = self.response_get().await?;
357        ApiResponse::from_response(response).await
358    }
359
360    pub async fn response_get(self) -> Result<Response> {
361        let request = self.send_request(|url, client| client.get(url)).await?;
362        Ok(request.send().await?)
363    }
364
365    async fn send_request<F>(&self, request_builder: F) -> Result<reqwest::RequestBuilder>
366    where
367        F: FnOnce(url::Url, &Client) -> reqwest::RequestBuilder,
368    {
369        let url = self.client.build_url(self.path)?;
370        let mut request = request_builder(url, &self.client.client);
371        request = request.headers(self.headers.clone());
372
373        if let Some(session_id) = crate::session_context::current_session_id() {
374            request = request.header(SESSION_ID_HEADER, session_id);
375        }
376
377        request = match &self.client.auth {
378            AuthMethod::BearerToken(token) => {
379                request.header("Authorization", format!("Bearer {}", token))
380            }
381            AuthMethod::ApiKey { header_name, key } => request.header(header_name.as_str(), key),
382            AuthMethod::OAuth(config) => {
383                let token = self.client.get_oauth_token(config).await?;
384                request.header("Authorization", format!("Bearer {}", token))
385            }
386            AuthMethod::Custom(provider) => {
387                let (header_name, header_value) = provider.get_auth_header().await?;
388                request.header(header_name, header_value)
389            }
390        };
391
392        Ok(request)
393    }
394}
395
396impl fmt::Debug for ApiClient {
397    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
398        f.debug_struct("ApiClient")
399            .field("host", &self.host)
400            .field("auth", &"[auth method]")
401            .field("timeout", &self.timeout)
402            .field("default_headers", &self.default_headers)
403            .finish_non_exhaustive()
404    }
405}
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410
411    #[tokio::test]
412    async fn test_session_id_header_injection() {
413        let client = ApiClient::new(
414            "http://localhost:8080".to_string(),
415            AuthMethod::BearerToken("test-token".to_string()),
416        )
417        .unwrap();
418
419        // Execute request within session context
420        crate::session_context::with_session_id(Some("test-session-456".to_string()), async {
421            let builder = client.request("/test");
422            let request = builder
423                .send_request(|url, client| client.get(url))
424                .await
425                .unwrap();
426
427            let headers = request.build().unwrap().headers().clone();
428
429            assert!(headers.contains_key(SESSION_ID_HEADER));
430            assert_eq!(
431                headers.get(SESSION_ID_HEADER).unwrap().to_str().unwrap(),
432                "test-session-456"
433            );
434        })
435        .await;
436    }
437
438    #[tokio::test]
439    async fn test_no_session_id_header_when_absent() {
440        let client = ApiClient::new(
441            "http://localhost:8080".to_string(),
442            AuthMethod::BearerToken("test-token".to_string()),
443        )
444        .unwrap();
445
446        // Build a request without session context
447        let builder = client.request("/test");
448        let request = builder
449            .send_request(|url, client| client.get(url))
450            .await
451            .unwrap();
452
453        let headers = request.build().unwrap().headers().clone();
454
455        assert!(!headers.contains_key(SESSION_ID_HEADER));
456    }
457}