hypertor/
lib.rs

1use anyhow::Result;
2use arti_client::{TorClient, TorClientConfig};
3use http_body_util::{Empty, Full};
4use hyper::body::Bytes;
5use hyper::body::Incoming;
6use hyper::header::HeaderValue;
7use hyper::http::uri::Scheme;
8use hyper::{Request, Response, Uri};
9use hyper_util::rt::TokioIo;
10use std::io::Error as IoError;
11use tokio::io::{AsyncRead, AsyncWrite};
12use tokio_native_tls::native_tls::TlsConnector;
13use tor_rtcompat::PreferredRuntime;
14
15/// A trait for types that implement both `AsyncRead` and `AsyncWrite`.
16pub trait AsyncReadWrite: AsyncRead + AsyncWrite {}
17
18impl<T> AsyncReadWrite for T where T: AsyncRead + AsyncWrite {}
19
20/// Configuration for the `Client`.
21pub struct ClientConfig {
22    /// TLS configuration for HTTPS connections.
23    pub tls_config: TlsConnector,
24    /// Tor client configuration for routing through the Tor network.
25    pub tor_config: TorClientConfig,
26}
27
28/// Builder for creating a `ClientConfig`.
29pub struct ClientConfigBuilder {
30    tls_config: Option<TlsConnector>,
31    tor_config: Option<TorClientConfig>,
32}
33
34impl ClientConfigBuilder {
35    /// Creates a new `ClientConfigBuilder`.
36    pub fn new() -> Self {
37        ClientConfigBuilder {
38            tls_config: None,
39            tor_config: None,
40        }
41    }
42
43    /// Sets the TLS configuration for the `ClientConfigBuilder`.
44    pub fn tls_config(mut self, tls_config: TlsConnector) -> Self {
45        self.tls_config = Some(tls_config);
46        self
47    }
48
49    /// Sets the Tor configuration for the `ClientConfigBuilder`.
50    pub fn tor_config(mut self, tor_config: TorClientConfig) -> Self {
51        self.tor_config = Some(tor_config);
52        self
53    }
54
55    /// Builds the `ClientConfig` from the `ClientConfigBuilder`.
56    pub fn build(self) -> Result<ClientConfig> {
57        Ok(ClientConfig {
58            tls_config: self.tls_config.unwrap_or_else(|| {
59                TlsConnector::builder()
60                    .build()
61                    .expect("Failed to create default TlsConnector")
62            }),
63            tor_config: self.tor_config.unwrap_or_else(|| {
64                let mut cfg_builder = TorClientConfig::builder();
65                cfg_builder.address_filter().allow_onion_addrs(true);
66                cfg_builder
67                    .build()
68                    .expect("Failed to create default TorClientConfig")
69            }),
70        })
71    }
72}
73
74/// A client for making HTTP requests over Tor with optional TLS.
75pub struct Client {
76    tor_client: TorClient<PreferredRuntime>,
77    config: ClientConfig,
78}
79
80impl Client {
81    /// Creates a new `Client` with the provided `ClientConfig`.
82    pub async fn with_config(config: ClientConfig) -> Result<Self> {
83        let tor_client = Self::create_tor_client(&config).await?;
84        Ok(Client { tor_client, config })
85    }
86
87    /// Creates a new `Client` with default configuration.
88    pub async fn new() -> Result<Self> {
89        let default_config = ClientConfigBuilder::new().build()?;
90        Self::with_config(default_config).await
91    }
92
93    /// Creates a Tor client using the given configuration.
94    async fn create_tor_client(config: &ClientConfig) -> Result<TorClient<PreferredRuntime>> {
95        let tor_client = TorClient::create_bootstrapped(config.tor_config.clone()).await?;
96        Ok(tor_client)
97    }
98
99    /// Sends an HTTP HEAD request to the specified URI.
100    pub async fn head<T>(&self, uri: T) -> Result<Response<Incoming>>
101    where
102        Uri: TryFrom<T>,
103        <Uri as TryFrom<T>>::Error: Into<hyper::http::Error>,
104    {
105        let req = Request::head(uri).body(Empty::<Bytes>::new())?;
106
107        let resp = self.send_request(req).await?;
108        Ok(resp)
109    }
110
111    /// Sends an HTTP GET request to the specified URI.
112    pub async fn get<T>(&self, uri: T) -> Result<Response<Incoming>>
113    where
114        Uri: TryFrom<T>,
115        <Uri as TryFrom<T>>::Error: Into<hyper::http::Error>,
116    {
117        let req = Request::get(uri).body(Empty::<Bytes>::new())?;
118
119        let resp = self.send_request(req).await?;
120        Ok(resp)
121    }
122
123    /// Sends an HTTP POST request to the specified URI with the given content type and body.
124    pub async fn post<T>(
125        &self,
126        uri: T,
127        content_type: &str,
128        body: Bytes,
129    ) -> Result<Response<Incoming>>
130    where
131        Uri: TryFrom<T>,
132        <Uri as TryFrom<T>>::Error: Into<hyper::http::Error>,
133    {
134        let req = Request::post(uri)
135            .header(hyper::header::CONTENT_TYPE, content_type)
136            .body(Full::<Bytes>::from(body))?;
137
138        let resp = self.send_request(req).await?;
139        Ok(resp)
140    }
141
142    /// Sends an HTTP request and returns the response.
143    async fn send_request<B>(&self, req: Request<B>) -> Result<Response<Incoming>>
144    where
145        B: hyper::body::Body + Send + 'static, // B must implement Body and be sendable
146        B::Data: Send,                         // B::Data must be sendable
147        B::Error: Into<Box<dyn std::error::Error + Send + Sync>>, // B::Error must be convertible to a boxed error
148    {
149        let stream = self.create_stream(req.uri()).await?;
150
151        let (mut request_sender, connection) =
152            hyper::client::conn::http1::handshake(TokioIo::new(stream)).await?;
153
154        // Spawn a task to poll the connection and drive the HTTP state
155        tokio::spawn(async move {
156            if let Err(e) = connection.await {
157                eprintln!("Error: {e:?}");
158            }
159        });
160
161        let mut final_req_builder = Request::builder().uri(req.uri()).method(req.method());
162
163        for (key, value) in req.headers() {
164            final_req_builder = final_req_builder.header(key, value);
165        }
166
167        if !req.headers().contains_key(hyper::header::HOST) {
168            if let Some(authority) = req.uri().authority() {
169                let host_header_value = HeaderValue::from_str(authority.as_str()).unwrap();
170                final_req_builder =
171                    final_req_builder.header(hyper::header::HOST, host_header_value);
172            }
173        }
174
175        let final_req = final_req_builder.body(req.into_body())?;
176
177        let resp = request_sender.send_request(final_req).await?;
178
179        Ok(resp)
180    }
181
182    /// Creates a stream for the specified URI, optionally wrapping it with TLS.
183    async fn create_stream(
184        &self,
185        url: &Uri,
186    ) -> Result<Box<dyn AsyncReadWrite + Unpin + Send>, IoError> {
187        let host = url
188            .host()
189            .ok_or_else(|| IoError::new(std::io::ErrorKind::InvalidInput, "Missing host"))?;
190        let https = url.scheme() == Some(&Scheme::HTTPS);
191
192        let port = match url.port_u16() {
193            Some(port) => port,
194            None if https => 443,
195            None => 80,
196        };
197
198        // Establish the initial stream connection
199        let stream = self
200            .tor_client
201            .connect((host, port))
202            .await
203            .map_err(|e| IoError::new(std::io::ErrorKind::Other, e))?;
204
205        if https {
206            // Wrap the stream with TLS
207            let tls_connector = &self.config.tls_config;
208            let cx = tokio_native_tls::TlsConnector::from(tls_connector.clone());
209            let wrapped_stream = cx
210                .connect(host, stream)
211                .await
212                .map_err(|e| IoError::new(std::io::ErrorKind::Other, e))?;
213            Ok(Box::new(wrapped_stream) as Box<dyn AsyncReadWrite + Unpin + Send>)
214        } else {
215            // Return the unwrapped stream directly for HTTP
216            Ok(Box::new(stream) as Box<dyn AsyncReadWrite + Unpin + Send>)
217        }
218    }
219}