use crate::enclave::{
proxy::{EnclaveArg, Error, Result},
DeviceProxy, VsockPortOffset,
};
use signal_hook::consts::SIGTERM;
use std::{
io::{ErrorKind, Read, Write},
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use vsock::{VsockAddr, VsockListener, VsockStream, VMADDR_CID_ANY};
#[derive(Clone)]
pub struct SignalHandler {
sig: Arc<AtomicBool>,
buf: [u8; 1],
}
impl SignalHandler {
pub fn new() -> Result<Self> {
let sig = Arc::new(AtomicBool::new(false));
signal_hook::flag::register(SIGTERM, Arc::clone(&sig)).map_err(Error::SignalRegister)?;
let buf = [0u8; 1];
Ok(Self { sig, buf })
}
}
impl DeviceProxy for SignalHandler {
fn arg(&self) -> Option<EnclaveArg<'_>> {
None
}
fn clone(&self) -> Result<Option<Box<dyn DeviceProxy>>> {
Ok(Some(Box::new(Clone::clone(self))))
}
fn rcv(&mut self, vsock: &mut VsockStream) -> Result<usize> {
vsock.read(&mut self.buf).map_err(Error::VsockRead)
}
fn send(&mut self, vsock: &mut VsockStream) -> Result<usize> {
if !self.sig.load(Ordering::Relaxed) {
return Ok(0);
}
let sig = libc::SIGTERM;
match vsock.write(&sig.to_ne_bytes()) {
Ok(size) => Ok(size),
Err(e) if e.kind() == ErrorKind::BrokenPipe => Ok(0),
Err(e) => Err(Error::VsockWrite(e)),
}
}
fn vsock(&mut self, cid: u32) -> Result<VsockStream> {
let port = cid + (VsockPortOffset::SignalHandler as u32);
let listener =
VsockListener::bind(&VsockAddr::new(VMADDR_CID_ANY, port)).map_err(Error::VsockBind)?;
let (vsock, _) = listener.accept().map_err(Error::VsockAccept)?;
Ok(vsock)
}
}