cli/util/
http.rs

1/*---------------------------------------------------------------------------------------------
2 *  Copyright (c) Microsoft Corporation. All rights reserved.
3 *  Licensed under the MIT License. See License.txt in the project root for license information.
4 *--------------------------------------------------------------------------------------------*/
5use crate::{
6	constants::get_default_user_agent,
7	log,
8	util::errors::{self, WrappedError},
9};
10use async_trait::async_trait;
11use core::panic;
12use futures::stream::TryStreamExt;
13use hyper::{
14	header::{HeaderName, CONTENT_LENGTH},
15	http::HeaderValue,
16	HeaderMap, StatusCode,
17};
18use serde::de::DeserializeOwned;
19use std::{io, pin::Pin, str::FromStr, sync::Arc, task::Poll};
20use tokio::{
21	fs,
22	io::{AsyncRead, AsyncReadExt},
23	sync::mpsc,
24};
25use tokio_util::compat::FuturesAsyncReadCompatExt;
26
27use super::{
28	errors::{wrap, AnyError, StatusError},
29	io::{copy_async_progress, ReadBuffer, ReportCopyProgress},
30};
31
32pub async fn download_into_file<T>(
33	filename: &std::path::Path,
34	progress: T,
35	mut res: SimpleResponse,
36) -> Result<fs::File, WrappedError>
37where
38	T: ReportCopyProgress,
39{
40	let mut file = fs::File::create(filename)
41		.await
42		.map_err(|e| errors::wrap(e, "failed to create file"))?;
43
44	let content_length = res
45		.headers
46		.get(CONTENT_LENGTH)
47		.and_then(|h| h.to_str().ok())
48		.and_then(|s| s.parse::<u64>().ok())
49		.unwrap_or(0);
50
51	copy_async_progress(progress, &mut res.read, &mut file, content_length)
52		.await
53		.map_err(|e| errors::wrap(e, "failed to download file"))?;
54
55	Ok(file)
56}
57
58pub struct SimpleResponse {
59	pub status_code: StatusCode,
60	pub headers: HeaderMap,
61	pub read: Pin<Box<dyn Send + AsyncRead + 'static>>,
62	pub url: Option<url::Url>,
63}
64
65impl SimpleResponse {
66	pub fn url_path_basename(&self) -> Option<String> {
67		self.url.as_ref().and_then(|u| {
68			u.path_segments()
69				.and_then(|s| s.last().map(|s| s.to_owned()))
70		})
71	}
72}
73
74impl SimpleResponse {
75	pub fn generic_error(url: &str) -> Self {
76		let (_, rx) = mpsc::unbounded_channel();
77		SimpleResponse {
78			url: url::Url::parse(url).ok(),
79			status_code: StatusCode::INTERNAL_SERVER_ERROR,
80			headers: HeaderMap::new(),
81			read: Box::pin(DelegatedReader::new(rx)),
82		}
83	}
84
85	/// Converts the response into a StatusError
86	pub async fn into_err(mut self) -> StatusError {
87		let mut body = String::new();
88		self.read.read_to_string(&mut body).await.ok();
89
90		StatusError {
91			url: self
92				.url
93				.map(|u| u.to_string())
94				.unwrap_or_else(|| "<invalid url>".to_owned()),
95			status_code: self.status_code.as_u16(),
96			body,
97		}
98	}
99
100	/// Deserializes the response body as JSON
101	pub async fn json<T: DeserializeOwned>(&mut self) -> Result<T, AnyError> {
102		let mut buf = vec![];
103
104		// ideally serde would deserialize a stream, but it does not appear that
105		// is supported. reqwest itself reads and decodes separately like we do here:
106		self.read
107			.read_to_end(&mut buf)
108			.await
109			.map_err(|e| wrap(e, "error reading response"))?;
110
111		let t = serde_json::from_slice(&buf)
112			.map_err(|e| wrap(e, format!("error decoding json from {:?}", self.url)))?;
113
114		Ok(t)
115	}
116}
117
118/// *Very* simple HTTP implementation. In most cases, this will just delegate to
119/// the request library on the server (i.e. `reqwest`) but it can also be used
120/// to make update/download requests on the client rather than the server,
121/// similar to SSH's `remote.SSH.localServerDownload` setting.
122#[async_trait]
123pub trait SimpleHttp {
124	async fn make_request(
125		&self,
126		method: &'static str,
127		url: String,
128	) -> Result<SimpleResponse, AnyError>;
129}
130
131pub type BoxedHttp = Arc<dyn SimpleHttp + Send + Sync + 'static>;
132
133// Implementation of SimpleHttp that uses a reqwest client.
134#[derive(Clone)]
135pub struct ReqwestSimpleHttp {
136	client: reqwest::Client,
137}
138
139impl ReqwestSimpleHttp {
140	pub fn new() -> Self {
141		Self {
142			client: reqwest::ClientBuilder::new()
143				.user_agent(get_default_user_agent())
144				.build()
145				.unwrap(),
146		}
147	}
148
149	pub fn with_client(client: reqwest::Client) -> Self {
150		Self { client }
151	}
152}
153
154impl Default for ReqwestSimpleHttp {
155	fn default() -> Self {
156		Self::new()
157	}
158}
159
160#[async_trait]
161impl SimpleHttp for ReqwestSimpleHttp {
162	async fn make_request(
163		&self,
164		method: &'static str,
165		url: String,
166	) -> Result<SimpleResponse, AnyError> {
167		let res = self
168			.client
169			.request(reqwest::Method::try_from(method).unwrap(), &url)
170			.send()
171			.await?;
172
173		Ok(SimpleResponse {
174			status_code: res.status(),
175			headers: res.headers().clone(),
176			url: Some(res.url().clone()),
177			read: Box::pin(
178				res.bytes_stream()
179					.map_err(|e| futures::io::Error::new(futures::io::ErrorKind::Other, e))
180					.into_async_read()
181					.compat(),
182			),
183		})
184	}
185}
186
187enum DelegatedHttpEvent {
188	InitResponse {
189		status_code: u16,
190		headers: Vec<(String, String)>,
191	},
192	Body(Vec<u8>),
193	End,
194}
195
196// Handle for a delegated request that allows manually issuing and response.
197pub struct DelegatedHttpRequest {
198	pub method: &'static str,
199	pub url: String,
200	ch: mpsc::UnboundedSender<DelegatedHttpEvent>,
201}
202
203impl DelegatedHttpRequest {
204	pub fn initial_response(&self, status_code: u16, headers: Vec<(String, String)>) {
205		self.ch
206			.send(DelegatedHttpEvent::InitResponse {
207				status_code,
208				headers,
209			})
210			.ok();
211	}
212
213	pub fn body(&self, chunk: Vec<u8>) {
214		self.ch.send(DelegatedHttpEvent::Body(chunk)).ok();
215	}
216
217	pub fn end(self) {}
218}
219
220impl Drop for DelegatedHttpRequest {
221	fn drop(&mut self) {
222		self.ch.send(DelegatedHttpEvent::End).ok();
223	}
224}
225
226/// Implementation of SimpleHttp that allows manually controlling responses.
227#[derive(Clone)]
228pub struct DelegatedSimpleHttp {
229	start_request: mpsc::Sender<DelegatedHttpRequest>,
230	log: log::Logger,
231}
232
233impl DelegatedSimpleHttp {
234	pub fn new(log: log::Logger) -> (Self, mpsc::Receiver<DelegatedHttpRequest>) {
235		let (tx, rx) = mpsc::channel(4);
236		(
237			DelegatedSimpleHttp {
238				log,
239				start_request: tx,
240			},
241			rx,
242		)
243	}
244}
245
246#[async_trait]
247impl SimpleHttp for DelegatedSimpleHttp {
248	async fn make_request(
249		&self,
250		method: &'static str,
251		url: String,
252	) -> Result<SimpleResponse, AnyError> {
253		trace!(self.log, "making delegated request to {}", url);
254		let (tx, mut rx) = mpsc::unbounded_channel();
255		let sent = self
256			.start_request
257			.send(DelegatedHttpRequest {
258				method,
259				url: url.clone(),
260				ch: tx,
261			})
262			.await;
263
264		if sent.is_err() {
265			return Ok(SimpleResponse::generic_error(&url)); // sender shut down
266		}
267
268		match rx.recv().await {
269			Some(DelegatedHttpEvent::InitResponse {
270				status_code,
271				headers,
272			}) => {
273				trace!(
274					self.log,
275					"delegated request to {} resulted in status = {}",
276					url,
277					status_code
278				);
279				let mut headers_map = HeaderMap::with_capacity(headers.len());
280				for (k, v) in &headers {
281					if let (Ok(key), Ok(value)) = (
282						HeaderName::from_str(&k.to_lowercase()),
283						HeaderValue::from_str(v),
284					) {
285						headers_map.insert(key, value);
286					}
287				}
288
289				Ok(SimpleResponse {
290					url: url::Url::parse(&url).ok(),
291					status_code: StatusCode::from_u16(status_code)
292						.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
293					headers: headers_map,
294					read: Box::pin(DelegatedReader::new(rx)),
295				})
296			}
297			Some(DelegatedHttpEvent::End) => Ok(SimpleResponse::generic_error(&url)),
298			Some(_) => panic!("expected initresponse as first message from delegated http"),
299			None => Ok(SimpleResponse::generic_error(&url)), // sender shut down
300		}
301	}
302}
303
304struct DelegatedReader {
305	receiver: mpsc::UnboundedReceiver<DelegatedHttpEvent>,
306	readbuf: ReadBuffer,
307}
308
309impl DelegatedReader {
310	pub fn new(rx: mpsc::UnboundedReceiver<DelegatedHttpEvent>) -> Self {
311		DelegatedReader {
312			readbuf: ReadBuffer::default(),
313			receiver: rx,
314		}
315	}
316}
317
318impl AsyncRead for DelegatedReader {
319	fn poll_read(
320		mut self: Pin<&mut Self>,
321		cx: &mut std::task::Context<'_>,
322		buf: &mut tokio::io::ReadBuf<'_>,
323	) -> std::task::Poll<std::io::Result<()>> {
324		if let Some((v, s)) = self.readbuf.take_data() {
325			return self.readbuf.put_data(buf, v, s);
326		}
327
328		match self.receiver.poll_recv(cx) {
329			Poll::Ready(Some(DelegatedHttpEvent::Body(msg))) => self.readbuf.put_data(buf, msg, 0),
330			Poll::Ready(Some(_)) => Poll::Ready(Ok(())), // EOF
331			Poll::Ready(None) => {
332				Poll::Ready(Err(io::Error::new(io::ErrorKind::UnexpectedEof, "EOF")))
333			}
334			Poll::Pending => Poll::Pending,
335		}
336	}
337}
338
339/// Simple http implementation that falls back to delegated http if
340/// making a direct reqwest fails.
341pub struct FallbackSimpleHttp {
342	native: ReqwestSimpleHttp,
343	delegated: DelegatedSimpleHttp,
344}
345
346impl FallbackSimpleHttp {
347	pub fn new(native: ReqwestSimpleHttp, delegated: DelegatedSimpleHttp) -> Self {
348		FallbackSimpleHttp { native, delegated }
349	}
350
351	pub fn native(&self) -> ReqwestSimpleHttp {
352		self.native.clone()
353	}
354
355	pub fn delegated(&self) -> DelegatedSimpleHttp {
356		self.delegated.clone()
357	}
358}
359
360#[async_trait]
361impl SimpleHttp for FallbackSimpleHttp {
362	async fn make_request(
363		&self,
364		method: &'static str,
365		url: String,
366	) -> Result<SimpleResponse, AnyError> {
367		let r1 = self.native.make_request(method, url.clone()).await;
368		if let Ok(res) = r1 {
369			if !res.status_code.is_server_error() {
370				return Ok(res);
371			}
372		}
373
374		self.delegated.make_request(method, url).await
375	}
376}