use core::net::SocketAddr;
use std::{
collections::HashSet,
error, fs,
io::{self, Read, Write},
path::Path,
sync::Arc,
};
use quinn::{Endpoint, Incoming, ServerConfig, crypto::rustls::QuicServerConfig};
use tracing::error;
use xitca_io::{
io::{AsyncIo, Interest},
net::TcpStream,
};
use xitca_unsafe_collection::futures::{Select, SelectOutput};
use super::driver::quic::QUIC_ALPN;
pub type Error = Box<dyn error::Error + Send + Sync>;
pub struct Proxy {
cfg: Result<ServerConfig, Error>,
upstream_addr: SocketAddr,
listen_addr: SocketAddr,
white_list: Option<HashSet<SocketAddr>>,
}
impl Proxy {
fn new(cfg: Result<ServerConfig, Error>) -> Self {
Self {
cfg,
upstream_addr: SocketAddr::from(([127, 0, 0, 1], 5432)),
listen_addr: SocketAddr::from(([0, 0, 0, 0], 5433)),
white_list: None,
}
}
pub fn with_cert(cert: impl AsRef<Path>, key: impl AsRef<Path>) -> Self {
Self::new(cfg_from_cert(cert, key))
}
pub fn with_config(cfg: ServerConfig) -> Self {
Self::new(Ok(cfg))
}
pub fn upstream_addr(mut self, addr: SocketAddr) -> Self {
self.upstream_addr = addr;
self
}
pub fn listen_addr(mut self, addr: SocketAddr) -> Self {
self.listen_addr = addr;
self
}
pub fn white_list(mut self, addrs: impl IntoIterator<Item = SocketAddr>) -> Self {
for addr in addrs.into_iter() {
self.white_list.get_or_insert_with(HashSet::new).insert(addr);
}
self
}
pub async fn run(self) -> Result<(), Error> {
let cfg = self.cfg?;
let listener = Endpoint::server(cfg, self.listen_addr)?;
let addr = self.upstream_addr;
while let Some(conn) = listener.accept().await {
if let Some(list) = self.white_list.as_ref() {
if !list.contains(&conn.remote_address()) {
continue;
}
}
tokio::spawn(async move {
if let Err(e) = listen_task(conn, addr).await {
error!("Proxy listen error: {e}");
}
});
}
Ok(())
}
}
fn cfg_from_cert(cert: impl AsRef<Path>, key: impl AsRef<Path>) -> Result<ServerConfig, Error> {
let cert = fs::read(cert)?;
let key = fs::read(key)?;
let key = rustls_pemfile::pkcs8_private_keys(&mut &*key).next().unwrap().unwrap();
let key = quinn::rustls::pki_types::PrivateKeyDer::from(key);
let cert = rustls_pemfile::certs(&mut &*cert).collect::<Result<_, _>>().unwrap();
let mut config = quinn::rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(cert, key)?;
config.alpn_protocols = vec![QUIC_ALPN.to_vec()];
let config = QuicServerConfig::try_from(config).unwrap();
Ok(ServerConfig::with_crypto(Arc::new(config)))
}
async fn listen_task(conn: Incoming, addr: SocketAddr) -> Result<(), Error> {
let conn = conn.await?;
let mut upstream = TcpStream::connect(addr).await?;
let (mut tx, mut rx) = conn.accept_bi().await?;
let mut buf = [0; 4096];
loop {
match rx.read(&mut buf).select(upstream.ready(Interest::READABLE)).await {
SelectOutput::A(Ok(Some(len))) => {
let mut off = 0;
while off != len {
match upstream.write(&buf[off..len]) {
Ok(0) => return Ok(()),
Ok(n) => off += n,
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
upstream.ready(Interest::WRITABLE).await?;
}
Err(e) => return Err(e.into()),
}
}
}
SelectOutput::B(Ok(_)) => 'inner: loop {
match upstream.read(&mut buf) {
Ok(0) => return Ok(()),
Ok(n) => tx.write_all(&buf[..n]).await?,
Err(e) if e.kind() == io::ErrorKind::WouldBlock => break 'inner,
Err(e) => return Err(e.into()),
}
},
SelectOutput::A(Err(e)) => return Err(e.into()),
SelectOutput::B(Err(e)) => return Err(e.into()),
_ => return Ok(()),
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn construct() {
let addr = "127.0.0.1:0".parse().unwrap();
let _ = Proxy::with_cert("", "")
.upstream_addr(addr)
.listen_addr(addr)
.white_list(vec![addr])
.white_list([addr]);
}
}