omni_llm_kit/reqwest_client/
reqwest_client.rs

1use std::error::Error;
2use std::sync::{LazyLock, OnceLock};
3use std::{any::type_name, borrow::Cow, mem, pin::Pin, task::Poll, time::Duration};
4
5use anyhow::anyhow;
6use bytes::{BufMut, Bytes, BytesMut};
7use futures::{AsyncRead, FutureExt, TryStreamExt as _};
8use url::Url;
9use reqwest::{
10    header::{HeaderMap, HeaderValue},
11    redirect,
12};
13use crate::http_client;
14use crate::http_client::RedirectPolicy;
15use crate::reqwest_client::http_client_tls::tls_config;
16
17const DEFAULT_CAPACITY: usize = 4096;
18static RUNTIME: OnceLock<tokio::runtime::Runtime> = OnceLock::new();
19
20pub struct ReqwestClient {
21    client: reqwest::Client,
22    proxy: Option<Url>,
23    handle: tokio::runtime::Handle,
24}
25
26impl ReqwestClient {
27    fn builder() -> reqwest::ClientBuilder {
28        reqwest::Client::builder()
29            .use_rustls_tls()
30            .connect_timeout(Duration::from_secs(10))
31    }
32
33    pub fn new() -> Self {
34        Self::builder()
35            .build()
36            .expect("Failed to initialize HTTP client")
37            .into()
38    }
39
40    pub fn user_agent(agent: &str) -> anyhow::Result<Self> {
41        let mut map = HeaderMap::new();
42        map.insert(http::header::USER_AGENT, HeaderValue::from_str(agent)?);
43        let client = Self::builder().default_headers(map).build()?;
44        Ok(client.into())
45    }
46
47    pub fn proxy_and_user_agent(proxy: Option<Url>, agent: &str) -> anyhow::Result<Self> {
48        let mut map = HeaderMap::new();
49        map.insert(http::header::USER_AGENT, HeaderValue::from_str(agent)?);
50        let mut client = Self::builder().default_headers(map);
51        let client_has_proxy;
52
53        if let Some(proxy) = proxy.as_ref().and_then(|proxy_url| {
54            reqwest::Proxy::all(proxy_url.clone())
55                .inspect_err(|e| {
56                    log::error!(
57                        "Failed to parse proxy URL '{}': {}",
58                        proxy_url,
59                        e.source().unwrap_or(&e as &_)
60                    )
61                })
62                .ok()
63        }) {
64            // Respect NO_PROXY env var
65            client = client.proxy(proxy.no_proxy(reqwest::NoProxy::from_env()));
66            client_has_proxy = true;
67        } else {
68            client_has_proxy = false;
69        };
70
71        let client = client
72            .use_preconfigured_tls(tls_config())
73            .build()?;
74        let mut client: ReqwestClient = client.into();
75        client.proxy = client_has_proxy.then_some(proxy).flatten();
76        Ok(client)
77    }
78}
79
80impl From<reqwest::Client> for ReqwestClient {
81    fn from(client: reqwest::Client) -> Self {
82        let handle = tokio::runtime::Handle::try_current().unwrap_or_else(|_| {
83            log::debug!("no tokio runtime found, creating one for Reqwest...");
84            let runtime = RUNTIME.get_or_init(|| {
85                tokio::runtime::Builder::new_multi_thread()
86                    // Since we now have two executors, let's try to keep our footprint small
87                    .worker_threads(1)
88                    .enable_all()
89                    .build()
90                    .expect("Failed to initialize HTTP client")
91            });
92
93            runtime.handle().clone()
94        });
95        Self {
96            client,
97            handle,
98            proxy: None,
99        }
100    }
101}
102
103// This struct is essentially a re-implementation of
104// https://docs.rs/tokio-util/0.7.12/tokio_util/io/struct.ReaderStream.html
105// except outside of Tokio's aegis
106struct StreamReader {
107    reader: Option<Pin<Box<dyn futures::AsyncRead + Send + Sync>>>,
108    buf: BytesMut,
109    capacity: usize,
110}
111
112impl StreamReader {
113    fn new(reader: Pin<Box<dyn futures::AsyncRead + Send + Sync>>) -> Self {
114        Self {
115            reader: Some(reader),
116            buf: BytesMut::new(),
117            capacity: DEFAULT_CAPACITY,
118        }
119    }
120}
121
122impl futures::Stream for StreamReader {
123    type Item = std::io::Result<Bytes>;
124
125    fn poll_next(
126        mut self: Pin<&mut Self>,
127        cx: &mut std::task::Context<'_>,
128    ) -> Poll<Option<Self::Item>> {
129        let mut this = self.as_mut();
130
131        let mut reader = match this.reader.take() {
132            Some(r) => r,
133            None => return Poll::Ready(None),
134        };
135
136        if this.buf.capacity() == 0 {
137            let capacity = this.capacity;
138            this.buf.reserve(capacity);
139        }
140
141        match poll_read_buf(&mut reader, cx, &mut this.buf) {
142            Poll::Pending => Poll::Pending,
143            Poll::Ready(Err(err)) => {
144                self.reader = None;
145
146                Poll::Ready(Some(Err(err)))
147            }
148            Poll::Ready(Ok(0)) => {
149                self.reader = None;
150                Poll::Ready(None)
151            }
152            Poll::Ready(Ok(_)) => {
153                let chunk = this.buf.split();
154                self.reader = Some(reader);
155                Poll::Ready(Some(Ok(chunk.freeze())))
156            }
157        }
158    }
159}
160
161/// Implementation from <https://docs.rs/tokio-util/0.7.12/src/tokio_util/util/poll_buf.rs.html>
162/// Specialized for this use case
163pub fn poll_read_buf(
164    io: &mut Pin<Box<dyn futures::AsyncRead + Send + Sync>>,
165    cx: &mut std::task::Context<'_>,
166    buf: &mut BytesMut,
167) -> Poll<std::io::Result<usize>> {
168    if !buf.has_remaining_mut() {
169        return Poll::Ready(Ok(0));
170    }
171
172    let n = {
173        let dst = buf.chunk_mut();
174
175        // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a
176        // transparent wrapper around `[MaybeUninit<u8>]`.
177        let dst = unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>]) };
178        let mut buf = tokio::io::ReadBuf::uninit(dst);
179        let ptr = buf.filled().as_ptr();
180        let unfilled_portion = buf.initialize_unfilled();
181        // SAFETY: Pin projection
182        let io_pin = unsafe { Pin::new_unchecked(io) };
183        std::task::ready!(io_pin.poll_read(cx, unfilled_portion)?);
184
185        // Ensure the pointer does not change from under us
186        assert_eq!(ptr, buf.filled().as_ptr());
187        buf.filled().len()
188    };
189
190    // Safety: This is guaranteed to be the number of initialized (and read)
191    // bytes due to the invariants provided by `ReadBuf::filled`.
192    unsafe {
193        buf.advance_mut(n);
194    }
195
196    Poll::Ready(Ok(n))
197}
198
199
200impl http_client::HttpClient for ReqwestClient {
201    fn proxy(&self) -> Option<&Url> {
202        self.proxy.as_ref()
203    }
204
205    fn type_name(&self) -> &'static str {
206        type_name::<Self>()
207    }
208
209    fn send(
210        &self,
211        req: http::Request<http_client::AsyncBody>,
212    ) -> futures::future::BoxFuture<
213        'static,
214        anyhow::Result<http_client::Response<http_client::AsyncBody>>,
215    > {
216        let (parts, body) = req.into_parts();
217
218        let mut request = self.client.request(parts.method, parts.uri.to_string());
219        request = request.headers(parts.headers);
220        if let Some(redirect_policy) = parts.extensions.get::<RedirectPolicy>() {
221            // todo ya
222            // request = request.redirect_policy(match redirect_policy {
223            //     RedirectPolicy::NoFollow => redirect::Policy::none(),
224            //     RedirectPolicy::FollowLimit(limit) => redirect::Policy::limited(*limit as usize),
225            //     RedirectPolicy::FollowAll => redirect::Policy::limited(100),
226            // });
227        }
228        let request = request.body(match body.0 {
229            http_client::Inner::Empty => reqwest::Body::default(),
230            http_client::Inner::Bytes(cursor) => cursor.into_inner().into(),
231            http_client::Inner::AsyncReader(stream) => {
232                reqwest::Body::wrap_stream(StreamReader::new(stream))
233            }
234        });
235
236        let handle = self.handle.clone();
237        async move {
238            let mut response = handle
239                .spawn(async { request.send().await })
240                .await??;
241
242
243            let headers = mem::take(response.headers_mut());
244            let mut builder = http::Response::builder()
245                .status(response.status().as_u16())
246                .version(response.version());
247            *builder.headers_mut().unwrap() = headers;
248
249            let bytes = response
250                .bytes_stream()
251                .map_err(|e| futures::io::Error::new(futures::io::ErrorKind::Other, e))
252                .into_async_read();
253            let body = http_client::AsyncBody::from_reader(bytes);
254
255            builder.body(body).map_err(|e| anyhow!(e))
256        }
257        .boxed()
258    }
259}