a3s_code_core/llm/
http.rs1use anyhow::{Context, Result};
4use async_trait::async_trait;
5use futures::StreamExt;
6use std::pin::Pin;
7use std::sync::Arc;
8
9pub struct HttpResponse {
11 pub status: u16,
12 pub body: String,
13}
14
15pub struct StreamingHttpResponse {
17 pub status: u16,
18 pub retry_after: Option<String>,
20 pub byte_stream: Pin<Box<dyn futures::Stream<Item = Result<bytes::Bytes>> + Send>>,
22 pub error_body: String,
24}
25
26#[async_trait]
30pub trait HttpClient: Send + Sync {
31 async fn post(
33 &self,
34 url: &str,
35 headers: Vec<(&str, &str)>,
36 body: &serde_json::Value,
37 ) -> Result<HttpResponse>;
38
39 async fn post_streaming(
41 &self,
42 url: &str,
43 headers: Vec<(&str, &str)>,
44 body: &serde_json::Value,
45 ) -> Result<StreamingHttpResponse>;
46}
47
48pub struct ReqwestHttpClient {
50 client: reqwest::Client,
51}
52
53impl ReqwestHttpClient {
54 pub fn new() -> Self {
55 Self {
56 client: reqwest::Client::new(),
57 }
58 }
59}
60
61impl Default for ReqwestHttpClient {
62 fn default() -> Self {
63 Self::new()
64 }
65}
66
67#[async_trait]
68impl HttpClient for ReqwestHttpClient {
69 async fn post(
70 &self,
71 url: &str,
72 headers: Vec<(&str, &str)>,
73 body: &serde_json::Value,
74 ) -> Result<HttpResponse> {
75 tracing::debug!(
76 "HTTP POST to {}: {}",
77 url,
78 serde_json::to_string_pretty(body)?
79 );
80
81 let mut request = self.client.post(url);
82 for (key, value) in headers {
83 request = request.header(key, value);
84 }
85 request = request.json(body);
86
87 let response = request
88 .send()
89 .await
90 .context(format!("Failed to send request to {}", url))?;
91
92 let status = response.status().as_u16();
93 let body = response.text().await?;
94
95 Ok(HttpResponse { status, body })
96 }
97
98 async fn post_streaming(
99 &self,
100 url: &str,
101 headers: Vec<(&str, &str)>,
102 body: &serde_json::Value,
103 ) -> Result<StreamingHttpResponse> {
104 let mut request = self.client.post(url);
105 for (key, value) in headers {
106 request = request.header(key, value);
107 }
108 request = request.json(body);
109
110 let response = request
111 .send()
112 .await
113 .context(format!("Failed to send streaming request to {}", url))?;
114
115 let status = response.status().as_u16();
116 let retry_after = response
117 .headers()
118 .get("retry-after")
119 .and_then(|v| v.to_str().ok())
120 .map(String::from);
121
122 if (200..300).contains(&status) {
123 let byte_stream = response
124 .bytes_stream()
125 .map(|r| r.map_err(|e| anyhow::anyhow!("Stream error: {}", e)));
126 Ok(StreamingHttpResponse {
127 status,
128 retry_after,
129 byte_stream: Box::pin(byte_stream),
130 error_body: String::new(),
131 })
132 } else {
133 let error_body = response.text().await.unwrap_or_default();
134 let empty: futures::stream::Empty<Result<bytes::Bytes>> = futures::stream::empty();
136 Ok(StreamingHttpResponse {
137 status,
138 retry_after,
139 byte_stream: Box::pin(empty),
140 error_body,
141 })
142 }
143 }
144}
145
146pub fn default_http_client() -> Arc<dyn HttpClient> {
148 Arc::new(ReqwestHttpClient::new())
149}
150
151pub(crate) fn normalize_base_url(base_url: &str) -> String {
153 base_url
154 .trim_end_matches('/')
155 .trim_end_matches("/v1")
156 .trim_end_matches('/')
157 .to_string()
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163
164 #[test]
165 fn test_normalize_base_url() {
166 assert_eq!(
167 normalize_base_url("https://api.example.com"),
168 "https://api.example.com"
169 );
170 assert_eq!(
171 normalize_base_url("https://api.example.com/"),
172 "https://api.example.com"
173 );
174 assert_eq!(
175 normalize_base_url("https://api.example.com/v1"),
176 "https://api.example.com"
177 );
178 assert_eq!(
179 normalize_base_url("https://api.example.com/v1/"),
180 "https://api.example.com"
181 );
182 }
183
184 #[test]
185 fn test_normalize_base_url_edge_cases() {
186 assert_eq!(
187 normalize_base_url("http://localhost:8080/v1"),
188 "http://localhost:8080"
189 );
190 assert_eq!(
191 normalize_base_url("http://localhost:8080"),
192 "http://localhost:8080"
193 );
194 assert_eq!(
195 normalize_base_url("https://api.example.com/v1/"),
196 "https://api.example.com"
197 );
198 }
199
200 #[test]
201 fn test_normalize_base_url_multiple_trailing_slashes() {
202 assert_eq!(
203 normalize_base_url("https://api.example.com//"),
204 "https://api.example.com"
205 );
206 }
207
208 #[test]
209 fn test_normalize_base_url_with_port() {
210 assert_eq!(
211 normalize_base_url("http://localhost:11434/v1/"),
212 "http://localhost:11434"
213 );
214 }
215
216 #[test]
217 fn test_normalize_base_url_already_normalized() {
218 assert_eq!(
219 normalize_base_url("https://api.openai.com"),
220 "https://api.openai.com"
221 );
222 }
223
224 #[test]
225 fn test_normalize_base_url_empty_string() {
226 assert_eq!(normalize_base_url(""), "");
227 }
228
229 #[test]
230 fn test_default_http_client_creation() {
231 let _client = default_http_client();
232 }
233}