use std::io::{stdout, BufWriter, Read, Write};
use std::net::{SocketAddr, TcpStream, UdpSocket};
use std::thread::{spawn, Builder, JoinHandle};
use mproxy_client::target_socket_interface;
use mproxy_server::upstream_socket_interface;
const BUFSIZE: usize = 8096;
pub fn forward_udp(listen_addr: String, downstream_addrs: &[String], tee: bool) -> JoinHandle<()> {
let (_addr, listen_socket) =
upstream_socket_interface(listen_addr).expect("binding server socket listener");
let mut output_buffer = BufWriter::new(stdout());
let targets: Vec<(SocketAddr, UdpSocket)> = downstream_addrs
.iter()
.map(|t| target_socket_interface(t).expect("binding client socket sender"))
.collect();
let mut buf = [0u8; BUFSIZE]; Builder::new()
.name(format!("{:#?}", listen_socket))
.spawn(move || {
listen_socket.set_broadcast(true).unwrap();
loop {
match listen_socket.recv_from(&mut buf[0..]) {
Ok((c, _remote_addr)) => {
for (target_addr, target_socket) in &targets {
if !(target_addr.is_ipv6() && target_addr.ip().is_multicast()) {
target_socket
.send_to(&buf[0..c], target_addr)
.unwrap_or_else(|e| panic!("sending to server socket: {}", e));
} else {
target_socket
.send(&buf[0..c])
.unwrap_or_else(|e| panic!("sending to server socket: {}", e));
}
}
if tee {
let _o = output_buffer
.write(&buf[0..c])
.expect("writing to output buffer");
#[cfg(debug_assertions)]
assert!(c == _o);
}
}
Err(err) => {
eprintln!("forward_udp: got an error: {}", err);
#[cfg(debug_assertions)]
panic!("forward_udp: got an error: {}", err);
}
}
output_buffer.flush().unwrap();
}
})
.unwrap()
}
pub fn proxy_gateway(
downstream_addrs: &[String],
listen_addrs: &[String],
tee: bool,
) -> Vec<JoinHandle<()>> {
let mut threads: Vec<JoinHandle<()>> = vec![];
for listen_addr in listen_addrs {
#[cfg(debug_assertions)]
println!(
"proxy: forwarding {:?} -> {:?}",
listen_addr, downstream_addrs
);
threads.push(forward_udp(listen_addr.to_string(), downstream_addrs, tee));
}
threads
}
pub fn proxy_tcp_udp(upstream_tcp: String, downstream_udp: String) -> JoinHandle<()> {
let mut buf = [0u8; BUFSIZE];
#[cfg(debug_assertions)]
println!(
"proxy: forwarding TCP {:?} -> UDP {:?}",
upstream_tcp, downstream_udp
);
spawn(move || loop {
let target = target_socket_interface(&downstream_udp);
let (target_addr, target_socket) = if let Ok((target_addr, target_socket)) = target {
(target_addr, target_socket)
} else {
println!("Retrying...");
std::thread::sleep(std::time::Duration::from_secs(5));
continue;
};
#[cfg(feature = "tls")]
let (mut conn, mut stream) =
if let Ok((conn, stream)) = tls_connection(upstream_tcp.clone()) {
(conn, stream)
} else {
println!("Retrying...");
std::thread::sleep(std::time::Duration::from_secs(5));
continue;
};
#[cfg(feature = "tls")]
let mut stream = TlsStream::new(&mut conn, &mut stream);
#[cfg(not(feature = "tls"))]
let stream = TcpStream::connect(upstream_tcp.clone());
#[cfg(not(feature = "tls"))]
let mut stream = if let Ok(s) = stream {
s
} else {
println!("Retrying...");
std::thread::sleep(std::time::Duration::from_secs(5));
continue;
};
loop {
match stream.read(&mut buf[0..]) {
Ok(c) => {
if c == 0 {
eprintln!("encountered EOF, disconnecting TCP proxy thread...");
break;
}
if !(target_addr.is_ipv6() && target_addr.ip().is_multicast()) {
target_socket
.send_to(&buf[0..c], target_addr)
.expect("sending to UDP socket");
} else {
target_socket
.send(&buf[0..c])
.expect("sending to UDP socket");
}
}
Err(e) => {
eprintln!("err: {}", e);
break;
}
}
}
println!("Retrying...");
std::thread::sleep(std::time::Duration::from_secs(5))
})
}
#[cfg(feature = "tls")]
use rustls::client::{ClientConfig, ClientConnection, ServerName};
#[cfg(feature = "tls")]
use rustls::Stream as TlsStream;
#[cfg(feature = "tls")]
use std::sync::Arc;
#[cfg(feature = "tls")]
use webpki_roots::TLS_SERVER_ROOTS;
#[cfg(feature = "tls")]
pub fn tls_connection(
tls_connect_addr: String,
) -> Result<(ClientConnection, TcpStream), Box<dyn std::error::Error>> {
let mut root_store = rustls::RootCertStore::empty();
root_store.add_server_trust_anchors(TLS_SERVER_ROOTS.0.iter().map(|ta| {
rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}));
let config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
let rc_config: Arc<ClientConfig> = Arc::new(config);
let dns_name: String = tls_connect_addr.split(':').next().unwrap().to_string();
let server_name = ServerName::try_from(dns_name.as_str());
let server_name = if let Ok(name) = server_name {
name
} else {
return Err(format!("Resolving DNS for {}", dns_name).into());
};
let conn = rustls::ClientConnection::new(rc_config, server_name);
let mut conn = if let Ok(c) = conn {
c
} else {
return Err("Performing handshake".into());
};
let sock = TcpStream::connect(tls_connect_addr.clone());
let sock = if let Ok(s) = sock {
s
} else {
return Err(format!("Connecting to {}", tls_connect_addr).into());
};
sock.set_nodelay(true).unwrap();
let request = format!(
"GET / HTTP/1.1\r\n\
Host: {}\r\n\
Connection: close\r\n\
Accept-Encoding: identity\r\n\
\r\n",
tls_connect_addr
);
if let Some(mut early_data) = conn.early_data() {
early_data.write_all(request.as_bytes()).unwrap();
}
Ok((conn, sock))
}