1use super::{
2 extensions::TimeoutExtension,
3 sync_client::{call_response_callbacks, from_reqwest_error, make_callback_error, make_user_agent},
4};
5use bytes::Bytes;
6use futures::future::BoxFuture;
7use futures::{ready, AsyncRead, Stream};
8use qiniu_http::{
9 AsyncRequest, AsyncResponse, AsyncResponseBody, AsyncResponseResult, HttpCaller, ResponseError, ResponseErrorKind,
10 SyncRequest, SyncResponseResult, TransferProgressInfo,
11};
12use reqwest::{
13 header::USER_AGENT, Body as AsyncBody, Client as AsyncReqwestClient, Request as AsyncReqwestRequest,
14 Response as AsyncReqwestResponse, Result as ReqwestResult, Url,
15};
16use std::{
17 error::Error,
18 fmt,
19 io::{Error as IoError, ErrorKind as IoErrorKind, Result as IoResult},
20 mem::{take, transmute},
21 num::NonZeroU16,
22 pin::Pin,
23 task::{Context, Poll},
24};
25
26#[cfg_attr(feature = "docs", doc(cfg(feature = "async")))]
28#[derive(Debug, Default)]
29pub struct AsyncClient {
30 async_client: AsyncReqwestClient,
31}
32
33impl AsyncClient {
34 #[inline]
36 pub fn new(async_client: AsyncReqwestClient) -> Self {
37 Self { async_client }
38 }
39}
40
41impl From<AsyncReqwestClient> for AsyncClient {
42 #[inline]
43 fn from(async_client: AsyncReqwestClient) -> Self {
44 Self::new(async_client)
45 }
46}
47
48impl HttpCaller for AsyncClient {
49 #[inline]
50 fn call<'a>(&'a self, _request: &'a mut SyncRequest<'_>) -> SyncResponseResult {
51 unimplemented!("AsyncClient does not support blocking call")
52 }
53
54 #[cfg_attr(feature = "docs", doc(cfg(feature = "async")))]
55 fn async_call<'a>(&'a self, request: &'a mut AsyncRequest<'_>) -> BoxFuture<'a, AsyncResponseResult> {
56 Box::pin(async move {
57 let mut user_cancelled_error: Option<ResponseError> = None;
58 let reqwest_request = make_async_reqwest_request(request, &mut user_cancelled_error)?;
59 match self.async_client.execute(reqwest_request).await {
60 Ok(reqwest_response) => from_async_response(reqwest_response, request),
61 Err(err) => user_cancelled_error.map_or_else(|| Err(from_reqwest_error(err, request)), Err),
62 }
63 })
64 }
65}
66
67fn make_async_reqwest_request(
68 request: &mut AsyncRequest,
69 user_cancelled_error: &mut Option<ResponseError>,
70) -> Result<AsyncReqwestRequest, ResponseError> {
71 let url = Url::parse(&request.url().to_string()).map_err(|err| {
72 ResponseError::builder(ResponseErrorKind::InvalidUrl, err)
73 .uri(request.url())
74 .build()
75 })?;
76 let mut reqwest_request = AsyncReqwestRequest::new(request.method().to_owned(), url);
77 for (header_name, header_value) in request.headers() {
78 reqwest_request
79 .headers_mut()
80 .insert(header_name, header_value.to_owned());
81 }
82 reqwest_request
83 .headers_mut()
84 .insert(USER_AGENT, make_user_agent(request, "async")?);
85 *reqwest_request.body_mut() = Some(AsyncBody::wrap_stream(RequestBodyWithCallbacks::new(
86 request,
87 user_cancelled_error,
88 )));
89 if let Some(timeout) = request.extensions().get::<TimeoutExtension>() {
90 *reqwest_request.timeout_mut() = Some(timeout.get());
91 }
92 return Ok(reqwest_request);
93
94 struct RequestBodyWithCallbacks {
95 request: &'static mut AsyncRequest<'static>,
96 have_read: u64,
97 user_cancelled_error: &'static mut Option<ResponseError>,
98 }
99
100 impl RequestBodyWithCallbacks {
101 fn new(request: &mut AsyncRequest, user_cancelled_error: &mut Option<ResponseError>) -> Self {
102 #[allow(unsafe_code)]
103 Self {
104 have_read: 0,
105 request: unsafe { transmute(request) },
106 user_cancelled_error: unsafe { transmute(user_cancelled_error) },
107 }
108 }
109 }
110
111 impl Stream for RequestBodyWithCallbacks {
112 type Item = Result<Vec<u8>, Box<dyn Error + Send + Sync>>;
113
114 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
115 const BUF_LEN: usize = 32 * 1024;
116 let mut buf = [0u8; BUF_LEN];
117 match ready!(Pin::new(&mut self.request.body_mut()).poll_read(cx, &mut buf)) {
118 Err(err) => Poll::Ready(Some(Err(Box::new(err)))),
119 Ok(0) => Poll::Ready(None),
120 Ok(n) => {
121 let buf = &buf[..n];
122 self.have_read += n as u64;
123 if let Some(on_uploading_progress) = self.request.on_uploading_progress() {
124 if let Err(err) = on_uploading_progress(TransferProgressInfo::new(
125 self.have_read,
126 self.request.body().size(),
127 buf,
128 )) {
129 *self.user_cancelled_error = Some(make_callback_error(err, self.request));
130 return Poll::Ready(Some(Err(Box::new(IoError::new(
131 IoErrorKind::Other,
132 "on_uploading_progress() callback returns error",
133 )))));
134 }
135 }
136 Poll::Ready(Some(Ok(buf.to_vec())))
137 }
138 }
139 }
140
141 #[inline]
142 fn size_hint(&self) -> (usize, Option<usize>) {
143 (self.have_read as usize, Some(self.request.body().size() as usize))
144 }
145 }
146}
147
148fn from_async_response(mut response: AsyncReqwestResponse, request: &mut AsyncRequest) -> AsyncResponseResult {
149 call_response_callbacks(request, response.status(), response.headers())?;
150 let mut response_builder = AsyncResponse::builder();
151 response_builder
152 .status_code(response.status())
153 .version(response.version())
154 .headers(take(response.headers_mut()))
155 .extensions(take(request.extensions_mut()));
156 if let Some(port) = response.url().port_or_known_default().and_then(NonZeroU16::new) {
157 response_builder.server_port(port);
158 }
159 if let Some(remote_addr) = response.remote_addr() {
160 response_builder.server_ip(remote_addr.ip());
161 if let Some(port) = NonZeroU16::new(remote_addr.port()) {
162 response_builder.server_port(port);
163 }
164 }
165 response_builder.body(AsyncResponseBody::from_reader(AsyncReqwestResponseReadWrapper::new(
166 response.bytes_stream(),
167 )));
168 return Ok(response_builder.build());
169
170 struct AsyncReqwestResponseReadWrapper<S: Stream<Item = ReqwestResult<Bytes>>> {
171 stream: S,
172 buffer: Vec<u8>,
173 used: usize,
174 }
175
176 impl<S: Stream<Item = ReqwestResult<Bytes>>> AsyncReqwestResponseReadWrapper<S> {
177 #[inline]
178 fn new(stream: S) -> Self {
179 AsyncReqwestResponseReadWrapper {
180 stream,
181 buffer: Default::default(),
182 used: 0,
183 }
184 }
185 }
186
187 impl<S: Stream<Item = ReqwestResult<Bytes>>> fmt::Debug for AsyncReqwestResponseReadWrapper<S> {
188 #[inline]
189 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
190 f.debug_struct("AsyncReqwestResponseReadWrapper")
191 .field("buffer_len", &self.buffer.len())
192 .field("buffer_cap", &self.buffer.capacity())
193 .field("used", &self.used)
194 .finish()
195 }
196 }
197
198 impl<S: Stream<Item = ReqwestResult<Bytes>>> AsyncRead for AsyncReqwestResponseReadWrapper<S> {
199 #[allow(unsafe_code)]
200 fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<IoResult<usize>> {
201 let oriself = unsafe { self.get_unchecked_mut() };
202 let buffer_rested = oriself.buffer.len() - oriself.used;
203 if oriself.buffer.is_empty() {
204 let stream = unsafe { Pin::new_unchecked(&mut oriself.stream) };
205 match ready!(stream.poll_next(cx)) {
206 None => Poll::Ready(Ok(0)),
207 Some(Err(err)) => Poll::Ready(Err(IoError::new(IoErrorKind::Other, err))),
208 Some(Ok(data)) => {
209 if data.len() <= buf.len() {
210 buf[..data.len()].copy_from_slice(&data);
211 Poll::Ready(Ok(data.len()))
212 } else {
213 buf.copy_from_slice(&data[..buf.len()]);
214 oriself.buffer.extend_from_slice(&data[buf.len()..]);
215 oriself.used = 0;
216 Poll::Ready(Ok(buf.len()))
217 }
218 }
219 }
220 } else if buf.len() >= buffer_rested {
221 buf[..buffer_rested].copy_from_slice(&oriself.buffer[oriself.used..]);
222 oriself.buffer.truncate(0);
223 oriself.used = 0;
224 Poll::Ready(Ok(buffer_rested))
225 } else {
226 buf.copy_from_slice(&oriself.buffer[oriself.used..(oriself.used + buf.len())]);
227 oriself.used += buf.len();
228 Poll::Ready(Ok(buf.len()))
229 }
230 }
231 }
232}