use std::net::SocketAddr;
use futures::future::join_all;
use tokio::{
io::copy_bidirectional,
net::{TcpListener, TcpStream},
task::JoinHandle,
};
use crate::io::SocketAddress;
use super::{IOError, Listener, Stream, StreamPool};
pub struct HostBridge {
stream_pool: StreamPool,
host_addr: SocketAddr,
}
impl HostBridge {
#[must_use]
pub fn new(stream_pool: StreamPool, host_addr: SocketAddr) -> Self {
assert!(
stream_pool.len() + usize::from(host_addr.port()) < u16::MAX.into()
);
Self { stream_pool, host_addr }
}
pub fn tcp_to_vsock(self) {
println!("starting tcp to vsock host bridge @ {}", self.host_addr);
tokio::spawn(async move {
let streams = self.stream_pool.to_streams();
let mut tasks = Vec::new();
let mut host_addr = self.host_addr;
for stream in streams {
println!("tcp to vsock bridge listening on {host_addr}");
tasks.push(tokio::spawn(tcp_to_vsock(stream, host_addr)));
host_addr.set_port(host_addr.port() + 1);
}
await_all(tasks).await;
});
}
pub fn vsock_to_tcp(self) {
println!("starting vsock to tcp host bridge @ {}", self.host_addr);
tokio::spawn(async move {
let listeners = self
.stream_pool
.listen()
.expect("unable to listen on vsock connections");
let mut tasks = Vec::new();
let mut host_addr = self.host_addr;
for listener in listeners {
println!(
"vsock to tcp bridge listening on {}",
listener.addr()
);
tasks.push(tokio::spawn(vsock_to_tcp(listener, host_addr)));
host_addr.set_port(host_addr.port() + 1);
}
await_all(tasks).await;
});
}
}
async fn await_all(tasks: Vec<JoinHandle<Result<(), IOError>>>) {
let results = join_all(tasks).await;
for result in results {
match result {
Err(err) => eprintln!("error on task joining: {err:?}"),
Ok(result) => match result {
Ok(()) => println!(
"tcp to vsock bridge host exit, no errors. This shouldn't happen"
),
Err(err) => eprintln!("error in task: {err:?}"),
},
}
}
}
async fn tcp_to_vsock(
enclave_stream: Stream,
host_addr: SocketAddr,
) -> Result<(), IOError> {
let listener = match TcpListener::bind(host_addr).await {
Ok(value) => value,
Err(err) => panic!("error binding to {host_addr}: {err}"),
};
loop {
let mut tcp_stream = match listener.accept().await {
Ok((value, _)) => value,
Err(err) => {
eprintln!(
"error accepting connection on tcp addr {host_addr}: {err:?}"
);
continue;
}
};
let mut stream = Stream::from(&enclave_stream);
tokio::spawn(async move {
if let Err(err) = stream.connect().await {
eprintln!(
"error connecting to VSOCK @ {} error: {err:?}",
stream
.address()
.unwrap_or(&SocketAddress::new_unix("unknown")),
);
return;
}
if let Err(err) =
copy_bidirectional(&mut tcp_stream, &mut stream).await
{
eprintln!("error on tcp to vsock stream bridge: {err:?}");
}
});
}
}
async fn vsock_to_tcp(
enclave_listener: Listener,
host_addr: SocketAddr,
) -> Result<(), IOError> {
loop {
let mut enclave_stream = match enclave_listener.accept().await {
Ok(value) => value,
Err(err) => {
eprintln!("error accepting connection on vsock: {err:?}");
continue;
}
};
tokio::spawn(async move {
let mut tcp_stream = match TcpStream::connect(host_addr).await {
Ok(value) => value,
Err(err) => {
eprintln!(
"error connecting to tcp addr {host_addr}: {err:?}"
);
return;
}
};
if let Err(err) =
copy_bidirectional(&mut enclave_stream, &mut tcp_stream).await
{
eprintln!("error on vsock to tcp stream bridge: {err:?}");
}
});
}
}