qiniu_reqwest/
async_client.rs

1use super::{
2    extensions::TimeoutExtension,
3    sync_client::{call_response_callbacks, from_reqwest_error, make_callback_error, make_user_agent},
4};
5use bytes::Bytes;
6use futures::future::BoxFuture;
7use futures::{ready, AsyncRead, Stream};
8use qiniu_http::{
9    AsyncRequest, AsyncResponse, AsyncResponseBody, AsyncResponseResult, HttpCaller, ResponseError, ResponseErrorKind,
10    SyncRequest, SyncResponseResult, TransferProgressInfo,
11};
12use reqwest::{
13    header::USER_AGENT, Body as AsyncBody, Client as AsyncReqwestClient, Request as AsyncReqwestRequest,
14    Response as AsyncReqwestResponse, Result as ReqwestResult, Url,
15};
16use std::{
17    error::Error,
18    fmt,
19    io::{Error as IoError, ErrorKind as IoErrorKind, Result as IoResult},
20    mem::{take, transmute},
21    num::NonZeroU16,
22    pin::Pin,
23    task::{Context, Poll},
24};
25
26/// Reqwest 异步客户端
27#[cfg_attr(feature = "docs", doc(cfg(feature = "async")))]
28#[derive(Debug, Default)]
29pub struct AsyncClient {
30    async_client: AsyncReqwestClient,
31}
32
33impl AsyncClient {
34    /// 创建 Reqwest 异步客户端
35    #[inline]
36    pub fn new(async_client: AsyncReqwestClient) -> Self {
37        Self { async_client }
38    }
39}
40
41impl From<AsyncReqwestClient> for AsyncClient {
42    #[inline]
43    fn from(async_client: AsyncReqwestClient) -> Self {
44        Self::new(async_client)
45    }
46}
47
48impl HttpCaller for AsyncClient {
49    #[inline]
50    fn call<'a>(&'a self, _request: &'a mut SyncRequest<'_>) -> SyncResponseResult {
51        unimplemented!("AsyncClient does not support blocking call")
52    }
53
54    #[cfg_attr(feature = "docs", doc(cfg(feature = "async")))]
55    fn async_call<'a>(&'a self, request: &'a mut AsyncRequest<'_>) -> BoxFuture<'a, AsyncResponseResult> {
56        Box::pin(async move {
57            let mut user_cancelled_error: Option<ResponseError> = None;
58            let reqwest_request = make_async_reqwest_request(request, &mut user_cancelled_error)?;
59            match self.async_client.execute(reqwest_request).await {
60                Ok(reqwest_response) => from_async_response(reqwest_response, request),
61                Err(err) => user_cancelled_error.map_or_else(|| Err(from_reqwest_error(err, request)), Err),
62            }
63        })
64    }
65}
66
67fn make_async_reqwest_request(
68    request: &mut AsyncRequest,
69    user_cancelled_error: &mut Option<ResponseError>,
70) -> Result<AsyncReqwestRequest, ResponseError> {
71    let url = Url::parse(&request.url().to_string()).map_err(|err| {
72        ResponseError::builder(ResponseErrorKind::InvalidUrl, err)
73            .uri(request.url())
74            .build()
75    })?;
76    let mut reqwest_request = AsyncReqwestRequest::new(request.method().to_owned(), url);
77    for (header_name, header_value) in request.headers() {
78        reqwest_request
79            .headers_mut()
80            .insert(header_name, header_value.to_owned());
81    }
82    reqwest_request
83        .headers_mut()
84        .insert(USER_AGENT, make_user_agent(request, "async")?);
85    *reqwest_request.body_mut() = Some(AsyncBody::wrap_stream(RequestBodyWithCallbacks::new(
86        request,
87        user_cancelled_error,
88    )));
89    if let Some(timeout) = request.extensions().get::<TimeoutExtension>() {
90        *reqwest_request.timeout_mut() = Some(timeout.get());
91    }
92    return Ok(reqwest_request);
93
94    struct RequestBodyWithCallbacks {
95        request: &'static mut AsyncRequest<'static>,
96        have_read: u64,
97        user_cancelled_error: &'static mut Option<ResponseError>,
98    }
99
100    impl RequestBodyWithCallbacks {
101        fn new(request: &mut AsyncRequest, user_cancelled_error: &mut Option<ResponseError>) -> Self {
102            #[allow(unsafe_code)]
103            Self {
104                have_read: 0,
105                request: unsafe { transmute(request) },
106                user_cancelled_error: unsafe { transmute(user_cancelled_error) },
107            }
108        }
109    }
110
111    impl Stream for RequestBodyWithCallbacks {
112        type Item = Result<Vec<u8>, Box<dyn Error + Send + Sync>>;
113
114        fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
115            const BUF_LEN: usize = 32 * 1024;
116            let mut buf = [0u8; BUF_LEN];
117            match ready!(Pin::new(&mut self.request.body_mut()).poll_read(cx, &mut buf)) {
118                Err(err) => Poll::Ready(Some(Err(Box::new(err)))),
119                Ok(0) => Poll::Ready(None),
120                Ok(n) => {
121                    let buf = &buf[..n];
122                    self.have_read += n as u64;
123                    if let Some(on_uploading_progress) = self.request.on_uploading_progress() {
124                        if let Err(err) = on_uploading_progress(TransferProgressInfo::new(
125                            self.have_read,
126                            self.request.body().size(),
127                            buf,
128                        )) {
129                            *self.user_cancelled_error = Some(make_callback_error(err, self.request));
130                            return Poll::Ready(Some(Err(Box::new(IoError::new(
131                                IoErrorKind::Other,
132                                "on_uploading_progress() callback returns error",
133                            )))));
134                        }
135                    }
136                    Poll::Ready(Some(Ok(buf.to_vec())))
137                }
138            }
139        }
140
141        #[inline]
142        fn size_hint(&self) -> (usize, Option<usize>) {
143            (self.have_read as usize, Some(self.request.body().size() as usize))
144        }
145    }
146}
147
148fn from_async_response(mut response: AsyncReqwestResponse, request: &mut AsyncRequest) -> AsyncResponseResult {
149    call_response_callbacks(request, response.status(), response.headers())?;
150    let mut response_builder = AsyncResponse::builder();
151    response_builder
152        .status_code(response.status())
153        .version(response.version())
154        .headers(take(response.headers_mut()))
155        .extensions(take(request.extensions_mut()));
156    if let Some(port) = response.url().port_or_known_default().and_then(NonZeroU16::new) {
157        response_builder.server_port(port);
158    }
159    if let Some(remote_addr) = response.remote_addr() {
160        response_builder.server_ip(remote_addr.ip());
161        if let Some(port) = NonZeroU16::new(remote_addr.port()) {
162            response_builder.server_port(port);
163        }
164    }
165    response_builder.body(AsyncResponseBody::from_reader(AsyncReqwestResponseReadWrapper::new(
166        response.bytes_stream(),
167    )));
168    return Ok(response_builder.build());
169
170    struct AsyncReqwestResponseReadWrapper<S: Stream<Item = ReqwestResult<Bytes>>> {
171        stream: S,
172        buffer: Vec<u8>,
173        used: usize,
174    }
175
176    impl<S: Stream<Item = ReqwestResult<Bytes>>> AsyncReqwestResponseReadWrapper<S> {
177        #[inline]
178        fn new(stream: S) -> Self {
179            AsyncReqwestResponseReadWrapper {
180                stream,
181                buffer: Default::default(),
182                used: 0,
183            }
184        }
185    }
186
187    impl<S: Stream<Item = ReqwestResult<Bytes>>> fmt::Debug for AsyncReqwestResponseReadWrapper<S> {
188        #[inline]
189        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
190            f.debug_struct("AsyncReqwestResponseReadWrapper")
191                .field("buffer_len", &self.buffer.len())
192                .field("buffer_cap", &self.buffer.capacity())
193                .field("used", &self.used)
194                .finish()
195        }
196    }
197
198    impl<S: Stream<Item = ReqwestResult<Bytes>>> AsyncRead for AsyncReqwestResponseReadWrapper<S> {
199        #[allow(unsafe_code)]
200        fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<IoResult<usize>> {
201            let oriself = unsafe { self.get_unchecked_mut() };
202            let buffer_rested = oriself.buffer.len() - oriself.used;
203            if oriself.buffer.is_empty() {
204                let stream = unsafe { Pin::new_unchecked(&mut oriself.stream) };
205                match ready!(stream.poll_next(cx)) {
206                    None => Poll::Ready(Ok(0)),
207                    Some(Err(err)) => Poll::Ready(Err(IoError::new(IoErrorKind::Other, err))),
208                    Some(Ok(data)) => {
209                        if data.len() <= buf.len() {
210                            buf[..data.len()].copy_from_slice(&data);
211                            Poll::Ready(Ok(data.len()))
212                        } else {
213                            buf.copy_from_slice(&data[..buf.len()]);
214                            oriself.buffer.extend_from_slice(&data[buf.len()..]);
215                            oriself.used = 0;
216                            Poll::Ready(Ok(buf.len()))
217                        }
218                    }
219                }
220            } else if buf.len() >= buffer_rested {
221                buf[..buffer_rested].copy_from_slice(&oriself.buffer[oriself.used..]);
222                oriself.buffer.truncate(0);
223                oriself.used = 0;
224                Poll::Ready(Ok(buffer_rested))
225            } else {
226                buf.copy_from_slice(&oriself.buffer[oriself.used..(oriself.used + buf.len())]);
227                oriself.used += buf.len();
228                Poll::Ready(Ok(buf.len()))
229            }
230        }
231    }
232}