use std::path::{Path, PathBuf};
use serde::{Deserialize, Serialize};
pub const DEFAULT_GUEST_CID: u32 = 3;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VsockConnection {
pub uds_path: PathBuf,
pub guest_cid: u32,
pub port: u32,
}
impl VsockConnection {
#[must_use]
pub fn new(uds_path: PathBuf, guest_cid: u32, port: u32) -> Self {
Self {
uds_path,
guest_cid,
port,
}
}
#[must_use]
pub fn from_workdir(workdir: &Path, port: u32) -> Self {
Self {
uds_path: workdir.join("vsock.sock"),
guest_cid: DEFAULT_GUEST_CID,
port,
}
}
pub async fn connect(&self) -> crate::Result<tokio::net::UnixStream> {
tracing::debug!(
uds = %self.uds_path.display(),
cid = self.guest_cid,
port = self.port,
"connecting to vsock"
);
let stream = tokio::net::UnixStream::connect(&self.uds_path)
.await
.map_err(|e| {
crate::KavachError::ExecFailed(format!(
"vsock connect {}: {e}",
self.uds_path.display()
))
})?;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
let connect_msg = format!("CONNECT {}\n", self.port);
let mut stream = stream;
stream
.write_all(connect_msg.as_bytes())
.await
.map_err(|e| crate::KavachError::ExecFailed(format!("vsock handshake: {e}")))?;
let mut reader = BufReader::new(&mut stream);
let mut response = String::new();
reader
.read_line(&mut response)
.await
.map_err(|e| crate::KavachError::ExecFailed(format!("vsock response: {e}")))?;
if !response.starts_with("OK") {
return Err(crate::KavachError::ExecFailed(format!(
"vsock connect rejected: {response}"
)));
}
tracing::debug!(port = self.port, "vsock connected");
Ok(stream)
}
pub async fn send(stream: &mut tokio::net::UnixStream, data: &[u8]) -> crate::Result<()> {
use tokio::io::AsyncWriteExt;
stream
.write_all(data)
.await
.map_err(|e| crate::KavachError::ExecFailed(format!("vsock send: {e}")))?;
Ok(())
}
pub async fn recv(stream: &mut tokio::net::UnixStream, buf: &mut [u8]) -> crate::Result<usize> {
use tokio::io::AsyncReadExt;
let n = stream
.read(buf)
.await
.map_err(|e| crate::KavachError::ExecFailed(format!("vsock recv: {e}")))?;
Ok(n)
}
#[must_use]
pub fn socket_exists(&self) -> bool {
self.uds_path.exists()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_connection() {
let conn = VsockConnection::new("/tmp/vsock.sock".into(), 3, 5000);
assert_eq!(conn.guest_cid, 3);
assert_eq!(conn.port, 5000);
}
#[test]
fn from_workdir() {
let conn = VsockConnection::from_workdir(Path::new("/tmp/fc"), 8080);
assert_eq!(conn.uds_path, PathBuf::from("/tmp/fc/vsock.sock"));
assert_eq!(conn.guest_cid, DEFAULT_GUEST_CID);
assert_eq!(conn.port, 8080);
}
#[test]
fn serde_roundtrip() {
let conn = VsockConnection::new("/tmp/vs.sock".into(), 5, 9000);
let json = serde_json::to_string(&conn).unwrap();
let back: VsockConnection = serde_json::from_str(&json).unwrap();
assert_eq!(conn.guest_cid, back.guest_cid);
assert_eq!(conn.port, back.port);
}
#[test]
fn socket_exists_false_for_nonexistent() {
let conn = VsockConnection::new("/tmp/nonexistent_vsock_test.sock".into(), 3, 5000);
assert!(!conn.socket_exists());
}
#[test]
fn default_cid() {
assert_eq!(DEFAULT_GUEST_CID, 3);
}
}