1use anyhow::Result;
2use arti_client::{TorClient, TorClientConfig};
3use http_body_util::{Empty, Full};
4use hyper::body::Bytes;
5use hyper::body::Incoming;
6use hyper::header::HeaderValue;
7use hyper::http::uri::Scheme;
8use hyper::{Request, Response, Uri};
9use hyper_util::rt::TokioIo;
10use std::io::Error as IoError;
11use tokio::io::{AsyncRead, AsyncWrite};
12use tokio_native_tls::native_tls::TlsConnector;
13use tor_rtcompat::PreferredRuntime;
14
15pub trait AsyncReadWrite: AsyncRead + AsyncWrite {}
17
18impl<T> AsyncReadWrite for T where T: AsyncRead + AsyncWrite {}
19
20pub struct ClientConfig {
22 pub tls_config: TlsConnector,
24 pub tor_config: TorClientConfig,
26}
27
28pub struct ClientConfigBuilder {
30 tls_config: Option<TlsConnector>,
31 tor_config: Option<TorClientConfig>,
32}
33
34impl ClientConfigBuilder {
35 pub fn new() -> Self {
37 ClientConfigBuilder {
38 tls_config: None,
39 tor_config: None,
40 }
41 }
42
43 pub fn tls_config(mut self, tls_config: TlsConnector) -> Self {
45 self.tls_config = Some(tls_config);
46 self
47 }
48
49 pub fn tor_config(mut self, tor_config: TorClientConfig) -> Self {
51 self.tor_config = Some(tor_config);
52 self
53 }
54
55 pub fn build(self) -> Result<ClientConfig> {
57 Ok(ClientConfig {
58 tls_config: self.tls_config.unwrap_or_else(|| {
59 TlsConnector::builder()
60 .build()
61 .expect("Failed to create default TlsConnector")
62 }),
63 tor_config: self.tor_config.unwrap_or_else(|| {
64 let mut cfg_builder = TorClientConfig::builder();
65 cfg_builder.address_filter().allow_onion_addrs(true);
66 cfg_builder
67 .build()
68 .expect("Failed to create default TorClientConfig")
69 }),
70 })
71 }
72}
73
74pub struct Client {
76 tor_client: TorClient<PreferredRuntime>,
77 config: ClientConfig,
78}
79
80impl Client {
81 pub async fn with_config(config: ClientConfig) -> Result<Self> {
83 let tor_client = Self::create_tor_client(&config).await?;
84 Ok(Client { tor_client, config })
85 }
86
87 pub async fn new() -> Result<Self> {
89 let default_config = ClientConfigBuilder::new().build()?;
90 Self::with_config(default_config).await
91 }
92
93 async fn create_tor_client(config: &ClientConfig) -> Result<TorClient<PreferredRuntime>> {
95 let tor_client = TorClient::create_bootstrapped(config.tor_config.clone()).await?;
96 Ok(tor_client)
97 }
98
99 pub async fn head<T>(&self, uri: T) -> Result<Response<Incoming>>
101 where
102 Uri: TryFrom<T>,
103 <Uri as TryFrom<T>>::Error: Into<hyper::http::Error>,
104 {
105 let req = Request::head(uri).body(Empty::<Bytes>::new())?;
106
107 let resp = self.send_request(req).await?;
108 Ok(resp)
109 }
110
111 pub async fn get<T>(&self, uri: T) -> Result<Response<Incoming>>
113 where
114 Uri: TryFrom<T>,
115 <Uri as TryFrom<T>>::Error: Into<hyper::http::Error>,
116 {
117 let req = Request::get(uri).body(Empty::<Bytes>::new())?;
118
119 let resp = self.send_request(req).await?;
120 Ok(resp)
121 }
122
123 pub async fn post<T>(
125 &self,
126 uri: T,
127 content_type: &str,
128 body: Bytes,
129 ) -> Result<Response<Incoming>>
130 where
131 Uri: TryFrom<T>,
132 <Uri as TryFrom<T>>::Error: Into<hyper::http::Error>,
133 {
134 let req = Request::post(uri)
135 .header(hyper::header::CONTENT_TYPE, content_type)
136 .body(Full::<Bytes>::from(body))?;
137
138 let resp = self.send_request(req).await?;
139 Ok(resp)
140 }
141
142 async fn send_request<B>(&self, req: Request<B>) -> Result<Response<Incoming>>
144 where
145 B: hyper::body::Body + Send + 'static, B::Data: Send, B::Error: Into<Box<dyn std::error::Error + Send + Sync>>, {
149 let stream = self.create_stream(req.uri()).await?;
150
151 let (mut request_sender, connection) =
152 hyper::client::conn::http1::handshake(TokioIo::new(stream)).await?;
153
154 tokio::spawn(async move {
156 if let Err(e) = connection.await {
157 eprintln!("Error: {e:?}");
158 }
159 });
160
161 let mut final_req_builder = Request::builder().uri(req.uri()).method(req.method());
162
163 for (key, value) in req.headers() {
164 final_req_builder = final_req_builder.header(key, value);
165 }
166
167 if !req.headers().contains_key(hyper::header::HOST) {
168 if let Some(authority) = req.uri().authority() {
169 let host_header_value = HeaderValue::from_str(authority.as_str()).unwrap();
170 final_req_builder =
171 final_req_builder.header(hyper::header::HOST, host_header_value);
172 }
173 }
174
175 let final_req = final_req_builder.body(req.into_body())?;
176
177 let resp = request_sender.send_request(final_req).await?;
178
179 Ok(resp)
180 }
181
182 async fn create_stream(
184 &self,
185 url: &Uri,
186 ) -> Result<Box<dyn AsyncReadWrite + Unpin + Send>, IoError> {
187 let host = url
188 .host()
189 .ok_or_else(|| IoError::new(std::io::ErrorKind::InvalidInput, "Missing host"))?;
190 let https = url.scheme() == Some(&Scheme::HTTPS);
191
192 let port = match url.port_u16() {
193 Some(port) => port,
194 None if https => 443,
195 None => 80,
196 };
197
198 let stream = self
200 .tor_client
201 .connect((host, port))
202 .await
203 .map_err(|e| IoError::new(std::io::ErrorKind::Other, e))?;
204
205 if https {
206 let tls_connector = &self.config.tls_config;
208 let cx = tokio_native_tls::TlsConnector::from(tls_connector.clone());
209 let wrapped_stream = cx
210 .connect(host, stream)
211 .await
212 .map_err(|e| IoError::new(std::io::ErrorKind::Other, e))?;
213 Ok(Box::new(wrapped_stream) as Box<dyn AsyncReadWrite + Unpin + Send>)
214 } else {
215 Ok(Box::new(stream) as Box<dyn AsyncReadWrite + Unpin + Send>)
217 }
218 }
219}