a3s_code_core/llm/
http.rs1use anyhow::{Context, Result};
4use async_trait::async_trait;
5use futures::StreamExt;
6use std::pin::Pin;
7use std::sync::Arc;
8use std::time::Duration;
9
10pub struct HttpResponse {
12 pub status: u16,
13 pub body: String,
14}
15
16pub struct StreamingHttpResponse {
18 pub status: u16,
19 pub retry_after: Option<String>,
21 pub byte_stream: Pin<Box<dyn futures::Stream<Item = Result<bytes::Bytes>> + Send>>,
23 pub error_body: String,
25}
26
27#[async_trait]
31pub trait HttpClient: Send + Sync {
32 async fn post(
34 &self,
35 url: &str,
36 headers: Vec<(&str, &str)>,
37 body: &serde_json::Value,
38 ) -> Result<HttpResponse>;
39
40 async fn post_streaming(
42 &self,
43 url: &str,
44 headers: Vec<(&str, &str)>,
45 body: &serde_json::Value,
46 ) -> Result<StreamingHttpResponse>;
47}
48
49pub struct ReqwestHttpClient {
51 client: reqwest::Client,
52}
53
54impl ReqwestHttpClient {
55 pub fn new() -> Self {
56 Self {
57 client: build_reqwest_client(None, None).expect("failed to build default HTTP client"),
58 }
59 }
60}
61
62impl Default for ReqwestHttpClient {
63 fn default() -> Self {
64 Self::new()
65 }
66}
67
68#[async_trait]
69impl HttpClient for ReqwestHttpClient {
70 async fn post(
71 &self,
72 url: &str,
73 headers: Vec<(&str, &str)>,
74 body: &serde_json::Value,
75 ) -> Result<HttpResponse> {
76 tracing::debug!(
77 "HTTP POST to {}: {}",
78 url,
79 serde_json::to_string_pretty(body)?
80 );
81
82 let mut request = self.client.post(url);
83 for (key, value) in headers {
84 request = request.header(key, value);
85 }
86 request = request.json(body);
87
88 let response = request
89 .send()
90 .await
91 .context(format!("Failed to send request to {}", url))?;
92
93 let status = response.status().as_u16();
94 let body = response.text().await?;
95
96 Ok(HttpResponse { status, body })
97 }
98
99 async fn post_streaming(
100 &self,
101 url: &str,
102 headers: Vec<(&str, &str)>,
103 body: &serde_json::Value,
104 ) -> Result<StreamingHttpResponse> {
105 let mut request = self.client.post(url);
106 for (key, value) in headers {
107 request = request.header(key, value);
108 }
109 request = request.json(body);
110
111 let response = request
112 .send()
113 .await
114 .context(format!("Failed to send streaming request to {}", url))?;
115
116 let status = response.status().as_u16();
117 let retry_after = response
118 .headers()
119 .get("retry-after")
120 .and_then(|v| v.to_str().ok())
121 .map(String::from);
122
123 if (200..300).contains(&status) {
124 let byte_stream = response
125 .bytes_stream()
126 .map(|r| r.map_err(|e| anyhow::anyhow!("Stream error: {}", e)));
127 Ok(StreamingHttpResponse {
128 status,
129 retry_after,
130 byte_stream: Box::pin(byte_stream),
131 error_body: String::new(),
132 })
133 } else {
134 let error_body = response.text().await.unwrap_or_default();
135 let empty: futures::stream::Empty<Result<bytes::Bytes>> = futures::stream::empty();
137 Ok(StreamingHttpResponse {
138 status,
139 retry_after,
140 byte_stream: Box::pin(empty),
141 error_body,
142 })
143 }
144 }
145}
146
147pub fn default_http_client() -> Arc<dyn HttpClient> {
149 Arc::new(ReqwestHttpClient::new())
150}
151
152pub(crate) fn build_reqwest_client(
159 timeout: Option<Duration>,
160 default_headers: Option<reqwest::header::HeaderMap>,
161) -> Result<reqwest::Client> {
162 let mut builder = reqwest::Client::builder().no_proxy();
163
164 if let Some(timeout) = timeout {
165 builder = builder.timeout(timeout);
166 }
167
168 if let Some(default_headers) = default_headers {
169 builder = builder.default_headers(default_headers);
170 }
171
172 builder.build().context("Failed to build reqwest client")
173}
174
175pub(crate) fn normalize_base_url(base_url: &str) -> String {
177 base_url
178 .trim_end_matches('/')
179 .trim_end_matches("/v1")
180 .trim_end_matches('/')
181 .to_string()
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187
188 #[test]
189 fn test_normalize_base_url() {
190 assert_eq!(
191 normalize_base_url("https://api.example.com"),
192 "https://api.example.com"
193 );
194 assert_eq!(
195 normalize_base_url("https://api.example.com/"),
196 "https://api.example.com"
197 );
198 assert_eq!(
199 normalize_base_url("https://api.example.com/v1"),
200 "https://api.example.com"
201 );
202 assert_eq!(
203 normalize_base_url("https://api.example.com/v1/"),
204 "https://api.example.com"
205 );
206 }
207
208 #[test]
209 fn test_normalize_base_url_edge_cases() {
210 assert_eq!(
211 normalize_base_url("http://localhost:8080/v1"),
212 "http://localhost:8080"
213 );
214 assert_eq!(
215 normalize_base_url("http://localhost:8080"),
216 "http://localhost:8080"
217 );
218 assert_eq!(
219 normalize_base_url("https://api.example.com/v1/"),
220 "https://api.example.com"
221 );
222 }
223
224 #[test]
225 fn test_normalize_base_url_multiple_trailing_slashes() {
226 assert_eq!(
227 normalize_base_url("https://api.example.com//"),
228 "https://api.example.com"
229 );
230 }
231
232 #[test]
233 fn test_normalize_base_url_with_port() {
234 assert_eq!(
235 normalize_base_url("http://localhost:11434/v1/"),
236 "http://localhost:11434"
237 );
238 }
239
240 #[test]
241 fn test_normalize_base_url_already_normalized() {
242 assert_eq!(
243 normalize_base_url("https://api.openai.com"),
244 "https://api.openai.com"
245 );
246 }
247
248 #[test]
249 fn test_normalize_base_url_empty_string() {
250 assert_eq!(normalize_base_url(""), "");
251 }
252
253 #[test]
254 fn test_default_http_client_creation() {
255 let _client = default_http_client();
256 }
257}