1#![cfg_attr(docsrs, feature(doc_auto_cfg, doc_cfg))]
2#![doc = include_str!("../README.md")]
3#![allow(renamed_and_removed_lints)] #![allow(unknown_lints)] #![warn(missing_docs)]
7#![warn(noop_method_call)]
8#![warn(unreachable_pub)]
9#![warn(clippy::all)]
10#![deny(clippy::await_holding_lock)]
11#![deny(clippy::cargo_common_metadata)]
12#![deny(clippy::cast_lossless)]
13#![deny(clippy::checked_conversions)]
14#![warn(clippy::cognitive_complexity)]
15#![deny(clippy::debug_assert_with_mut_call)]
16#![deny(clippy::exhaustive_enums)]
17#![deny(clippy::exhaustive_structs)]
18#![deny(clippy::expl_impl_clone_on_copy)]
19#![deny(clippy::fallible_impl_from)]
20#![deny(clippy::implicit_clone)]
21#![deny(clippy::large_stack_arrays)]
22#![warn(clippy::manual_ok_or)]
23#![deny(clippy::missing_docs_in_private_items)]
24#![warn(clippy::needless_borrow)]
25#![warn(clippy::needless_pass_by_value)]
26#![warn(clippy::option_option)]
27#![deny(clippy::print_stderr)]
28#![deny(clippy::print_stdout)]
29#![warn(clippy::rc_buffer)]
30#![deny(clippy::ref_option_ref)]
31#![warn(clippy::semicolon_if_nothing_returned)]
32#![warn(clippy::trait_duplication_in_bounds)]
33#![deny(clippy::unchecked_duration_subtraction)]
34#![deny(clippy::unnecessary_wraps)]
35#![warn(clippy::unseparated_literal_suffix)]
36#![deny(clippy::unwrap_used)]
37#![allow(clippy::let_unit_value)] #![allow(clippy::uninlined_format_args)]
39#![allow(clippy::significant_drop_in_scrutinee)] #![allow(clippy::result_large_err)] #![allow(clippy::needless_raw_string_hashes)] use std::future::Future;
45use std::io::Error;
46use std::pin::Pin;
47use std::sync::Arc;
48use std::task::{Context, Poll};
49
50use arti_client::{DataStream, IntoTorAddr, TorClient};
51use educe::Educe;
52use hyper::client::connect::{Connected, Connection};
53use hyper::http::uri::Scheme;
54use hyper::http::Uri;
55use hyper::service::Service;
56use pin_project::pin_project;
57use thiserror::Error;
58use tls_api::TlsConnector as TlsConn; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
60use tor_rtcompat::Runtime;
61
62#[derive(Error, Clone, Debug)]
66#[non_exhaustive]
67pub enum ConnectionError {
68 #[error("unsupported URI scheme in {uri:?}")]
70 UnsupportedUriScheme {
71 uri: Uri,
73 },
74
75 #[error("Missing hostname in {uri:?}")]
77 MissingHostname {
78 uri: Uri,
80 },
81
82 #[error("Tor connection failed")]
84 Arti(#[from] arti_client::Error),
85
86 #[error("TLS connection failed")]
88 TLS(#[source] Arc<anyhow::Error>),
89}
90
91impl tor_error::HasKind for ConnectionError {
93 #[rustfmt::skip]
94 fn kind(&self) -> tor_error::ErrorKind {
95 use ConnectionError as CE;
96 use tor_error::ErrorKind as EK;
97 match self {
98 CE::UnsupportedUriScheme{..} => EK::NotImplemented,
99 CE::MissingHostname{..} => EK::BadApiUsage,
100 CE::Arti(e) => e.kind(),
101 CE::TLS(_) => EK::RemoteProtocolViolation,
102 }
103 }
104}
105
106#[derive(Educe)]
119#[educe(Clone)] pub struct ArtiHttpConnector<R: Runtime, TC: TlsConn> {
121 client: TorClient<R>,
123
124 tls_conn: Arc<TC>,
126}
127
128impl<R: Runtime, TC: TlsConn> ArtiHttpConnector<R, TC> {
131 pub fn new(client: TorClient<R>, tls_conn: TC) -> Self {
133 let tls_conn = tls_conn.into();
134 Self { client, tls_conn }
135 }
136}
137
138#[pin_project]
149pub struct ArtiHttpConnection<TC: TlsConn> {
150 #[pin]
152 inner: MaybeHttpsStream<TC>,
153}
154
155#[pin_project(project = MaybeHttpsStreamProj)]
157enum MaybeHttpsStream<TC: TlsConn> {
158 Http(Pin<Box<DataStream>>), Https(#[pin] TC::TlsStream),
163}
164
165impl<TC: TlsConn> Connection for ArtiHttpConnection<TC> {
166 fn connected(&self) -> Connected {
167 Connected::new()
168 }
169}
170
171impl<TC: TlsConn> AsyncRead for ArtiHttpConnection<TC> {
174 fn poll_read(
175 self: Pin<&mut Self>,
176 cx: &mut Context<'_>,
177 buf: &mut ReadBuf<'_>,
178 ) -> Poll<Result<(), std::io::Error>> {
179 match self.project().inner.project() {
180 MaybeHttpsStreamProj::Http(ds) => ds.as_mut().poll_read(cx, buf),
181 MaybeHttpsStreamProj::Https(t) => t.poll_read(cx, buf),
182 }
183 }
184}
185
186impl<TC: TlsConn> AsyncWrite for ArtiHttpConnection<TC> {
187 fn poll_write(
188 self: Pin<&mut Self>,
189 cx: &mut Context<'_>,
190 buf: &[u8],
191 ) -> Poll<Result<usize, Error>> {
192 match self.project().inner.project() {
193 MaybeHttpsStreamProj::Http(ds) => ds.as_mut().poll_write(cx, buf),
194 MaybeHttpsStreamProj::Https(t) => t.poll_write(cx, buf),
195 }
196 }
197
198 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
199 match self.project().inner.project() {
200 MaybeHttpsStreamProj::Http(ds) => ds.as_mut().poll_flush(cx),
201 MaybeHttpsStreamProj::Https(t) => t.poll_flush(cx),
202 }
203 }
204
205 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
206 match self.project().inner.project() {
207 MaybeHttpsStreamProj::Http(ds) => ds.as_mut().poll_shutdown(cx),
208 MaybeHttpsStreamProj::Https(t) => t.poll_shutdown(cx),
209 }
210 }
211}
212
213#[derive(Debug, Clone, Copy, Eq, PartialEq)]
214enum UseTls {
216 Bare,
218
219 Tls,
221}
222
223fn uri_to_host_port_tls(uri: Uri) -> Result<(String, u16, UseTls), ConnectionError> {
225 let use_tls = {
226 let scheme = uri.scheme();
228 if scheme == Some(&Scheme::HTTP) {
229 UseTls::Bare
230 } else if scheme == Some(&Scheme::HTTPS) {
231 UseTls::Tls
232 } else {
233 return Err(ConnectionError::UnsupportedUriScheme { uri });
234 }
235 };
236 let host = match uri.host() {
237 Some(h) => h,
238 _ => return Err(ConnectionError::MissingHostname { uri }),
239 };
240 let port = uri.port().map(|x| x.as_u16()).unwrap_or(match use_tls {
241 UseTls::Tls => 443,
242 UseTls::Bare => 80,
243 });
244
245 Ok((host.to_owned(), port, use_tls))
246}
247
248impl<R: Runtime, TC: TlsConn> Service<Uri> for ArtiHttpConnector<R, TC> {
249 type Response = ArtiHttpConnection<TC>;
250 type Error = ConnectionError;
251 #[allow(clippy::type_complexity)]
252 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
253
254 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
255 Poll::Ready(Ok(()))
256 }
257
258 fn call(&mut self, req: Uri) -> Self::Future {
259 let client = self.client.clone();
263 let tls_conn = self.tls_conn.clone();
264 Box::pin(async move {
265 let (host, port, use_tls) = uri_to_host_port_tls(req)?;
267 let addr = (&host as &str, port)
269 .into_tor_addr()
270 .map_err(arti_client::Error::from)?;
271 let ds = client.connect(addr).await?;
272
273 let inner = match use_tls {
274 UseTls::Tls => {
275 let conn = tls_conn
276 .connect_impl_tls_stream(&host, ds)
277 .await
278 .map_err(|e| ConnectionError::TLS(e.into()))?;
279 MaybeHttpsStream::Https(conn)
280 }
281 UseTls::Bare => MaybeHttpsStream::Http(Box::new(ds).into()),
282 };
283
284 Ok(ArtiHttpConnection { inner })
285 })
286 }
287}
288
289#[cfg(test)]
290mod test {
291 #![allow(clippy::bool_assert_comparison)]
293 #![allow(clippy::clone_on_copy)]
294 #![allow(clippy::dbg_macro)]
295 #![allow(clippy::mixed_attributes_style)]
296 #![allow(clippy::print_stderr)]
297 #![allow(clippy::print_stdout)]
298 #![allow(clippy::single_char_pattern)]
299 #![allow(clippy::unwrap_used)]
300 #![allow(clippy::unchecked_duration_subtraction)]
301 #![allow(clippy::useless_vec)]
302 #![allow(clippy::needless_pass_by_value)]
303 use super::*;
305
306 fn make_uri(url: &str) -> Uri {
307 url.parse::<Uri>().expect("Unable to parse uri")
308 }
309
310 #[test]
311 fn check_supported_uri_schemes() {
312 let unsupported = [
314 "wss://torproject.org",
315 "file://torproject.org",
316 "ftp://torproject.org",
317 "vnc://torproject.org",
318 "/no/scheme",
319 ];
320 for url in unsupported {
321 assert!(uri_to_host_port_tls(make_uri(url)).is_err());
322 }
323
324 let supported = [
326 ("https://torproject.org", UseTls::Tls),
327 ("http://torproject.org", UseTls::Bare),
328 ];
329 for (url, tls) in supported {
330 let (_ret_host, _ret_port, ret_tls) =
331 uri_to_host_port_tls(make_uri(url)).expect("function should return Result");
332
333 assert_eq!(ret_tls, tls);
334 }
335 }
336
337 #[test]
338 fn get_correct_port_and_tls_from_uri() {
339 let urls = [
342 ("https://torproject.org:999", 999, UseTls::Tls),
343 ("https://torproject.org:80", 80, UseTls::Tls),
344 ("https://torproject.org", 443, UseTls::Tls),
345 ("http://torproject.org:999", 999, UseTls::Bare),
346 ("http://torproject.org:443", 443, UseTls::Bare),
347 ("http://torproject.org", 80, UseTls::Bare),
348 ];
349
350 for (url, port, tls) in urls {
351 let (_ret_host, ret_port, ret_tls) =
352 uri_to_host_port_tls(make_uri(url)).expect("function should return Result");
353
354 assert_eq!(ret_port, port);
355 assert_eq!(ret_tls, tls);
356 }
357 }
358
359 #[test]
360 fn get_correct_host_from_uri() {
361 let urls = [
362 ("https://torproject.org", "torproject.org"),
363 ("http://torproject.org", "torproject.org"),
364 ];
365
366 for (url, host) in urls {
367 let (ret_host, _ret_port, _ret_tls) =
368 uri_to_host_port_tls(make_uri(url)).expect("function should return Result");
369
370 assert_eq!(ret_host, host);
371 }
372 }
373}