a3s_code_core/llm/
http.rs1use anyhow::{Context, Result};
4use async_trait::async_trait;
5use futures::StreamExt;
6use std::pin::Pin;
7use std::sync::Arc;
8
9pub struct HttpResponse {
11 pub status: u16,
12 pub body: String,
13}
14
15pub struct StreamingHttpResponse {
17 pub status: u16,
18 pub retry_after: Option<String>,
20 pub byte_stream: Pin<Box<dyn futures::Stream<Item = Result<bytes::Bytes>> + Send>>,
22 pub error_body: String,
24}
25
26#[async_trait]
30pub trait HttpClient: Send + Sync {
31 async fn post(
33 &self,
34 url: &str,
35 headers: Vec<(&str, &str)>,
36 body: &serde_json::Value,
37 ) -> Result<HttpResponse>;
38
39 async fn post_streaming(
41 &self,
42 url: &str,
43 headers: Vec<(&str, &str)>,
44 body: &serde_json::Value,
45 ) -> Result<StreamingHttpResponse>;
46}
47
48pub struct ReqwestHttpClient {
50 client: reqwest::Client,
51}
52
53impl ReqwestHttpClient {
54 pub fn new() -> Self {
55 Self {
56 client: reqwest::Client::new(),
57 }
58 }
59}
60
61impl Default for ReqwestHttpClient {
62 fn default() -> Self {
63 Self::new()
64 }
65}
66
67#[async_trait]
68impl HttpClient for ReqwestHttpClient {
69 async fn post(
70 &self,
71 url: &str,
72 headers: Vec<(&str, &str)>,
73 body: &serde_json::Value,
74 ) -> Result<HttpResponse> {
75 tracing::debug!(
76 "HTTP POST to {}: {}",
77 url,
78 serde_json::to_string_pretty(body)?
79 );
80
81 let mut request = self.client.post(url);
82 for (key, value) in headers {
83 request = request.header(key, value);
84 }
85 request = request.json(body);
86
87 let response = request
88 .send()
89 .await
90 .context(format!("Failed to send request to {}", url))?;
91
92 let status = response.status().as_u16();
93 let body = response.text().await?;
94
95 Ok(HttpResponse { status, body })
96 }
97
98 async fn post_streaming(
99 &self,
100 url: &str,
101 headers: Vec<(&str, &str)>,
102 body: &serde_json::Value,
103 ) -> Result<StreamingHttpResponse> {
104 tracing::debug!(
105 "HTTP POST streaming to {}: {}",
106 url,
107 serde_json::to_string_pretty(body)?
108 );
109
110 let mut request = self.client.post(url);
111 for (key, value) in headers {
112 request = request.header(key, value);
113 }
114 request = request.json(body);
115
116 let response = request
117 .send()
118 .await
119 .context(format!("Failed to send streaming request to {}", url))?;
120
121 let status = response.status().as_u16();
122 let retry_after = response
123 .headers()
124 .get("retry-after")
125 .and_then(|v| v.to_str().ok())
126 .map(String::from);
127
128 if (200..300).contains(&status) {
129 let byte_stream = response
130 .bytes_stream()
131 .map(|r| r.map_err(|e| anyhow::anyhow!("Stream error: {}", e)));
132 Ok(StreamingHttpResponse {
133 status,
134 retry_after,
135 byte_stream: Box::pin(byte_stream),
136 error_body: String::new(),
137 })
138 } else {
139 let error_body = response.text().await.unwrap_or_default();
140 let empty: futures::stream::Empty<Result<bytes::Bytes>> = futures::stream::empty();
142 Ok(StreamingHttpResponse {
143 status,
144 retry_after,
145 byte_stream: Box::pin(empty),
146 error_body,
147 })
148 }
149 }
150}
151
152pub fn default_http_client() -> Arc<dyn HttpClient> {
154 Arc::new(ReqwestHttpClient::new())
155}
156
157pub(crate) fn normalize_base_url(base_url: &str) -> String {
159 base_url
160 .trim_end_matches('/')
161 .trim_end_matches("/v1")
162 .trim_end_matches('/')
163 .to_string()
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169
170 #[test]
171 fn test_normalize_base_url() {
172 assert_eq!(
173 normalize_base_url("https://api.example.com"),
174 "https://api.example.com"
175 );
176 assert_eq!(
177 normalize_base_url("https://api.example.com/"),
178 "https://api.example.com"
179 );
180 assert_eq!(
181 normalize_base_url("https://api.example.com/v1"),
182 "https://api.example.com"
183 );
184 assert_eq!(
185 normalize_base_url("https://api.example.com/v1/"),
186 "https://api.example.com"
187 );
188 }
189
190 #[test]
191 fn test_normalize_base_url_edge_cases() {
192 assert_eq!(
193 normalize_base_url("http://localhost:8080/v1"),
194 "http://localhost:8080"
195 );
196 assert_eq!(
197 normalize_base_url("http://localhost:8080"),
198 "http://localhost:8080"
199 );
200 assert_eq!(
201 normalize_base_url("https://api.example.com/v1/"),
202 "https://api.example.com"
203 );
204 }
205
206 #[test]
207 fn test_normalize_base_url_multiple_trailing_slashes() {
208 assert_eq!(
209 normalize_base_url("https://api.example.com//"),
210 "https://api.example.com"
211 );
212 }
213
214 #[test]
215 fn test_normalize_base_url_with_port() {
216 assert_eq!(
217 normalize_base_url("http://localhost:11434/v1/"),
218 "http://localhost:11434"
219 );
220 }
221
222 #[test]
223 fn test_normalize_base_url_already_normalized() {
224 assert_eq!(
225 normalize_base_url("https://api.openai.com"),
226 "https://api.openai.com"
227 );
228 }
229
230 #[test]
231 fn test_normalize_base_url_empty_string() {
232 assert_eq!(normalize_base_url(""), "");
233 }
234
235 #[test]
236 fn test_default_http_client_creation() {
237 let _client = default_http_client();
238 }
239}