use std::sync::{
atomic::{AtomicU32, Ordering},
Arc,
};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream},
};
use tokio_rustls::TlsAcceptor;
pub async fn echo_server() -> (u16, Arc<AtomicU32>) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
let hits = Arc::new(AtomicU32::new(0));
let hits2 = hits.clone();
tokio::spawn(async move {
loop {
let (mut stream, _) = listener.accept().await.unwrap();
hits2.fetch_add(1, Ordering::SeqCst);
tokio::spawn(async move {
let mut buf = vec![0u8; 4096];
loop {
let n = stream.read(&mut buf).await.unwrap_or(0);
if n == 0 {
break;
}
if stream.write_all(&buf[..n]).await.is_err() {
break;
}
}
});
}
});
(port, hits)
}
pub async fn socks4_server(hits: Arc<AtomicU32>) -> u16 {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
tokio::spawn(async move {
loop {
let (mut client, _) = listener.accept().await.unwrap();
hits.fetch_add(1, Ordering::SeqCst);
tokio::spawn(async move {
let mut hdr = [0u8; 8];
client.read_exact(&mut hdr).await.unwrap();
assert_eq!(hdr[0], 4, "SOCKS version");
assert_eq!(hdr[1], 1, "CONNECT command");
let port = u16::from_be_bytes([hdr[2], hdr[3]]);
let ip_bytes = &hdr[4..8];
let mut username = Vec::new();
loop {
let b = client.read_u8().await.unwrap();
if b == 0 {
break;
}
username.push(b);
}
let target_addr = if ip_bytes[0] == 0
&& ip_bytes[1] == 0
&& ip_bytes[2] == 0
&& ip_bytes[3] != 0
{
let mut hostname = Vec::new();
loop {
let b = client.read_u8().await.unwrap();
if b == 0 {
break;
}
hostname.push(b);
}
let hostname = String::from_utf8(hostname).unwrap();
format!("{hostname}:{port}")
} else {
let ip =
std::net::Ipv4Addr::new(ip_bytes[0], ip_bytes[1], ip_bytes[2], ip_bytes[3]);
format!("{ip}:{port}")
};
client.write_all(&[0, 90, 0, 0, 0, 0, 0, 0]).await.unwrap();
let mut upstream = TcpStream::connect(&target_addr).await.unwrap();
tokio::io::copy_bidirectional(&mut client, &mut upstream)
.await
.ok();
});
}
});
port
}
pub async fn socks5_server(
require_user: Option<(&'static str, &'static str)>,
hits: Arc<AtomicU32>,
) -> u16 {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
tokio::spawn(async move {
loop {
let (mut client, _) = listener.accept().await.unwrap();
hits.fetch_add(1, Ordering::SeqCst);
tokio::spawn(async move {
let ver = client.read_u8().await.unwrap();
assert_eq!(ver, 5);
let n_methods = client.read_u8().await.unwrap() as usize;
let mut methods = vec![0u8; n_methods];
client.read_exact(&mut methods).await.unwrap();
if let Some((exp_user, exp_pass)) = require_user {
client.write_all(&[5, 2]).await.unwrap();
let _ver = client.read_u8().await.unwrap();
let ulen = client.read_u8().await.unwrap() as usize;
let mut user = vec![0u8; ulen];
client.read_exact(&mut user).await.unwrap();
let plen = client.read_u8().await.unwrap() as usize;
let mut pass = vec![0u8; plen];
client.read_exact(&mut pass).await.unwrap();
if user == exp_user.as_bytes() && pass == exp_pass.as_bytes() {
client.write_all(&[1, 0]).await.unwrap(); } else {
client.write_all(&[1, 1]).await.unwrap(); return;
}
} else {
client.write_all(&[5, 0]).await.unwrap(); }
let mut req = [0u8; 4];
client.read_exact(&mut req).await.unwrap();
assert_eq!(req[0], 5);
assert_eq!(req[1], 1); let atype = req[3];
let target = match atype {
1 => {
let mut ip = [0u8; 4];
client.read_exact(&mut ip).await.unwrap();
let port_bytes = [
client.read_u8().await.unwrap(),
client.read_u8().await.unwrap(),
];
let port = u16::from_be_bytes(port_bytes);
format!("{}:{}", std::net::Ipv4Addr::from(ip), port)
}
4 => {
let mut ip = [0u8; 16];
client.read_exact(&mut ip).await.unwrap();
let port_bytes = [
client.read_u8().await.unwrap(),
client.read_u8().await.unwrap(),
];
let port = u16::from_be_bytes(port_bytes);
format!("[{}]:{}", std::net::Ipv6Addr::from(ip), port)
}
3 => {
let len = client.read_u8().await.unwrap() as usize;
let mut host = vec![0u8; len];
client.read_exact(&mut host).await.unwrap();
let port_bytes = [
client.read_u8().await.unwrap(),
client.read_u8().await.unwrap(),
];
let port = u16::from_be_bytes(port_bytes);
format!("{}:{}", String::from_utf8(host).unwrap(), port)
}
_ => panic!("unknown atype {atype}"),
};
client
.write_all(&[5, 0, 0, 1, 0, 0, 0, 0, 0, 0])
.await
.unwrap();
let mut upstream = TcpStream::connect(&target).await.unwrap();
tokio::io::copy_bidirectional(&mut client, &mut upstream)
.await
.ok();
});
}
});
port
}
pub async fn http_connect_server(
require_auth: Option<(&'static str, &'static str)>,
hits: Arc<AtomicU32>,
) -> u16 {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
tokio::spawn(async move {
loop {
let (mut client, _) = listener.accept().await.unwrap();
hits.fetch_add(1, Ordering::SeqCst);
tokio::spawn(async move {
let mut lines: Vec<String> = Vec::new();
let mut buf = Vec::new();
let mut b_prev = 0u8;
loop {
let b = client.read_u8().await.unwrap();
if b == b'\n' {
let line = String::from_utf8(buf.clone()).unwrap();
let line = line.trim_end_matches('\r').to_owned();
if line.is_empty() {
break;
}
lines.push(line);
buf.clear();
} else {
buf.push(b);
}
b_prev = b;
}
let _ = b_prev;
let req_line = lines.first().unwrap();
let mut parts = req_line.split_whitespace();
assert_eq!(parts.next().unwrap(), "CONNECT");
let target = parts.next().unwrap().to_owned();
if let Some((exp_user, exp_pass)) = require_auth {
let expected =
format!("Basic {}", base64_encode(&format!("{exp_user}:{exp_pass}")));
let auth_header = lines
.iter()
.find(|l| l.to_ascii_lowercase().starts_with("proxy-authorization:"));
let authed = auth_header
.map(|h| h.split_once(':').unwrap().1.trim() == expected)
.unwrap_or(false);
if !authed {
client
.write_all(b"HTTP/1.0 407 Proxy Authentication Required\r\n\r\n")
.await
.unwrap();
return;
}
}
client
.write_all(b"HTTP/1.0 200 Connection established\r\n\r\n")
.await
.unwrap();
let mut upstream = TcpStream::connect(&target).await.unwrap();
tokio::io::copy_bidirectional(&mut client, &mut upstream)
.await
.ok();
});
}
});
port
}
pub async fn tls_connect_server(
require_auth: Option<(&'static str, &'static str)>,
hits: Arc<AtomicU32>,
) -> u16 {
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
use rustls::ServerConfig;
use std::sync::Arc as StdArc;
let rcgen::CertifiedKey { cert, key_pair } =
rcgen::generate_simple_self_signed(vec!["localhost".to_string()]).unwrap();
let cert_der = CertificateDer::from(cert.der().to_vec());
let key_der = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_pair.serialize_der()));
let server_config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(vec![cert_der], key_der)
.unwrap();
let acceptor = TlsAcceptor::from(StdArc::new(server_config));
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
tokio::spawn(async move {
loop {
let (tcp, _) = listener.accept().await.unwrap();
hits.fetch_add(1, Ordering::SeqCst);
let acceptor = acceptor.clone();
tokio::spawn(async move {
let mut stream = acceptor.accept(tcp).await.unwrap();
let mut lines: Vec<String> = Vec::new();
let mut buf: Vec<u8> = Vec::new();
loop {
let b = stream.read_u8().await.unwrap();
if b == b'\n' {
let line = String::from_utf8(buf.clone())
.unwrap()
.trim_end_matches('\r')
.to_owned();
if line.is_empty() {
break;
}
lines.push(line);
buf.clear();
} else {
buf.push(b);
}
}
let req_line = lines.first().unwrap();
let mut parts = req_line.split_whitespace();
assert_eq!(parts.next().unwrap(), "CONNECT");
let target = parts.next().unwrap().to_owned();
if let Some((exp_user, exp_pass)) = require_auth {
let expected =
format!("Basic {}", base64_encode(&format!("{exp_user}:{exp_pass}")));
let authed = lines.iter().any(|l| {
l.to_ascii_lowercase().starts_with("proxy-authorization:")
&& l.split_once(':')
.map(|(_, v)| v.trim() == expected)
.unwrap_or(false)
});
if !authed {
stream
.write_all(b"HTTP/1.0 407 Proxy Authentication Required\r\n\r\n")
.await
.unwrap();
return;
}
}
stream
.write_all(b"HTTP/1.0 200 Connection established\r\n\r\n")
.await
.unwrap();
let mut upstream = TcpStream::connect(&target).await.unwrap();
tokio::io::copy_bidirectional(&mut stream, &mut upstream)
.await
.ok();
});
}
});
port
}
fn base64_encode(s: &str) -> String {
const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let input = s.as_bytes();
let mut out = String::new();
for chunk in input.chunks(3) {
let b0 = chunk[0] as u32;
let b1 = if chunk.len() > 1 { chunk[1] as u32 } else { 0 };
let b2 = if chunk.len() > 2 { chunk[2] as u32 } else { 0 };
let n = (b0 << 16) | (b1 << 8) | b2;
out.push(CHARS[((n >> 18) & 0x3f) as usize] as char);
out.push(CHARS[((n >> 12) & 0x3f) as usize] as char);
out.push(if chunk.len() > 1 {
CHARS[((n >> 6) & 0x3f) as usize] as char
} else {
'='
});
out.push(if chunk.len() > 2 {
CHARS[(n & 0x3f) as usize] as char
} else {
'='
});
}
out
}