use std::{net::SocketAddr, path::Path, fmt::Debug};
use anyhow::Ok;
use crate::coding::{Rframe, Tframe};
use s2n_quic::{
client::{Client, Connect},
provider::tls,
};
use slog_scope::{debug, error};
use crate::async_wire_format::AsyncWireFormatExt;
#[derive(Debug, Clone)]
pub struct DialQuic {
host: String,
port: u16,
client_cert: Box<Path>,
key: Box<Path>,
ca_cert: Box<Path>,
hostname: String,
}
impl DialQuic {
pub fn new(
host: String,
port: u16,
cert: Box<Path>,
key: Box<Path>,
ca_cert: Box<Path>,
hostname: String,
) -> Self {
Self {
host,
port,
client_cert: cert,
key,
ca_cert,
hostname,
}
}
}
impl DialQuic {
async fn dial(self) -> anyhow::Result<s2n_quic::Connection> {
let ca_cert = self.ca_cert.to_str().unwrap();
let client_cert = self.client_cert.to_str().unwrap();
let client_key = self.key.to_str().unwrap();
let tls = tls::default::Client::builder()
.with_certificate(Path::new(ca_cert))?
.with_client_identity(
Path::new(client_cert),
Path::new(client_key),
)?
.build()?;
let client = Client::builder()
.with_tls(tls)?
.with_io("0.0.0.0:0")?
.start()?;
let host_port = format!("{}:{}", self.host, self.port);
let addr: SocketAddr = host_port.parse()?;
let connect = Connect::new(addr).with_server_name(&*self.hostname);
let mut connection = client.connect(connect).await?;
connection.keep_alive(true)?;
Ok(connection)
}
}
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_vsock::{VsockListener, VsockAddr};
#[async_trait::async_trait]
pub trait ListenerStream: Send + Sync + Debug+ 'static{
type Stream: AsyncRead + AsyncWrite + Unpin + Send + Sync;
type Addr: std::fmt::Debug;
async fn accept(&mut self) -> std::io::Result<(Self::Stream, Self::Addr)>;
}
#[async_trait::async_trait]
impl ListenerStream for tokio::net::UnixListener {
type Stream = tokio::net::UnixStream;
type Addr = tokio::net::unix::SocketAddr;
async fn accept(&mut self) -> std::io::Result<(Self::Stream, Self::Addr)> {
tokio::net::UnixListener::accept(self).await
}
}
#[async_trait::async_trait]
impl ListenerStream for VsockListener {
type Stream = tokio_vsock::VsockStream;
type Addr = VsockAddr;
async fn accept(&mut self) -> std::io::Result<(Self::Stream, Self::Addr)> {
VsockListener::accept(self).await
}
}
pub struct Proxy<L>
where
L: ListenerStream,
{
dial: DialQuic,
listener: L,
}
impl<L> Proxy<L>
where
L: ListenerStream,
{
pub fn new(dial: DialQuic, listener: L) -> Self {
Self { dial, listener }
}
pub async fn run(&mut self) {
debug!("Listening on {:?}", self.listener);
while let std::result::Result::Ok((down_stream, _)) = self.listener.accept().await {
debug!("Accepted connection from");
let down_stream = down_stream;
let dial = self.dial.clone();
tokio::spawn(async move {
debug!("Dialing {:?}", dial);
let mut dial = dial.clone().dial().await.unwrap();
debug!("Connected to {:?}", dial.remote_addr());
let up_stream = dial.open_bidirectional_stream().await.unwrap();
let (rx, mut tx) = up_stream.split();
let (read, mut write) = tokio::io::split(down_stream);
let mut upstream_reader = tokio::io::BufReader::new(rx);
let mut downstream_reader = tokio::io::BufReader::new(read);
loop {
{
debug!("Reading from down_stream");
let tframe = Tframe::decode_async(&mut downstream_reader).await;
if let Err(e) = tframe {
if e.kind() == std::io::ErrorKind::UnexpectedEof {
break;
} else {
error!("Error decoding from down_stream: {:?}", e);
break;
}
} else if let std::io::Result::Ok(tframe) = tframe {
debug!("Sending to up_stream {:?}", tframe);
tframe.encode_async(&mut tx).await.unwrap();
}
}
{
debug!("Reading from up_stream");
let rframe = Rframe::decode_async(&mut upstream_reader).await.unwrap();
debug!("Sending to down_stream");
rframe.encode_async(&mut write).await.unwrap();
}
}
});
}
}
}