hyper_tor_connector/
maybe.rs

1use std::task::Poll;
2
3use super::{TorConnector, TorStream};
4use futures::future::BoxFuture;
5use futures::{FutureExt, TryFutureExt};
6use hyper::client::connect::{Connected, Connection};
7use hyper::client::HttpConnector;
8use hyper::Uri;
9use tokio::io::{AsyncRead, AsyncWrite};
10use tokio::net::TcpStream;
11use tower::{BoxError, Service};
12
13#[derive(Debug, Clone)]
14pub enum MaybeTorConnector {
15    ClearnetOnly(HttpConnector),
16    Hybrid {
17        clearnet: HttpConnector,
18        tor: TorConnector,
19    },
20    TorOnly(TorConnector),
21}
22
23#[pin_project::pin_project(project = MaybeTorStreamProj)]
24pub enum MaybeTorStream {
25    Clearnet(#[pin] TcpStream),
26    Tor(#[pin] TorStream),
27}
28impl Connection for MaybeTorStream {
29    fn connected(&self) -> Connected {
30        match self {
31            MaybeTorStream::Clearnet(a) => a.connected(),
32            MaybeTorStream::Tor(a) => a.connected(),
33        }
34    }
35}
36impl AsyncWrite for MaybeTorStream {
37    fn is_write_vectored(&self) -> bool {
38        match self {
39            MaybeTorStream::Clearnet(a) => a.is_write_vectored(),
40            MaybeTorStream::Tor(a) => a.is_write_vectored(),
41        }
42    }
43    fn poll_flush(
44        self: std::pin::Pin<&mut Self>,
45        cx: &mut std::task::Context<'_>,
46    ) -> Poll<Result<(), std::io::Error>> {
47        match self.project() {
48            MaybeTorStreamProj::Clearnet(a) => a.poll_flush(cx),
49            MaybeTorStreamProj::Tor(a) => a.poll_flush(cx),
50        }
51    }
52    fn poll_shutdown(
53        self: std::pin::Pin<&mut Self>,
54        cx: &mut std::task::Context<'_>,
55    ) -> Poll<Result<(), std::io::Error>> {
56        match self.project() {
57            MaybeTorStreamProj::Clearnet(a) => a.poll_shutdown(cx),
58            MaybeTorStreamProj::Tor(a) => a.poll_shutdown(cx),
59        }
60    }
61    fn poll_write(
62        self: std::pin::Pin<&mut Self>,
63        cx: &mut std::task::Context<'_>,
64        buf: &[u8],
65    ) -> Poll<Result<usize, std::io::Error>> {
66        match self.project() {
67            MaybeTorStreamProj::Clearnet(a) => a.poll_write(cx, buf),
68            MaybeTorStreamProj::Tor(a) => a.poll_write(cx, buf),
69        }
70    }
71    fn poll_write_vectored(
72        self: std::pin::Pin<&mut Self>,
73        cx: &mut std::task::Context<'_>,
74        bufs: &[std::io::IoSlice<'_>],
75    ) -> Poll<Result<usize, std::io::Error>> {
76        match self.project() {
77            MaybeTorStreamProj::Clearnet(a) => a.poll_write_vectored(cx, bufs),
78            MaybeTorStreamProj::Tor(a) => a.poll_write_vectored(cx, bufs),
79        }
80    }
81}
82
83impl AsyncRead for MaybeTorStream {
84    fn poll_read(
85        self: std::pin::Pin<&mut Self>,
86        cx: &mut std::task::Context<'_>,
87        buf: &mut tokio::io::ReadBuf<'_>,
88    ) -> Poll<std::io::Result<()>> {
89        match self.project() {
90            MaybeTorStreamProj::Clearnet(a) => a.poll_read(cx, buf),
91            MaybeTorStreamProj::Tor(a) => a.poll_read(cx, buf),
92        }
93    }
94}
95
96impl Service<Uri> for MaybeTorConnector {
97    type Response = MaybeTorStream;
98    type Error = BoxError;
99    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
100
101    fn poll_ready(
102        &mut self,
103        _cx: &mut std::task::Context<'_>,
104    ) -> std::task::Poll<Result<(), Self::Error>> {
105        Poll::Ready(Ok(()))
106    }
107
108    fn call(&mut self, req: Uri) -> Self::Future {
109        match self {
110            MaybeTorConnector::ClearnetOnly(clearnet) => clearnet
111                .call(req)
112                .map_ok(MaybeTorStream::Clearnet)
113                .map_err(BoxError::from)
114                .boxed(),
115            MaybeTorConnector::Hybrid { clearnet, tor } => {
116                if req.host().unwrap_or_default().ends_with(".onion") {
117                    tor.call(req)
118                        .map_ok(MaybeTorStream::Tor)
119                        .map_err(BoxError::from)
120                        .boxed()
121                } else {
122                    clearnet
123                        .call(req)
124                        .map_ok(MaybeTorStream::Clearnet)
125                        .map_err(BoxError::from)
126                        .boxed()
127                }
128            }
129            MaybeTorConnector::TorOnly(tor) => tor
130                .call(req)
131                .map_ok(MaybeTorStream::Tor)
132                .map_err(BoxError::from)
133                .boxed(),
134        }
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141    use hyper::client::Client;
142    use hyper::Body;
143
144    #[tokio::test]
145    async fn get_torproject_page() {
146        #[cfg(feature = "socks")]
147        let tor = TorConnector::new(([127, 0, 0, 1], 9050).into()).unwrap();
148        #[cfg(feature = "arti")]
149        let tor = TorConnector::new().unwrap();
150        let client: Client<MaybeTorConnector, Body> =
151            Client::builder().build(MaybeTorConnector::Hybrid {
152                tor,
153                clearnet: HttpConnector::new(),
154            });
155        assert!(client
156            .get(
157                "http://2gzyxa5ihm7nsggfxnu52rck2vv4rvmdlkiu3zzui5du4xyclen53wid.onion"
158                    .parse()
159                    .unwrap(),
160            )
161            .await
162            .unwrap()
163            .status()
164            .is_success());
165        client
166            .get("http://torproject.org".parse().unwrap())
167            .await
168            .unwrap();
169    }
170}