use std::{collections::HashMap, future::Future, io, marker::PhantomData, pin::Pin, sync::Arc};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio_vsock::{VsockAddr, VsockListener};
pub use crate::utils::CodingKey;
use crate::{Request, utils::Stream};
const VMADDR_CID_ANY: u32 = 0xFFFF_FFFF;
type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("Failed to bind to vsock address: {0}")]
Bind(io::Error),
#[error("Failed to accept connection: {0}")]
Accept(io::Error),
#[cfg(feature = "nsm")]
#[error("Failed to connect to NSM: {0}")]
NsmConnect(io::Error),
#[error("encoding failed: {0}")]
Encoding(rmp_serde::encode::Error),
#[error("decoding failed: {0}")]
Decoding(rmp_serde::decode::Error),
#[error("failed to write {0}: {1}")]
Writing(CodingKey, io::Error),
#[error("failed to read {0}: {1}")]
Reading(CodingKey, io::Error),
#[error("Unknown request type: 0x{0:08x}")]
UnknownRequest(u32),
}
trait Handler<S>: Send + Sync {
fn handle<'a>(&'a self, stream: &'a mut Stream, state: S) -> BoxFuture<'a, Result<(), Error>>;
}
struct TypedHandler<R, S, H, Fut>
where
R: Request,
H: Fn(S, R) -> Fut + Send + Sync,
Fut: Future<Output = R::Response> + Send,
{
handler: H, _phantom: PhantomData<(R, S)>, }
impl<R, S, H, Fut> Handler<S> for TypedHandler<R, S, H, Fut>
where
R: Request,
S: Clone + Send + Sync + 'static,
H: Fn(S, R) -> Fut + Send + Sync,
Fut: Future<Output = R::Response> + Send,
{
fn handle<'a>(&'a self, stream: &'a mut Stream, state: S) -> BoxFuture<'a, Result<(), Error>> {
Box::pin(async move {
let len = stream
.read_u64()
.await
.map_err(|e| Error::Reading(CodingKey::Length, e))?;
let payload = stream
.read_exact(len)
.await
.map_err(|e| Error::Reading(CodingKey::Payload, e))?;
let request: R = rmp_serde::from_slice(&payload).map_err(Error::Decoding)?;
let response = (self.handler)(state, request).await;
let response_bytes = rmp_serde::to_vec(&response).map_err(Error::Encoding)?;
stream
.write_u64(response_bytes.len() as u64)
.await
.map_err(|e| Error::Writing(CodingKey::Length, e))?;
stream
.write_all(&response_bytes)
.await
.map_err(|e| Error::Writing(CodingKey::Payload, e))?;
Ok(())
})
}
}
pub struct Router<S = ()> {
routes: HashMap<u32, Box<dyn Handler<S>>>, state: S, }
impl Router<()> {
#[must_use]
pub fn new() -> Self {
Self {
routes: HashMap::new(),
state: (),
}
}
}
impl Default for Router<()> {
fn default() -> Self {
Self::new()
}
}
impl<S> Router<S>
where
S: Clone + Send + Sync + 'static,
{
#[must_use]
pub fn with_state(state: S) -> Self {
Self {
routes: HashMap::new(),
state,
}
}
#[must_use]
pub fn route<R, H, Fut>(mut self, handler: H) -> Self
where
R: Request,
H: Fn(S, R) -> Fut + Send + Sync + 'static,
Fut: Future<Output = R::Response> + Send + 'static,
{
let type_id = R::type_id();
tracing::debug!(
route_id = R::ROUTE_ID,
type_id = format!("0x{:08x}", type_id),
"Registering route"
);
let typed_adapter = TypedHandler {
handler,
_phantom: PhantomData::<(R, S)>,
};
let boxed: Box<dyn Handler<S>> = Box::new(typed_adapter);
self.routes.insert(type_id, boxed);
self
}
pub async fn serve(self, port: u32) -> Result<(), Error> {
let listener =
VsockListener::bind(VsockAddr::new(VMADDR_CID_ANY, port)).map_err(Error::Bind)?;
tracing::info!("Router listening on port {port}");
#[cfg(feature = "nsm")]
{
crate::nsm::SecureModule::try_init_global()
.await
.map_err(Error::NsmConnect)?;
}
let router = Arc::new(self);
loop {
let (stream, _) = listener.accept().await.map_err(Error::Accept)?;
let mut stream = Stream::new(stream);
let router = router.clone();
tokio::spawn(async move {
if let Err(e) = handle_connection(&mut stream, router).await {
tracing::error!("Failed to handle request: {e}");
}
});
}
}
}
async fn handle_connection<S>(stream: &mut Stream, router: Arc<Router<S>>) -> Result<(), Error>
where
S: Clone + Send + Sync + 'static,
{
let type_id = stream
.read_u32()
.await
.map_err(|e| Error::Reading(CodingKey::Length, e))?;
let handler = router.routes.get(&type_id).ok_or_else(|| {
tracing::warn!(
type_id = format!("0x{:08x}", type_id),
"Unknown request type"
);
Error::UnknownRequest(type_id)
})?;
handler.handle(stream, router.state.clone()).await
}