1use crate::{
6 constants::get_default_user_agent,
7 log,
8 util::errors::{self, WrappedError},
9};
10use async_trait::async_trait;
11use core::panic;
12use futures::stream::TryStreamExt;
13use hyper::{
14 header::{HeaderName, CONTENT_LENGTH},
15 http::HeaderValue,
16 HeaderMap, StatusCode,
17};
18use serde::de::DeserializeOwned;
19use std::{io, pin::Pin, str::FromStr, sync::Arc, task::Poll};
20use tokio::{
21 fs,
22 io::{AsyncRead, AsyncReadExt},
23 sync::mpsc,
24};
25use tokio_util::compat::FuturesAsyncReadCompatExt;
26
27use super::{
28 errors::{wrap, AnyError, StatusError},
29 io::{copy_async_progress, ReadBuffer, ReportCopyProgress},
30};
31
32pub async fn download_into_file<T>(
33 filename: &std::path::Path,
34 progress: T,
35 mut res: SimpleResponse,
36) -> Result<fs::File, WrappedError>
37where
38 T: ReportCopyProgress,
39{
40 let mut file = fs::File::create(filename)
41 .await
42 .map_err(|e| errors::wrap(e, "failed to create file"))?;
43
44 let content_length = res
45 .headers
46 .get(CONTENT_LENGTH)
47 .and_then(|h| h.to_str().ok())
48 .and_then(|s| s.parse::<u64>().ok())
49 .unwrap_or(0);
50
51 copy_async_progress(progress, &mut res.read, &mut file, content_length)
52 .await
53 .map_err(|e| errors::wrap(e, "failed to download file"))?;
54
55 Ok(file)
56}
57
58pub struct SimpleResponse {
59 pub status_code: StatusCode,
60 pub headers: HeaderMap,
61 pub read: Pin<Box<dyn Send + AsyncRead + 'static>>,
62 pub url: Option<url::Url>,
63}
64
65impl SimpleResponse {
66 pub fn url_path_basename(&self) -> Option<String> {
67 self.url.as_ref().and_then(|u| {
68 u.path_segments()
69 .and_then(|s| s.last().map(|s| s.to_owned()))
70 })
71 }
72}
73
74impl SimpleResponse {
75 pub fn generic_error(url: &str) -> Self {
76 let (_, rx) = mpsc::unbounded_channel();
77 SimpleResponse {
78 url: url::Url::parse(url).ok(),
79 status_code: StatusCode::INTERNAL_SERVER_ERROR,
80 headers: HeaderMap::new(),
81 read: Box::pin(DelegatedReader::new(rx)),
82 }
83 }
84
85 pub async fn into_err(mut self) -> StatusError {
87 let mut body = String::new();
88 self.read.read_to_string(&mut body).await.ok();
89
90 StatusError {
91 url: self
92 .url
93 .map(|u| u.to_string())
94 .unwrap_or_else(|| "<invalid url>".to_owned()),
95 status_code: self.status_code.as_u16(),
96 body,
97 }
98 }
99
100 pub async fn json<T: DeserializeOwned>(&mut self) -> Result<T, AnyError> {
102 let mut buf = vec![];
103
104 self.read
107 .read_to_end(&mut buf)
108 .await
109 .map_err(|e| wrap(e, "error reading response"))?;
110
111 let t = serde_json::from_slice(&buf)
112 .map_err(|e| wrap(e, format!("error decoding json from {:?}", self.url)))?;
113
114 Ok(t)
115 }
116}
117
118#[async_trait]
123pub trait SimpleHttp {
124 async fn make_request(
125 &self,
126 method: &'static str,
127 url: String,
128 ) -> Result<SimpleResponse, AnyError>;
129}
130
131pub type BoxedHttp = Arc<dyn SimpleHttp + Send + Sync + 'static>;
132
133#[derive(Clone)]
135pub struct ReqwestSimpleHttp {
136 client: reqwest::Client,
137}
138
139impl ReqwestSimpleHttp {
140 pub fn new() -> Self {
141 Self {
142 client: reqwest::ClientBuilder::new()
143 .user_agent(get_default_user_agent())
144 .build()
145 .unwrap(),
146 }
147 }
148
149 pub fn with_client(client: reqwest::Client) -> Self {
150 Self { client }
151 }
152}
153
154impl Default for ReqwestSimpleHttp {
155 fn default() -> Self {
156 Self::new()
157 }
158}
159
160#[async_trait]
161impl SimpleHttp for ReqwestSimpleHttp {
162 async fn make_request(
163 &self,
164 method: &'static str,
165 url: String,
166 ) -> Result<SimpleResponse, AnyError> {
167 let res = self
168 .client
169 .request(reqwest::Method::try_from(method).unwrap(), &url)
170 .send()
171 .await?;
172
173 Ok(SimpleResponse {
174 status_code: res.status(),
175 headers: res.headers().clone(),
176 url: Some(res.url().clone()),
177 read: Box::pin(
178 res.bytes_stream()
179 .map_err(|e| futures::io::Error::new(futures::io::ErrorKind::Other, e))
180 .into_async_read()
181 .compat(),
182 ),
183 })
184 }
185}
186
187enum DelegatedHttpEvent {
188 InitResponse {
189 status_code: u16,
190 headers: Vec<(String, String)>,
191 },
192 Body(Vec<u8>),
193 End,
194}
195
196pub struct DelegatedHttpRequest {
198 pub method: &'static str,
199 pub url: String,
200 ch: mpsc::UnboundedSender<DelegatedHttpEvent>,
201}
202
203impl DelegatedHttpRequest {
204 pub fn initial_response(&self, status_code: u16, headers: Vec<(String, String)>) {
205 self.ch
206 .send(DelegatedHttpEvent::InitResponse {
207 status_code,
208 headers,
209 })
210 .ok();
211 }
212
213 pub fn body(&self, chunk: Vec<u8>) {
214 self.ch.send(DelegatedHttpEvent::Body(chunk)).ok();
215 }
216
217 pub fn end(self) {}
218}
219
220impl Drop for DelegatedHttpRequest {
221 fn drop(&mut self) {
222 self.ch.send(DelegatedHttpEvent::End).ok();
223 }
224}
225
226#[derive(Clone)]
228pub struct DelegatedSimpleHttp {
229 start_request: mpsc::Sender<DelegatedHttpRequest>,
230 log: log::Logger,
231}
232
233impl DelegatedSimpleHttp {
234 pub fn new(log: log::Logger) -> (Self, mpsc::Receiver<DelegatedHttpRequest>) {
235 let (tx, rx) = mpsc::channel(4);
236 (
237 DelegatedSimpleHttp {
238 log,
239 start_request: tx,
240 },
241 rx,
242 )
243 }
244}
245
246#[async_trait]
247impl SimpleHttp for DelegatedSimpleHttp {
248 async fn make_request(
249 &self,
250 method: &'static str,
251 url: String,
252 ) -> Result<SimpleResponse, AnyError> {
253 trace!(self.log, "making delegated request to {}", url);
254 let (tx, mut rx) = mpsc::unbounded_channel();
255 let sent = self
256 .start_request
257 .send(DelegatedHttpRequest {
258 method,
259 url: url.clone(),
260 ch: tx,
261 })
262 .await;
263
264 if sent.is_err() {
265 return Ok(SimpleResponse::generic_error(&url)); }
267
268 match rx.recv().await {
269 Some(DelegatedHttpEvent::InitResponse {
270 status_code,
271 headers,
272 }) => {
273 trace!(
274 self.log,
275 "delegated request to {} resulted in status = {}",
276 url,
277 status_code
278 );
279 let mut headers_map = HeaderMap::with_capacity(headers.len());
280 for (k, v) in &headers {
281 if let (Ok(key), Ok(value)) = (
282 HeaderName::from_str(&k.to_lowercase()),
283 HeaderValue::from_str(v),
284 ) {
285 headers_map.insert(key, value);
286 }
287 }
288
289 Ok(SimpleResponse {
290 url: url::Url::parse(&url).ok(),
291 status_code: StatusCode::from_u16(status_code)
292 .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
293 headers: headers_map,
294 read: Box::pin(DelegatedReader::new(rx)),
295 })
296 }
297 Some(DelegatedHttpEvent::End) => Ok(SimpleResponse::generic_error(&url)),
298 Some(_) => panic!("expected initresponse as first message from delegated http"),
299 None => Ok(SimpleResponse::generic_error(&url)), }
301 }
302}
303
304struct DelegatedReader {
305 receiver: mpsc::UnboundedReceiver<DelegatedHttpEvent>,
306 readbuf: ReadBuffer,
307}
308
309impl DelegatedReader {
310 pub fn new(rx: mpsc::UnboundedReceiver<DelegatedHttpEvent>) -> Self {
311 DelegatedReader {
312 readbuf: ReadBuffer::default(),
313 receiver: rx,
314 }
315 }
316}
317
318impl AsyncRead for DelegatedReader {
319 fn poll_read(
320 mut self: Pin<&mut Self>,
321 cx: &mut std::task::Context<'_>,
322 buf: &mut tokio::io::ReadBuf<'_>,
323 ) -> std::task::Poll<std::io::Result<()>> {
324 if let Some((v, s)) = self.readbuf.take_data() {
325 return self.readbuf.put_data(buf, v, s);
326 }
327
328 match self.receiver.poll_recv(cx) {
329 Poll::Ready(Some(DelegatedHttpEvent::Body(msg))) => self.readbuf.put_data(buf, msg, 0),
330 Poll::Ready(Some(_)) => Poll::Ready(Ok(())), Poll::Ready(None) => {
332 Poll::Ready(Err(io::Error::new(io::ErrorKind::UnexpectedEof, "EOF")))
333 }
334 Poll::Pending => Poll::Pending,
335 }
336 }
337}
338
339pub struct FallbackSimpleHttp {
342 native: ReqwestSimpleHttp,
343 delegated: DelegatedSimpleHttp,
344}
345
346impl FallbackSimpleHttp {
347 pub fn new(native: ReqwestSimpleHttp, delegated: DelegatedSimpleHttp) -> Self {
348 FallbackSimpleHttp { native, delegated }
349 }
350
351 pub fn native(&self) -> ReqwestSimpleHttp {
352 self.native.clone()
353 }
354
355 pub fn delegated(&self) -> DelegatedSimpleHttp {
356 self.delegated.clone()
357 }
358}
359
360#[async_trait]
361impl SimpleHttp for FallbackSimpleHttp {
362 async fn make_request(
363 &self,
364 method: &'static str,
365 url: String,
366 ) -> Result<SimpleResponse, AnyError> {
367 let r1 = self.native.make_request(method, url.clone()).await;
368 if let Ok(res) = r1 {
369 if !res.status_code.is_server_error() {
370 return Ok(res);
371 }
372 }
373
374 self.delegated.make_request(method, url).await
375 }
376}