#![cfg_attr(docsrs, feature(doc_auto_cfg, doc_cfg))]
#![doc = include_str!("../README.md")]
#![cfg_attr(not(ci_arti_stable), allow(renamed_and_removed_lints))]
#![cfg_attr(not(ci_arti_nightly), allow(unknown_lints))]
#![deny(missing_docs)]
#![warn(noop_method_call)]
#![deny(unreachable_pub)]
#![warn(clippy::all)]
#![deny(clippy::await_holding_lock)]
#![deny(clippy::cargo_common_metadata)]
#![deny(clippy::cast_lossless)]
#![deny(clippy::checked_conversions)]
#![warn(clippy::cognitive_complexity)]
#![deny(clippy::debug_assert_with_mut_call)]
#![deny(clippy::exhaustive_enums)]
#![deny(clippy::exhaustive_structs)]
#![deny(clippy::expl_impl_clone_on_copy)]
#![deny(clippy::fallible_impl_from)]
#![deny(clippy::implicit_clone)]
#![deny(clippy::large_stack_arrays)]
#![warn(clippy::manual_ok_or)]
#![deny(clippy::missing_docs_in_private_items)]
#![deny(clippy::missing_panics_doc)]
#![warn(clippy::needless_borrow)]
#![warn(clippy::needless_pass_by_value)]
#![warn(clippy::option_option)]
#![warn(clippy::rc_buffer)]
#![deny(clippy::ref_option_ref)]
#![warn(clippy::semicolon_if_nothing_returned)]
#![warn(clippy::trait_duplication_in_bounds)]
#![deny(clippy::unnecessary_wraps)]
#![warn(clippy::unseparated_literal_suffix)]
#![deny(clippy::unwrap_used)]
#![allow(clippy::let_unit_value)] #![allow(clippy::significant_drop_in_scrutinee)] #![allow(clippy::result_large_err)]
mod err;
pub mod request;
mod response;
mod util;
use tor_circmgr::{CircMgr, DirInfo};
use tor_rtcompat::{Runtime, SleepProvider, SleepProviderExt};
#[cfg(feature = "xz")]
use async_compression::futures::bufread::XzDecoder;
use async_compression::futures::bufread::ZlibDecoder;
#[cfg(feature = "zstd")]
use async_compression::futures::bufread::ZstdDecoder;
use futures::io::{
AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader,
};
use futures::FutureExt;
use memchr::memchr;
use std::sync::Arc;
use std::time::Duration;
use tracing::info;
pub use err::{Error, RequestError, RequestFailedError};
pub use response::{DirResponse, SourceInfo};
pub type Result<T> = std::result::Result<T, Error>;
pub type RequestResult<T> = std::result::Result<T, RequestError>;
pub async fn get_resource<CR, R, SP>(
req: &CR,
dirinfo: DirInfo<'_>,
runtime: &SP,
circ_mgr: Arc<CircMgr<R>>,
) -> Result<DirResponse>
where
CR: request::Requestable + ?Sized,
R: Runtime,
SP: SleepProvider,
{
let circuit = circ_mgr.get_or_launch_dir(dirinfo).await?;
let begin_timeout = Duration::from_secs(5);
let source = SourceInfo::from_circuit(&circuit);
let wrap_err = |error| {
Error::RequestFailed(RequestFailedError {
source: Some(source.clone()),
error,
})
};
req.check_circuit(&circuit).map_err(wrap_err)?;
let mut stream = runtime
.timeout(begin_timeout, circuit.begin_dir_stream())
.await
.map_err(RequestError::from)
.map_err(wrap_err)?
.map_err(RequestError::from)
.map_err(wrap_err)?;
let r = download(runtime, req, &mut stream, Some(source.clone())).await;
if should_retire_circ(&r) {
retire_circ(&circ_mgr, &source, "Partial response");
}
r
}
fn should_retire_circ(result: &Result<DirResponse>) -> bool {
match result {
Err(e) => e.should_retire_circ(),
Ok(dr) => dr.error().map(RequestError::should_retire_circ) == Some(true),
}
}
pub async fn download<R, S, SP>(
runtime: &SP,
req: &R,
stream: &mut S,
source: Option<SourceInfo>,
) -> Result<DirResponse>
where
R: request::Requestable + ?Sized,
S: AsyncRead + AsyncWrite + Send + Unpin,
SP: SleepProvider,
{
let wrap_err = |error| {
Error::RequestFailed(RequestFailedError {
source: source.clone(),
error,
})
};
let partial_ok = req.partial_docs_ok();
let maxlen = req.max_response_len();
let req = req.make_request().map_err(wrap_err)?;
let encoded = util::encode_request(&req);
stream
.write_all(encoded.as_bytes())
.await
.map_err(RequestError::from)
.map_err(wrap_err)?;
stream
.flush()
.await
.map_err(RequestError::from)
.map_err(wrap_err)?;
let mut buffered = BufReader::new(stream);
let header = read_headers(&mut buffered).await.map_err(wrap_err)?;
if header.status != Some(200) {
return Ok(DirResponse::new(
header.status.unwrap_or(0),
None,
vec![],
source,
));
}
let mut decoder = get_decoder(buffered, header.encoding.as_deref()).map_err(wrap_err)?;
let mut result = Vec::new();
let ok = read_and_decompress(runtime, &mut decoder, maxlen, &mut result).await;
let ok = match (partial_ok, ok, result.len()) {
(true, Err(e), n) if n > 0 => {
Err(e)
}
(_, Err(e), _) => {
return Err(wrap_err(e));
}
(_, Ok(()), _) => Ok(()),
};
Ok(DirResponse::new(200, ok.err(), result, source))
}
async fn read_headers<S>(stream: &mut S) -> RequestResult<HeaderStatus>
where
S: AsyncBufRead + Unpin,
{
let mut buf = Vec::with_capacity(1024);
loop {
let n = read_until_limited(stream, b'\n', 2048, &mut buf).await?;
let mut headers = [httparse::EMPTY_HEADER; 32];
let mut response = httparse::Response::new(&mut headers);
match response.parse(&buf[..])? {
httparse::Status::Partial => {
if n == 0 {
return Err(RequestError::TruncatedHeaders);
}
if buf.len() >= 16384 {
return Err(httparse::Error::TooManyHeaders.into());
}
}
httparse::Status::Complete(n_parsed) => {
if response.code != Some(200) {
return Ok(HeaderStatus {
status: response.code,
encoding: None,
});
}
let encoding = if let Some(enc) = response
.headers
.iter()
.find(|h| h.name == "Content-Encoding")
{
Some(String::from_utf8(enc.value.to_vec())?)
} else {
None
};
assert!(n_parsed == buf.len());
return Ok(HeaderStatus {
status: Some(200),
encoding,
});
}
}
if n == 0 {
return Err(RequestError::TruncatedHeaders);
}
}
}
#[derive(Debug, Clone)]
struct HeaderStatus {
status: Option<u16>,
encoding: Option<String>,
}
async fn read_and_decompress<S, SP>(
runtime: &SP,
mut stream: S,
maxlen: usize,
result: &mut Vec<u8>,
) -> RequestResult<()>
where
S: AsyncRead + Unpin,
SP: SleepProvider,
{
let buffer_window_size = 1024;
let mut written_total: usize = 0;
let read_timeout = Duration::from_secs(10);
let timer = runtime.sleep(read_timeout).fuse();
futures::pin_mut!(timer);
loop {
result.resize(written_total + buffer_window_size, 0);
let buf: &mut [u8] = &mut result[written_total..written_total + buffer_window_size];
let status = futures::select! {
status = stream.read(buf).fuse() => status,
_ = timer => {
result.resize(written_total, 0); return Err(RequestError::DirTimeout);
}
};
let written_in_this_loop = match status {
Ok(n) => n,
Err(other) => {
result.resize(written_total, 0); return Err(other.into());
}
};
written_total += written_in_this_loop;
if written_in_this_loop == 0 {
if written_total < result.len() {
result.resize(written_total, 0);
}
return Ok(());
}
if written_total > maxlen {
result.resize(maxlen, 0);
return Err(RequestError::ResponseTooLong(written_total));
}
}
}
fn retire_circ<R, E>(circ_mgr: &Arc<CircMgr<R>>, source_info: &SourceInfo, error: &E)
where
R: Runtime,
E: std::fmt::Display + ?Sized,
{
let id = source_info.unique_circ_id();
info!(
"{}: Retiring circuit because of directory failure: {}",
&id, &error
);
circ_mgr.retire_circ(id);
}
async fn read_until_limited<S>(
stream: &mut S,
byte: u8,
max: usize,
buf: &mut Vec<u8>,
) -> std::io::Result<usize>
where
S: AsyncBufRead + Unpin,
{
let mut n_added = 0;
loop {
let data = stream.fill_buf().await?;
if data.is_empty() {
return Ok(n_added);
}
debug_assert!(n_added < max);
let remaining_space = max - n_added;
let (available, found_byte) = match memchr(byte, data) {
Some(idx) => (idx + 1, true),
None => (data.len(), false),
};
debug_assert!(available >= 1);
let n_to_copy = std::cmp::min(remaining_space, available);
buf.extend(&data[..n_to_copy]);
stream.consume_unpin(n_to_copy);
n_added += n_to_copy;
if found_byte || n_added == max {
return Ok(n_added);
}
}
}
macro_rules! decoder {
($dec:ident, $s:expr) => {{
let mut decoder = $dec::new($s);
decoder.multiple_members(true);
Ok(Box::new(decoder))
}};
}
fn get_decoder<'a, S: AsyncBufRead + Unpin + Send + 'a>(
stream: S,
encoding: Option<&str>,
) -> RequestResult<Box<dyn AsyncRead + Unpin + Send + 'a>> {
match encoding {
None | Some("identity") => Ok(Box::new(stream)),
Some("deflate") => decoder!(ZlibDecoder, stream),
#[cfg(feature = "xz")]
Some("x-tor-lzma") => decoder!(XzDecoder, stream),
#[cfg(feature = "zstd")]
Some("x-zstd") => decoder!(ZstdDecoder, stream),
Some(other) => Err(RequestError::ContentEncoding(other.into())),
}
}
#[cfg(test)]
mod test {
#![allow(clippy::unwrap_used)]
use super::*;
use tor_rtmock::{io::stream_pair, time::MockSleepProvider};
use futures_await_test::async_test;
#[async_test]
async fn test_read_until_limited() -> RequestResult<()> {
let mut out = Vec::new();
let bytes = b"This line eventually ends\nthen comes another\n";
let mut s = &bytes[..];
let res = read_until_limited(&mut s, b'\n', 100, &mut out).await;
assert_eq!(res?, 26);
assert_eq!(&out[..], b"This line eventually ends\n");
let mut s = &bytes[..];
out.clear();
let res = read_until_limited(&mut s, b'\n', 10, &mut out).await;
assert_eq!(res?, 10);
assert_eq!(&out[..], b"This line ");
let mut s = &bytes[..];
out.clear();
let res = read_until_limited(&mut s, b'Z', 100, &mut out).await;
assert_eq!(res?, 45);
assert_eq!(&out[..], &bytes[..]);
Ok(())
}
async fn decomp_basic(
encoding: Option<&str>,
data: &[u8],
maxlen: usize,
) -> (RequestResult<()>, Vec<u8>) {
let mock_time = MockSleepProvider::new(std::time::SystemTime::now());
let mut output = Vec::new();
let mut stream = match get_decoder(data, encoding) {
Ok(s) => s,
Err(e) => return (Err(e), output),
};
let r = read_and_decompress(&mock_time, &mut stream, maxlen, &mut output).await;
(r, output)
}
#[async_test]
async fn decompress_identity() -> RequestResult<()> {
let mut text = Vec::new();
for _ in 0..1000 {
text.extend(b"This is a string with a nontrivial length that we'll use to make sure that the loop is executed more than once.");
}
let limit = 10 << 20;
let (s, r) = decomp_basic(None, &text[..], limit).await;
s?;
assert_eq!(r, text);
let (s, r) = decomp_basic(Some("identity"), &text[..], limit).await;
s?;
assert_eq!(r, text);
let limit = 100;
let (s, r) = decomp_basic(Some("identity"), &text[..], limit).await;
assert!(s.is_err());
assert_eq!(r, &text[..100]);
Ok(())
}
#[async_test]
async fn decomp_zlib() -> RequestResult<()> {
let compressed =
hex::decode("789cf3cf4b5548cb2cce500829cf8730825253200ca79c52881c00e5970c88").unwrap();
let limit = 10 << 20;
let (s, r) = decomp_basic(Some("deflate"), &compressed, limit).await;
s?;
assert_eq!(r, b"One fish Two fish Red fish Blue fish");
Ok(())
}
#[cfg(feature = "zstd")]
#[async_test]
async fn decomp_zstd() -> RequestResult<()> {
let compressed = hex::decode("28b52ffd24250d0100c84f6e6520666973682054776f526564426c756520666973680a0200600c0e2509478352cb").unwrap();
let limit = 10 << 20;
let (s, r) = decomp_basic(Some("x-zstd"), &compressed, limit).await;
s?;
assert_eq!(r, b"One fish Two fish Red fish Blue fish\n");
Ok(())
}
#[cfg(feature = "xz")]
#[async_test]
async fn decomp_xz2() -> RequestResult<()> {
let compressed = hex::decode("fd377a585a000004e6d6b446020021011c00000010cf58cce00024001d5d00279b88a202ca8612cfb3c19c87c34248a570451e4851d3323d34ab8000000000000901af64854c91f600013925d6ec06651fb6f37d010000000004595a").unwrap();
let limit = 10 << 20;
let (s, r) = decomp_basic(Some("x-tor-lzma"), &compressed, limit).await;
s?;
assert_eq!(r, b"One fish Two fish Red fish Blue fish\n");
Ok(())
}
#[async_test]
async fn decomp_unknown() {
let compressed = hex::decode("28b52ffd24250d0100c84f6e6520666973682054776f526564426c756520666973680a0200600c0e2509478352cb").unwrap();
let limit = 10 << 20;
let (s, _r) = decomp_basic(Some("x-proprietary-rle"), &compressed, limit).await;
assert!(matches!(s, Err(RequestError::ContentEncoding(_))));
}
#[async_test]
async fn decomp_bad_data() {
let compressed = b"This is not good zlib data";
let limit = 10 << 20;
let (s, _r) = decomp_basic(Some("deflate"), compressed, limit).await;
assert!(matches!(s, Err(RequestError::IoError(_))));
}
#[async_test]
async fn headers_ok() -> RequestResult<()> {
let text = b"HTTP/1.0 200 OK\r\nDate: ignored\r\nContent-Encoding: Waffles\r\n\r\n";
let mut s = &text[..];
let h = read_headers(&mut s).await?;
assert_eq!(h.status, Some(200));
assert_eq!(h.encoding.as_deref(), Some("Waffles"));
let mut s = &text[..15];
let h = read_headers(&mut s).await;
assert!(matches!(h, Err(RequestError::TruncatedHeaders)));
let text = b"HTTP/1.0 404 Not found\r\n\r\n";
let mut s = &text[..];
let h = read_headers(&mut s).await?;
assert_eq!(h.status, Some(404));
assert!(h.encoding.is_none());
Ok(())
}
#[async_test]
async fn headers_bogus() -> Result<()> {
let text = b"HTTP/999.0 WHAT EVEN\r\n\r\n";
let mut s = &text[..];
let h = read_headers(&mut s).await;
assert!(h.is_err());
assert!(matches!(h, Err(RequestError::HttparseError(_))));
Ok(())
}
fn run_download_test<Req: request::Requestable>(
req: Req,
response: &[u8],
) -> (Result<DirResponse>, RequestResult<Vec<u8>>) {
let (mut s1, s2) = stream_pair();
let (mut s2_r, mut s2_w) = s2.split();
tor_rtcompat::test_with_one_runtime!(|rt| async move {
let rt2 = rt.clone();
let (v1, v2, v3): (
Result<DirResponse>,
RequestResult<Vec<u8>>,
RequestResult<()>,
) = futures::join!(
async {
let r = download(&rt, &req, &mut s1, None).await;
s1.close().await.map_err(|error| {
Error::RequestFailed(RequestFailedError {
source: None,
error: error.into(),
})
})?;
r
},
async {
let mut v = Vec::new();
s2_r.read_to_end(&mut v).await?;
Ok(v)
},
async {
s2_w.write_all(response).await?;
rt2.sleep(Duration::from_millis(50)).await;
s2_w.close().await?;
Ok(())
}
);
assert!(v3.is_ok());
(v1, v2)
})
}
#[test]
fn test_download() -> RequestResult<()> {
let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
let (response, request) = run_download_test(
req,
b"HTTP/1.0 200 OK\r\n\r\nThis is where the descs would go.",
);
let request = request?;
assert!(request[..].starts_with(
b"GET /tor/micro/d/CQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQk.z HTTP/1.0\r\n"
));
let response = response.unwrap();
assert_eq!(response.status_code(), 200);
assert!(!response.is_partial());
assert!(response.error().is_none());
assert!(response.source().is_none());
let out_ref = response.output_unchecked();
assert_eq!(out_ref, b"This is where the descs would go.");
let out = response.into_output_unchecked();
assert_eq!(&out, b"This is where the descs would go.");
Ok(())
}
#[test]
fn test_download_truncated() {
let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
let mut response_text: Vec<u8> =
(*b"HTTP/1.0 200 OK\r\nContent-Encoding: deflate\r\n\r\n").into();
response_text.extend(
hex::decode("789cf3cf4b5548cb2cce500829cf8730825253200ca79c52881c00e5970c88").unwrap(),
);
response_text.extend(
hex::decode("789cf3cf4b5548cb2cce500829cf8730825253200ca79c52881c00e5").unwrap(),
);
let (response, request) = run_download_test(req, &response_text);
assert!(request.is_ok());
assert!(response.is_err());
let req: request::MicrodescRequest = vec![[9; 32]; 2].into_iter().collect();
let (response, request) = run_download_test(req, &response_text);
assert!(request.is_ok());
let response = response.unwrap();
assert_eq!(response.status_code(), 200);
assert!(response.error().is_some());
assert!(response.is_partial());
assert!(response.output_unchecked().len() < 37 * 2);
assert!(response.output_unchecked().starts_with(b"One fish"));
}
#[test]
fn test_404() {
let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
let response_text = b"HTTP/1.0 418 I'm a teapot\r\n\r\n";
let (response, _request) = run_download_test(req, response_text);
assert_eq!(response.unwrap().status_code(), 418);
}
#[test]
fn test_headers_truncated() {
let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
let response_text = b"HTTP/1.0 404 truncation happens here\r\n";
let (response, _request) = run_download_test(req, response_text);
assert!(matches!(
response,
Err(Error::RequestFailed(RequestFailedError {
error: RequestError::TruncatedHeaders,
..
}))
));
let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
let response_text = b"";
let (response, _request) = run_download_test(req, response_text);
assert!(matches!(
response,
Err(Error::RequestFailed(RequestFailedError {
error: RequestError::TruncatedHeaders,
..
}))
));
}
#[test]
fn test_headers_too_long() {
let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
let mut response_text: Vec<u8> = (*b"HTTP/1.0 418 I'm a teapot\r\nX-Too-Many-As: ").into();
response_text.resize(16384, b'A');
let (response, _request) = run_download_test(req, &response_text);
assert!(response.as_ref().unwrap_err().should_retire_circ());
assert!(matches!(
response,
Err(Error::RequestFailed(RequestFailedError {
error: RequestError::HttparseError(_),
..
}))
));
}
}