use serde::{Deserialize, Serialize};
use std::{io, sync::Arc};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio_vsock::{VsockAddr, VsockListener};
pub use crate::utils::CodingKey;
use crate::utils::Stream;
const VMADDR_CID_ANY: u32 = 0xFFFF_FFFF;
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("Failed to bind to vsock address: {0}")]
Bind(io::Error),
#[cfg(feature = "nsm")]
#[error("Failed to connect to NSM: {0}")]
NsmConnect(io::Error),
#[error("encoding failed: {0}")]
Encoding(rmp_serde::encode::Error),
#[error("decoding failed: {0}")]
Decoding(rmp_serde::decode::Error),
#[error("failed to write {0}: {1}")]
Writing(CodingKey, io::Error),
#[error("failed to read {0}: {1}")]
Reading(CodingKey, io::Error),
}
pub async fn listen<Req, Res, Fut>(
port: u32,
process: impl Fn(Req) -> Fut + Send + Sync + 'static,
) -> Result<(), Error>
where
Res: Serialize + Send,
Req: for<'de> Deserialize<'de>,
Fut: Future<Output = Res> + Send,
{
listen_with_ctx(port, (), move |req, ()| process(req)).await
}
pub async fn listen_with_ctx<Req, Res, Ctx, Fut>(
port: u32,
context: Ctx,
process: impl Fn(Req, Ctx) -> Fut + Send + Sync + 'static,
) -> Result<(), Error>
where
Res: Serialize + Send,
Ctx: Clone + Send + 'static,
Req: for<'de> Deserialize<'de>,
Fut: Future<Output = Res> + Send,
{
let listener =
VsockListener::bind(VsockAddr::new(VMADDR_CID_ANY, port)).map_err(Error::Bind)?;
tracing::info!("started listening on port {port}");
#[cfg(feature = "nsm")]
{
match crate::SecureModule::connect() {
Ok(nsm) => {
crate::nsm::SECURE_MODULE_GLOBAL
.get_or_init(|| async { nsm })
.await
},
Err(e) => {
return Err(Error::NsmConnect(e));
},
};
}
let process = Arc::new(process);
loop {
let stream = match listener.accept().await {
Ok((stream, _)) => Stream::new(stream),
Err(e) => {
tracing::debug!("failed to accept connection: {e}");
continue;
},
};
tracing::trace!("spawning new task to handle connection");
let context = context.clone();
let process = process.clone();
tokio::spawn(async move {
match process_request(stream, context.clone(), process).await {
Ok(()) => tracing::debug!("request processed successfully"),
Err(e) => tracing::error!("failed to process request: {e}"),
}
});
}
}
async fn process_request<Req, Res, Ctx, Fut>(
mut stream: Stream,
context: Ctx,
process: Arc<impl Fn(Req, Ctx) -> Fut + Send + Sync>,
) -> Result<(), Error>
where
Ctx: Clone,
Res: Serialize + Send,
Req: for<'de> Deserialize<'de>,
Fut: Future<Output = Res> + Send,
{
let len = stream
.read_u64()
.await
.map_err(|e| Error::Reading(CodingKey::Length, e))?;
let payload = stream
.read_exact(len)
.await
.map_err(|e| Error::Reading(CodingKey::Payload, e))?;
let request = rmp_serde::from_slice(&payload).map_err(Error::Decoding)?;
let response = process(request, context).await;
let payload = rmp_serde::to_vec(&response).map_err(Error::Encoding)?;
stream
.write_u64(payload.len() as u64)
.await
.map_err(|e| Error::Writing(CodingKey::Length, e))?;
stream
.write_all(&payload)
.await
.map_err(|e| Error::Writing(CodingKey::Payload, e))
}