Skip to main content

a3s_code_core/llm/
http.rs

1//! HTTP utilities and abstraction for LLM API calls
2
3use anyhow::{Context, Result};
4use async_trait::async_trait;
5use futures::StreamExt;
6use std::pin::Pin;
7use std::sync::Arc;
8
9/// HTTP response from a non-streaming POST request
10pub struct HttpResponse {
11    pub status: u16,
12    pub body: String,
13}
14
15/// HTTP response from a streaming POST request
16pub struct StreamingHttpResponse {
17    pub status: u16,
18    /// Retry-After header value (if present)
19    pub retry_after: Option<String>,
20    /// Byte stream (valid when status is 2xx)
21    pub byte_stream: Pin<Box<dyn futures::Stream<Item = Result<bytes::Bytes>> + Send>>,
22    /// Error body (populated when status is not 2xx)
23    pub error_body: String,
24}
25
26/// Abstraction over HTTP POST requests for LLM API calls.
27///
28/// Enables dependency injection for testing without hitting real HTTP endpoints.
29#[async_trait]
30pub trait HttpClient: Send + Sync {
31    /// Make a POST request and return status + body
32    async fn post(
33        &self,
34        url: &str,
35        headers: Vec<(&str, &str)>,
36        body: &serde_json::Value,
37    ) -> Result<HttpResponse>;
38
39    /// Make a POST request and return a streaming response
40    async fn post_streaming(
41        &self,
42        url: &str,
43        headers: Vec<(&str, &str)>,
44        body: &serde_json::Value,
45    ) -> Result<StreamingHttpResponse>;
46}
47
48/// Default HTTP client backed by reqwest
49pub struct ReqwestHttpClient {
50    client: reqwest::Client,
51}
52
53impl ReqwestHttpClient {
54    pub fn new() -> Self {
55        Self {
56            client: reqwest::Client::new(),
57        }
58    }
59}
60
61impl Default for ReqwestHttpClient {
62    fn default() -> Self {
63        Self::new()
64    }
65}
66
67#[async_trait]
68impl HttpClient for ReqwestHttpClient {
69    async fn post(
70        &self,
71        url: &str,
72        headers: Vec<(&str, &str)>,
73        body: &serde_json::Value,
74    ) -> Result<HttpResponse> {
75        tracing::debug!(
76            "HTTP POST to {}: {}",
77            url,
78            serde_json::to_string_pretty(body)?
79        );
80
81        let mut request = self.client.post(url);
82        for (key, value) in headers {
83            request = request.header(key, value);
84        }
85        request = request.json(body);
86
87        let response = request
88            .send()
89            .await
90            .context(format!("Failed to send request to {}", url))?;
91
92        let status = response.status().as_u16();
93        let body = response.text().await?;
94
95        Ok(HttpResponse { status, body })
96    }
97
98    async fn post_streaming(
99        &self,
100        url: &str,
101        headers: Vec<(&str, &str)>,
102        body: &serde_json::Value,
103    ) -> Result<StreamingHttpResponse> {
104        tracing::debug!(
105            "HTTP POST streaming to {}: {}",
106            url,
107            serde_json::to_string_pretty(body)?
108        );
109
110        let mut request = self.client.post(url);
111        for (key, value) in headers {
112            request = request.header(key, value);
113        }
114        request = request.json(body);
115
116        let response = request
117            .send()
118            .await
119            .context(format!("Failed to send streaming request to {}", url))?;
120
121        let status = response.status().as_u16();
122        let retry_after = response
123            .headers()
124            .get("retry-after")
125            .and_then(|v| v.to_str().ok())
126            .map(String::from);
127
128        if (200..300).contains(&status) {
129            let byte_stream = response
130                .bytes_stream()
131                .map(|r| r.map_err(|e| anyhow::anyhow!("Stream error: {}", e)));
132            Ok(StreamingHttpResponse {
133                status,
134                retry_after,
135                byte_stream: Box::pin(byte_stream),
136                error_body: String::new(),
137            })
138        } else {
139            let error_body = response.text().await.unwrap_or_default();
140            // Return an empty stream for error responses
141            let empty: futures::stream::Empty<Result<bytes::Bytes>> = futures::stream::empty();
142            Ok(StreamingHttpResponse {
143                status,
144                retry_after,
145                byte_stream: Box::pin(empty),
146                error_body,
147            })
148        }
149    }
150}
151
152/// Create a default HTTP client
153pub fn default_http_client() -> Arc<dyn HttpClient> {
154    Arc::new(ReqwestHttpClient::new())
155}
156
157/// Normalize base URL by stripping trailing /v1
158pub(crate) fn normalize_base_url(base_url: &str) -> String {
159    base_url
160        .trim_end_matches('/')
161        .trim_end_matches("/v1")
162        .trim_end_matches('/')
163        .to_string()
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169
170    #[test]
171    fn test_normalize_base_url() {
172        assert_eq!(
173            normalize_base_url("https://api.example.com"),
174            "https://api.example.com"
175        );
176        assert_eq!(
177            normalize_base_url("https://api.example.com/"),
178            "https://api.example.com"
179        );
180        assert_eq!(
181            normalize_base_url("https://api.example.com/v1"),
182            "https://api.example.com"
183        );
184        assert_eq!(
185            normalize_base_url("https://api.example.com/v1/"),
186            "https://api.example.com"
187        );
188    }
189
190    #[test]
191    fn test_normalize_base_url_edge_cases() {
192        assert_eq!(
193            normalize_base_url("http://localhost:8080/v1"),
194            "http://localhost:8080"
195        );
196        assert_eq!(
197            normalize_base_url("http://localhost:8080"),
198            "http://localhost:8080"
199        );
200        assert_eq!(
201            normalize_base_url("https://api.example.com/v1/"),
202            "https://api.example.com"
203        );
204    }
205
206    #[test]
207    fn test_normalize_base_url_multiple_trailing_slashes() {
208        assert_eq!(
209            normalize_base_url("https://api.example.com//"),
210            "https://api.example.com"
211        );
212    }
213
214    #[test]
215    fn test_normalize_base_url_with_port() {
216        assert_eq!(
217            normalize_base_url("http://localhost:11434/v1/"),
218            "http://localhost:11434"
219        );
220    }
221
222    #[test]
223    fn test_normalize_base_url_already_normalized() {
224        assert_eq!(
225            normalize_base_url("https://api.openai.com"),
226            "https://api.openai.com"
227        );
228    }
229
230    #[test]
231    fn test_normalize_base_url_empty_string() {
232        assert_eq!(normalize_base_url(""), "");
233    }
234
235    #[test]
236    fn test_default_http_client_creation() {
237        let _client = default_http_client();
238    }
239}