use std::path::{Path, PathBuf};
use std::time::SystemTime;
use serde::Deserialize;
use tokio::io::AsyncReadExt;
use tokio::net::{UnixListener, UnixStream};
use crate::host_stamp;
use crate::keepalive::KeepAlive;
use crate::{
GuestDeclaration, HostStamp, StampedDeclaration, TelemetryError, VSOCK_TELEMETRY_PORT,
WIRE_CONTENT_VERSION_MAJOR,
};
pub const MAX_FRAME_BYTES: u32 = 64 * 1024;
pub struct VsockUdsListener {
socket_path: PathBuf,
listener: UnixListener,
}
impl VsockUdsListener {
pub fn bind_for_cell(vsock_uds_base: &Path) -> Result<Self, TelemetryError> {
let socket_path = PathBuf::from(format!(
"{}_{}",
vsock_uds_base.display(),
VSOCK_TELEMETRY_PORT
));
let listener = UnixListener::bind(&socket_path).map_err(|e| {
TelemetryError::Bind(format!(
"bind telemetry UDS at {}: {e}",
socket_path.display()
))
})?;
Ok(Self {
socket_path,
listener,
})
}
pub fn socket_path(&self) -> &Path {
&self.socket_path
}
pub async fn accept(&self) -> Result<VsockUdsStream, TelemetryError> {
let (stream, _addr) = self.listener.accept().await.map_err(|e| {
TelemetryError::Bind(format!(
"accept telemetry connection at {}: {e}",
self.socket_path.display()
))
})?;
Ok(VsockUdsStream { stream })
}
}
pub struct VsockUdsStream {
stream: UnixStream,
}
impl VsockUdsStream {
pub async fn recv_stamped(
&mut self,
stamp: &HostStamp,
keepalive: &KeepAlive,
) -> Result<Option<StampedDeclaration>, TelemetryError> {
let guest = match self.recv_guest_declaration().await? {
Some(g) => g,
None => return Ok(None),
};
keepalive.notify_frame().await;
let per_frame_stamp = HostStamp {
cell_id: stamp.cell_id.clone(),
run_id: stamp.run_id.clone(),
host_received_at: SystemTime::now(),
spec_signature_hash: stamp.spec_signature_hash.clone(),
};
Ok(Some(host_stamp::stamp(guest, per_frame_stamp)))
}
pub async fn recv_guest_declaration(
&mut self,
) -> Result<Option<GuestDeclaration>, TelemetryError> {
let mut len_buf = [0u8; 4];
match self.stream.read_exact(&mut len_buf).await {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(TelemetryError::Wire(format!("read length prefix: {e}"))),
}
let len = u32::from_le_bytes(len_buf);
if len == 0 || len > MAX_FRAME_BYTES {
return Err(TelemetryError::Wire(format!(
"frame length {len} out of bounds (max {MAX_FRAME_BYTES})"
)));
}
let mut body = vec![0u8; len as usize];
self.stream
.read_exact(&mut body)
.await
.map_err(|e| TelemetryError::Wire(format!("read frame body: {e}")))?;
decode_frame(&body).map(Some)
}
}
#[derive(Debug, Deserialize)]
struct WireFrame {
content_version: u16,
probe_source: String,
guest_pid: u32,
guest_comm: String,
guest_monotonic_ns: u64,
}
pub fn decode_frame(body: &[u8]) -> Result<GuestDeclaration, TelemetryError> {
let frame: WireFrame = ciborium::de::from_reader(body)
.map_err(|e| TelemetryError::Wire(format!("CBOR decode: {e}")))?;
let frame_major = (frame.content_version >> 8) as u8;
let host_major = (WIRE_CONTENT_VERSION_MAJOR >> 8) as u8;
if frame_major != host_major {
return Err(TelemetryError::UnsupportedVersion(frame.content_version));
}
Ok(GuestDeclaration {
probe_source: frame.probe_source,
guest_pid: frame.guest_pid,
guest_comm: frame.guest_comm,
guest_monotonic_ns: frame.guest_monotonic_ns,
})
}
#[doc(hidden)]
pub fn encode_test_frame(
content_version: u16,
probe_source: &str,
guest_pid: u32,
guest_comm: &str,
guest_monotonic_ns: u64,
) -> Vec<u8> {
let body = serde_json::json!({
"content_version": content_version,
"probe_source": probe_source,
"guest_pid": guest_pid,
"guest_comm": guest_comm,
"guest_monotonic_ns": guest_monotonic_ns,
});
let mut cbor_bytes: Vec<u8> = Vec::new();
ciborium::ser::into_writer(&body, &mut cbor_bytes).expect("CBOR encode");
let mut framed = (cbor_bytes.len() as u32).to_le_bytes().to_vec();
framed.extend_from_slice(&cbor_bytes);
framed
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
use tokio::io::AsyncWriteExt;
use tokio::net::UnixStream;
#[test]
fn decode_rejects_unknown_major() {
let mut cbor_bytes: Vec<u8> = Vec::new();
let body = serde_json::json!({
"content_version": 0x0100u16,
"probe_source": "process.spawned",
"guest_pid": 1u32,
"guest_comm": "x",
"guest_monotonic_ns": 0u64,
});
ciborium::ser::into_writer(&body, &mut cbor_bytes).unwrap();
match decode_frame(&cbor_bytes) {
Err(TelemetryError::UnsupportedVersion(v)) => assert_eq!(v, 0x0100),
other => panic!("expected UnsupportedVersion, got {other:?}"),
}
}
#[test]
fn decode_accepts_known_major_and_drops_unknown_fields() {
let mut cbor_bytes: Vec<u8> = Vec::new();
let body = serde_json::json!({
"content_version": 1u16,
"probe_source": "process.spawned",
"guest_pid": 7u32,
"guest_comm": "workload",
"guest_monotonic_ns": 42u64,
"cell_id": "FORGED-CELL-ID",
"run_id": "FORGED-RUN-ID",
});
ciborium::ser::into_writer(&body, &mut cbor_bytes).unwrap();
let g = decode_frame(&cbor_bytes).expect("decode");
assert_eq!(g.probe_source, "process.spawned");
assert_eq!(g.guest_pid, 7);
assert_eq!(g.guest_comm, "workload");
assert_eq!(g.guest_monotonic_ns, 42);
}
#[test]
fn decode_rejects_garbage() {
let result = decode_frame(&[0xff, 0xff, 0xff]);
assert!(matches!(result, Err(TelemetryError::Wire(_))));
}
#[tokio::test]
async fn bind_creates_uds_at_expected_path() {
let dir = tempdir().unwrap();
let base = dir.path().join("cellos-vsock-test.socket");
let listener = VsockUdsListener::bind_for_cell(&base).expect("bind");
let expected = PathBuf::from(format!("{}_{}", base.display(), VSOCK_TELEMETRY_PORT));
assert_eq!(listener.socket_path(), expected.as_path());
assert!(expected.exists(), "UDS file should exist after bind");
}
#[tokio::test]
async fn end_to_end_frame_round_trip_stamps_attribution() {
let dir = tempdir().unwrap();
let base = dir.path().join("cellos-vsock-rt.socket");
let listener = VsockUdsListener::bind_for_cell(&base).expect("bind");
let socket_path = listener.socket_path().to_path_buf();
let server = tokio::spawn(async move {
let mut stream = listener.accept().await.expect("accept");
let stamp = HostStamp {
cell_id: "cell-A".into(),
run_id: "run-1".into(),
host_received_at: SystemTime::now(),
spec_signature_hash: "sha256:abc".into(),
};
let ka = KeepAlive::new(std::time::Duration::from_secs(10));
stream
.recv_stamped(&stamp, &ka)
.await
.expect("recv_stamped")
});
let mut client = UnixStream::connect(&socket_path).await.expect("connect");
let frame = encode_test_frame(1, "process.spawned", 9, "evil", 7);
client.write_all(&frame).await.expect("write");
client.shutdown().await.ok();
let stamped = server.await.expect("task").expect("got frame");
assert_eq!(stamped.cell_id, "cell-A");
assert_eq!(stamped.run_id, "run-1");
assert_eq!(stamped.spec_signature_hash, "sha256:abc");
assert_eq!(stamped.probe_source, "process.spawned");
assert_eq!(stamped.guest_pid, 9);
assert_eq!(stamped.guest_comm, "evil");
}
}