use anyhow::{bail, Context, Result};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use super::{BoxStream, Target};
pub async fn connect(
mut stream: BoxStream,
target: &Target,
username: Option<&str>,
password: Option<&str>,
) -> Result<BoxStream> {
let host_port = match target {
Target::Ip(addr, port) => format!("{addr}:{port}"),
Target::Host(host, port) => format!("{host}:{port}"),
};
let mut req = format!("CONNECT {host_port} HTTP/1.0\r\nHost: {host_port}\r\n");
if let (Some(u), Some(p)) = (username, password) {
let credentials = base64_encode(&format!("{u}:{p}"));
req.push_str(&format!("Proxy-Authorization: Basic {credentials}\r\n"));
}
req.push_str("\r\n");
stream
.write_all(req.as_bytes())
.await
.context("http: write CONNECT")?;
let mut response = Vec::<u8>::with_capacity(256);
loop {
let b = stream.read_u8().await.context("http: read response")?;
response.push(b);
if response.ends_with(b"\r\n\r\n") {
break;
}
if response.len() > 8192 {
bail!("http: response headers too large");
}
}
let header_line = response
.split(|&b| b == b'\n')
.next()
.context("http: empty response")?;
let header_str = std::str::from_utf8(header_line)
.context("http: non-UTF-8 status line")?
.trim();
let status_code: u16 = header_str
.split_whitespace()
.nth(1)
.context("http: missing status code")?
.parse()
.context("http: invalid status code")?;
if !(200..300).contains(&status_code) {
bail!("http: CONNECT failed with status {status_code}");
}
Ok(stream)
}
fn base64_encode(s: &str) -> String {
const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let input = s.as_bytes();
let mut out = String::new();
for chunk in input.chunks(3) {
let b0 = chunk[0] as u32;
let b1 = if chunk.len() > 1 { chunk[1] as u32 } else { 0 };
let b2 = if chunk.len() > 2 { chunk[2] as u32 } else { 0 };
let n = (b0 << 16) | (b1 << 8) | b2;
out.push(CHARS[((n >> 18) & 0x3f) as usize] as char);
out.push(CHARS[((n >> 12) & 0x3f) as usize] as char);
out.push(if chunk.len() > 1 {
CHARS[((n >> 6) & 0x3f) as usize] as char
} else {
'='
});
out.push(if chunk.len() > 2 {
CHARS[(n & 0x3f) as usize] as char
} else {
'='
});
}
out
}