omni_llm_kit/reqwest_client/
reqwest_client.rs1use std::error::Error;
2use std::sync::{LazyLock, OnceLock};
3use std::{any::type_name, borrow::Cow, mem, pin::Pin, task::Poll, time::Duration};
4
5use anyhow::anyhow;
6use bytes::{BufMut, Bytes, BytesMut};
7use futures::{AsyncRead, FutureExt, TryStreamExt as _};
8use url::Url;
9use reqwest::{
10 header::{HeaderMap, HeaderValue},
11 redirect,
12};
13use crate::http_client;
14use crate::http_client::RedirectPolicy;
15use crate::reqwest_client::http_client_tls::tls_config;
16
17const DEFAULT_CAPACITY: usize = 4096;
18static RUNTIME: OnceLock<tokio::runtime::Runtime> = OnceLock::new();
19
20pub struct ReqwestClient {
21 client: reqwest::Client,
22 proxy: Option<Url>,
23 handle: tokio::runtime::Handle,
24}
25
26impl ReqwestClient {
27 fn builder() -> reqwest::ClientBuilder {
28 reqwest::Client::builder()
29 .use_rustls_tls()
30 .connect_timeout(Duration::from_secs(10))
31 }
32
33 pub fn new() -> Self {
34 Self::builder()
35 .build()
36 .expect("Failed to initialize HTTP client")
37 .into()
38 }
39
40 pub fn user_agent(agent: &str) -> anyhow::Result<Self> {
41 let mut map = HeaderMap::new();
42 map.insert(http::header::USER_AGENT, HeaderValue::from_str(agent)?);
43 let client = Self::builder().default_headers(map).build()?;
44 Ok(client.into())
45 }
46
47 pub fn proxy_and_user_agent(proxy: Option<Url>, agent: &str) -> anyhow::Result<Self> {
48 let mut map = HeaderMap::new();
49 map.insert(http::header::USER_AGENT, HeaderValue::from_str(agent)?);
50 let mut client = Self::builder().default_headers(map);
51 let client_has_proxy;
52
53 if let Some(proxy) = proxy.as_ref().and_then(|proxy_url| {
54 reqwest::Proxy::all(proxy_url.clone())
55 .inspect_err(|e| {
56 log::error!(
57 "Failed to parse proxy URL '{}': {}",
58 proxy_url,
59 e.source().unwrap_or(&e as &_)
60 )
61 })
62 .ok()
63 }) {
64 client = client.proxy(proxy.no_proxy(reqwest::NoProxy::from_env()));
66 client_has_proxy = true;
67 } else {
68 client_has_proxy = false;
69 };
70
71 let client = client
72 .use_preconfigured_tls(tls_config())
73 .build()?;
74 let mut client: ReqwestClient = client.into();
75 client.proxy = client_has_proxy.then_some(proxy).flatten();
76 Ok(client)
77 }
78}
79
80impl From<reqwest::Client> for ReqwestClient {
81 fn from(client: reqwest::Client) -> Self {
82 let handle = tokio::runtime::Handle::try_current().unwrap_or_else(|_| {
83 log::debug!("no tokio runtime found, creating one for Reqwest...");
84 let runtime = RUNTIME.get_or_init(|| {
85 tokio::runtime::Builder::new_multi_thread()
86 .worker_threads(1)
88 .enable_all()
89 .build()
90 .expect("Failed to initialize HTTP client")
91 });
92
93 runtime.handle().clone()
94 });
95 Self {
96 client,
97 handle,
98 proxy: None,
99 }
100 }
101}
102
103struct StreamReader {
107 reader: Option<Pin<Box<dyn futures::AsyncRead + Send + Sync>>>,
108 buf: BytesMut,
109 capacity: usize,
110}
111
112impl StreamReader {
113 fn new(reader: Pin<Box<dyn futures::AsyncRead + Send + Sync>>) -> Self {
114 Self {
115 reader: Some(reader),
116 buf: BytesMut::new(),
117 capacity: DEFAULT_CAPACITY,
118 }
119 }
120}
121
122impl futures::Stream for StreamReader {
123 type Item = std::io::Result<Bytes>;
124
125 fn poll_next(
126 mut self: Pin<&mut Self>,
127 cx: &mut std::task::Context<'_>,
128 ) -> Poll<Option<Self::Item>> {
129 let mut this = self.as_mut();
130
131 let mut reader = match this.reader.take() {
132 Some(r) => r,
133 None => return Poll::Ready(None),
134 };
135
136 if this.buf.capacity() == 0 {
137 let capacity = this.capacity;
138 this.buf.reserve(capacity);
139 }
140
141 match poll_read_buf(&mut reader, cx, &mut this.buf) {
142 Poll::Pending => Poll::Pending,
143 Poll::Ready(Err(err)) => {
144 self.reader = None;
145
146 Poll::Ready(Some(Err(err)))
147 }
148 Poll::Ready(Ok(0)) => {
149 self.reader = None;
150 Poll::Ready(None)
151 }
152 Poll::Ready(Ok(_)) => {
153 let chunk = this.buf.split();
154 self.reader = Some(reader);
155 Poll::Ready(Some(Ok(chunk.freeze())))
156 }
157 }
158 }
159}
160
161pub fn poll_read_buf(
164 io: &mut Pin<Box<dyn futures::AsyncRead + Send + Sync>>,
165 cx: &mut std::task::Context<'_>,
166 buf: &mut BytesMut,
167) -> Poll<std::io::Result<usize>> {
168 if !buf.has_remaining_mut() {
169 return Poll::Ready(Ok(0));
170 }
171
172 let n = {
173 let dst = buf.chunk_mut();
174
175 let dst = unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>]) };
178 let mut buf = tokio::io::ReadBuf::uninit(dst);
179 let ptr = buf.filled().as_ptr();
180 let unfilled_portion = buf.initialize_unfilled();
181 let io_pin = unsafe { Pin::new_unchecked(io) };
183 std::task::ready!(io_pin.poll_read(cx, unfilled_portion)?);
184
185 assert_eq!(ptr, buf.filled().as_ptr());
187 buf.filled().len()
188 };
189
190 unsafe {
193 buf.advance_mut(n);
194 }
195
196 Poll::Ready(Ok(n))
197}
198
199
200impl http_client::HttpClient for ReqwestClient {
201 fn proxy(&self) -> Option<&Url> {
202 self.proxy.as_ref()
203 }
204
205 fn type_name(&self) -> &'static str {
206 type_name::<Self>()
207 }
208
209 fn send(
210 &self,
211 req: http::Request<http_client::AsyncBody>,
212 ) -> futures::future::BoxFuture<
213 'static,
214 anyhow::Result<http_client::Response<http_client::AsyncBody>>,
215 > {
216 let (parts, body) = req.into_parts();
217
218 let mut request = self.client.request(parts.method, parts.uri.to_string());
219 request = request.headers(parts.headers);
220 if let Some(redirect_policy) = parts.extensions.get::<RedirectPolicy>() {
221 }
228 let request = request.body(match body.0 {
229 http_client::Inner::Empty => reqwest::Body::default(),
230 http_client::Inner::Bytes(cursor) => cursor.into_inner().into(),
231 http_client::Inner::AsyncReader(stream) => {
232 reqwest::Body::wrap_stream(StreamReader::new(stream))
233 }
234 });
235
236 let handle = self.handle.clone();
237 async move {
238 let mut response = handle
239 .spawn(async { request.send().await })
240 .await??;
241
242
243 let headers = mem::take(response.headers_mut());
244 let mut builder = http::Response::builder()
245 .status(response.status().as_u16())
246 .version(response.version());
247 *builder.headers_mut().unwrap() = headers;
248
249 let bytes = response
250 .bytes_stream()
251 .map_err(|e| futures::io::Error::new(futures::io::ErrorKind::Other, e))
252 .into_async_read();
253 let body = http_client::AsyncBody::from_reader(bytes);
254
255 builder.body(body).map_err(|e| anyhow!(e))
256 }
257 .boxed()
258 }
259}