1use std::future::Future;
11#[cfg(feature = "socks")]
12use std::net::IpAddr;
13use std::pin::Pin;
14use std::task::{Context, Poll};
15use std::time::Duration;
16
17use base64::engine::general_purpose::STANDARD as BASE64;
18use base64::Engine as _;
19use http::Uri;
20use hyper_util::rt::TokioIo;
21use tokio::io::{AsyncReadExt, AsyncWriteExt};
22use tokio::net::TcpStream;
23use tower_service::Service;
24
25use oxihttp_core::OxiHttpError;
26
27fn uri_host_port(uri: &Uri) -> Result<(String, u16), OxiHttpError> {
35 let host = uri
36 .host()
37 .ok_or_else(|| OxiHttpError::ConnectionPool(format!("URI has no host: {uri}")))?
38 .to_owned();
39 let port = match uri.port_u16() {
40 Some(p) => p,
41 None => match uri.scheme_str() {
42 Some("https") => 443u16,
43 Some("http") => 80u16,
44 _ => {
45 return Err(OxiHttpError::ConnectionPool(format!(
46 "URI has no port and unknown scheme: {uri}"
47 )))
48 }
49 },
50 };
51 Ok((host, port))
52}
53
54fn extract_auth(uri: &Uri) -> Option<(String, String)> {
56 let authority = uri.authority()?;
57 let userinfo = authority.as_str().split('@').next()?;
58 if !authority.as_str().contains('@') {
60 return None;
61 }
62 let (user, pass) = userinfo.split_once(':')?;
63 if user.is_empty() {
64 return None;
65 }
66 Some((user.to_owned(), pass.to_owned()))
67}
68
69#[derive(Clone, Debug)]
75pub enum ProxyKind {
76 HttpConnect(Uri),
78 #[cfg(feature = "socks")]
80 Socks5(Uri),
81}
82
83#[derive(Clone, Debug)]
97pub struct ProxyConnector {
98 proxy_uri: Uri,
99 connect_timeout: Option<Duration>,
100 auth: Option<(String, String)>,
101}
102
103impl ProxyConnector {
104 pub fn new(proxy_uri: Uri, connect_timeout: Option<Duration>) -> Self {
110 let auth = extract_auth(&proxy_uri);
111 Self {
112 proxy_uri,
113 connect_timeout,
114 auth,
115 }
116 }
117}
118
119impl Service<Uri> for ProxyConnector {
120 type Response = TokioIo<TcpStream>;
121 type Error = OxiHttpError;
122 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
123
124 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
125 Poll::Ready(Ok(()))
126 }
127
128 fn call(&mut self, uri: Uri) -> Self::Future {
129 let proxy_uri = self.proxy_uri.clone();
130 let connect_timeout = self.connect_timeout;
131 let auth = self.auth.clone();
132
133 Box::pin(async move {
134 let (proxy_host, proxy_port) = uri_host_port(&proxy_uri)?;
136 let proxy_addr = format!("{proxy_host}:{proxy_port}");
137
138 let (target_host, target_port) = uri_host_port(&uri)?;
140 let target_authority = format!("{target_host}:{target_port}");
141
142 let stream = if let Some(timeout) = connect_timeout {
144 tokio::time::timeout(timeout, TcpStream::connect(&proxy_addr))
145 .await
146 .map_err(|_| {
147 OxiHttpError::Timeout(format!(
148 "proxy connect timeout after {}ms",
149 timeout.as_millis()
150 ))
151 })??
152 } else {
153 TcpStream::connect(&proxy_addr).await?
154 };
155
156 let mut stream = stream;
157
158 let mut req =
160 format!("CONNECT {target_authority} HTTP/1.1\r\nHost: {target_authority}\r\n");
161 if let Some((user, pass)) = &auth {
162 let credentials = format!("{user}:{pass}");
163 let encoded = BASE64.encode(credentials.as_bytes());
164 req.push_str(&format!("Proxy-Authorization: Basic {encoded}\r\n"));
165 }
166 req.push_str("\r\n");
167
168 stream.write_all(req.as_bytes()).await?;
169
170 let mut response_buf = Vec::with_capacity(256);
172 let mut single = [0u8; 1];
173 loop {
174 let n = stream.read(&mut single).await?;
175 if n == 0 {
176 return Err(OxiHttpError::ConnectionPool(
177 "proxy closed connection during CONNECT handshake".to_owned(),
178 ));
179 }
180 response_buf.push(single[0]);
181 if response_buf.ends_with(b"\r\n\r\n") {
182 break;
183 }
184 if response_buf.len() > 8192 {
185 return Err(OxiHttpError::ConnectionPool(
186 "proxy CONNECT response too large".to_owned(),
187 ));
188 }
189 }
190
191 let first_line = response_buf
193 .split(|&b| b == b'\n')
194 .next()
195 .and_then(|l| std::str::from_utf8(l).ok())
196 .unwrap_or("");
197
198 if !first_line.contains("200") {
199 return Err(OxiHttpError::ConnectionPool(format!(
200 "proxy CONNECT rejected: {first_line}"
201 )));
202 }
203
204 Ok(TokioIo::new(stream))
205 })
206 }
207}
208
209#[cfg(feature = "socks")]
219#[derive(Clone, Debug)]
220pub struct Socks5Connector {
221 proxy_uri: Uri,
222 connect_timeout: Option<Duration>,
223 auth: Option<(String, String)>,
224}
225
226#[cfg(feature = "socks")]
227impl Socks5Connector {
228 pub fn new(proxy_uri: Uri, connect_timeout: Option<Duration>) -> Self {
233 let auth = extract_auth(&proxy_uri);
234 Self {
235 proxy_uri,
236 connect_timeout,
237 auth,
238 }
239 }
240}
241
242#[cfg(feature = "socks")]
243impl Service<Uri> for Socks5Connector {
244 type Response = TokioIo<TcpStream>;
245 type Error = OxiHttpError;
246 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
247
248 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
249 Poll::Ready(Ok(()))
250 }
251
252 fn call(&mut self, uri: Uri) -> Self::Future {
253 let proxy_uri = self.proxy_uri.clone();
254 let connect_timeout = self.connect_timeout;
255 let auth = self.auth.clone();
256
257 Box::pin(async move {
258 let (proxy_host, proxy_port) = uri_host_port(&proxy_uri)?;
259 let proxy_addr = format!("{proxy_host}:{proxy_port}");
260 let (target_host, target_port) = uri_host_port(&uri)?;
261
262 let stream = if let Some(timeout) = connect_timeout {
264 tokio::time::timeout(timeout, TcpStream::connect(&proxy_addr))
265 .await
266 .map_err(|_| {
267 OxiHttpError::Timeout(format!(
268 "SOCKS5 proxy connect timeout after {}ms",
269 timeout.as_millis()
270 ))
271 })??
272 } else {
273 TcpStream::connect(&proxy_addr).await?
274 };
275 let mut stream = stream;
276
277 let (nmethods, methods): (u8, Vec<u8>) = if auth.is_some() {
279 (2, vec![0x00, 0x02]) } else {
281 (1, vec![0x00]) };
283 let mut greeting = vec![0x05, nmethods];
284 greeting.extend_from_slice(&methods);
285 stream.write_all(&greeting).await?;
286
287 let mut method_resp = [0u8; 2];
289 stream.read_exact(&mut method_resp).await?;
290 if method_resp[0] != 0x05 {
291 return Err(OxiHttpError::ConnectionPool(
292 "SOCKS5 greeting response has wrong version byte".to_owned(),
293 ));
294 }
295 let selected = method_resp[1];
296 if selected == 0xFF {
297 return Err(OxiHttpError::ConnectionPool(
298 "SOCKS5 proxy rejected all authentication methods".to_owned(),
299 ));
300 }
301
302 if selected == 0x02 {
304 let (user, pass) = auth.as_ref().ok_or_else(|| {
305 OxiHttpError::ConnectionPool(
306 "SOCKS5 proxy requires authentication but none configured".to_owned(),
307 )
308 })?;
309 let user_bytes = user.as_bytes();
310 let pass_bytes = pass.as_bytes();
311 let mut auth_req = Vec::with_capacity(3 + user_bytes.len() + pass_bytes.len());
312 auth_req.push(0x01); auth_req.push(user_bytes.len() as u8);
314 auth_req.extend_from_slice(user_bytes);
315 auth_req.push(pass_bytes.len() as u8);
316 auth_req.extend_from_slice(pass_bytes);
317 stream.write_all(&auth_req).await?;
318
319 let mut auth_resp = [0u8; 2];
320 stream.read_exact(&mut auth_resp).await?;
321 if auth_resp[1] != 0x00 {
322 return Err(OxiHttpError::ConnectionPool(
323 "SOCKS5 authentication failed".to_owned(),
324 ));
325 }
326 }
327
328 let (atyp, addr_bytes): (u8, Vec<u8>) = match target_host.parse::<IpAddr>() {
331 Ok(IpAddr::V4(v4)) => (0x01, v4.octets().to_vec()),
332 Ok(IpAddr::V6(v6)) => (0x04, v6.octets().to_vec()),
333 Err(_) => {
334 let host_bytes = target_host.as_bytes();
336 let len = host_bytes.len() as u8;
337 let mut b = Vec::with_capacity(1 + host_bytes.len());
338 b.push(len);
339 b.extend_from_slice(host_bytes);
340 (0x03, b)
341 }
342 };
343
344 let port_hi = (target_port >> 8) as u8;
345 let port_lo = (target_port & 0xFF) as u8;
346
347 let mut connect_req = vec![0x05, 0x01, 0x00, atyp];
348 connect_req.extend_from_slice(&addr_bytes);
349 connect_req.push(port_hi);
350 connect_req.push(port_lo);
351 stream.write_all(&connect_req).await?;
352
353 let mut reply_hdr = [0u8; 4]; stream.read_exact(&mut reply_hdr).await?;
356
357 let rep = reply_hdr[1];
358 if rep != 0x00 {
359 return Err(OxiHttpError::ConnectionPool(format!(
360 "SOCKS5 error code {rep:#04x}"
361 )));
362 }
363
364 let bnd_atyp = reply_hdr[3];
366 match bnd_atyp {
367 0x01 => {
368 let mut buf = [0u8; 4];
369 stream.read_exact(&mut buf).await?;
370 }
371 0x04 => {
372 let mut buf = [0u8; 16];
373 stream.read_exact(&mut buf).await?;
374 }
375 0x03 => {
376 let mut len_buf = [0u8; 1];
377 stream.read_exact(&mut len_buf).await?;
378 let mut domain_buf = vec![0u8; len_buf[0] as usize];
379 stream.read_exact(&mut domain_buf).await?;
380 }
381 other => {
382 return Err(OxiHttpError::ConnectionPool(format!(
383 "SOCKS5 reply has unknown ATYP {other:#04x}"
384 )));
385 }
386 }
387 let mut port_buf = [0u8; 2];
389 stream.read_exact(&mut port_buf).await?;
390
391 Ok(TokioIo::new(stream))
392 })
393 }
394}