use libc::{pthread_kill, pthread_t, SIGALRM};
use std::mem;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Condvar, Mutex, Weak};
use crate::instance::{Instance, TerminationDetails};
pub struct KillState {
terminable: AtomicBool,
execution_domain: Mutex<Domain>,
thread_id: Mutex<Option<pthread_t>>,
tid_change_notifier: Condvar,
}
pub unsafe extern "C" fn exit_guest_region(instance: *mut Instance) {
let terminable = (*instance)
.kill_state
.terminable
.swap(false, Ordering::SeqCst);
if !terminable {
loop {}
}
}
impl KillState {
pub fn new() -> KillState {
KillState {
terminable: AtomicBool::new(false),
tid_change_notifier: Condvar::new(),
execution_domain: Mutex::new(Domain::Guest),
thread_id: Mutex::new(None),
}
}
pub fn is_terminable(&self) -> bool {
self.terminable.load(Ordering::SeqCst)
}
pub fn enable_termination(&self) {
self.terminable.store(true, Ordering::SeqCst);
}
pub fn disable_termination(&self) {
self.terminable.store(false, Ordering::SeqCst);
}
pub fn terminable_ptr(&self) -> *const AtomicBool {
&self.terminable as *const AtomicBool
}
pub fn begin_hostcall(&self) {
let mut current_domain = self.execution_domain.lock().unwrap();
match *current_domain {
Domain::Guest => {
*current_domain = Domain::Hostcall;
}
Domain::Hostcall => {
panic!(
"Invalid state: Instance marked as in a hostcall while entering a hostcall."
);
}
Domain::Terminated => {
panic!("Invalid state: Instance marked as terminated while in guest code. This should be an error.");
}
}
}
pub fn end_hostcall(&self) -> Option<TerminationDetails> {
let mut current_domain = self.execution_domain.lock().unwrap();
match *current_domain {
Domain::Guest => {
panic!("Invalid state: Instance marked as in guest code while exiting a hostcall.");
}
Domain::Hostcall => {
*current_domain = Domain::Guest;
None
}
Domain::Terminated => {
debug_assert!(!self.terminable.load(Ordering::SeqCst));
std::mem::drop(current_domain);
Some(TerminationDetails::Remote)
}
}
}
pub fn schedule(&self, tid: pthread_t) {
*self.thread_id.lock().unwrap() = Some(tid);
self.tid_change_notifier.notify_all();
}
pub fn deschedule(&self) {
*self.thread_id.lock().unwrap() = None;
self.tid_change_notifier.notify_all();
}
}
pub enum Domain {
Guest,
Hostcall,
Terminated,
}
pub struct KillSwitch {
state: Weak<KillState>,
}
#[derive(Debug, PartialEq)]
pub enum KillSuccess {
Signalled,
Pending,
}
#[derive(Debug, PartialEq)]
pub enum KillError {
NotTerminable,
}
type KillResult = Result<KillSuccess, KillError>;
impl KillSwitch {
pub(crate) fn new(state: Weak<KillState>) -> Self {
KillSwitch { state }
}
pub fn terminate(&self) -> KillResult {
let state = self.state.upgrade().ok_or(KillError::NotTerminable)?;
let terminable = state.terminable.swap(false, Ordering::SeqCst);
if !terminable {
return Err(KillError::NotTerminable);
}
let mut execution_domain = state.execution_domain.lock().unwrap();
let result = match *execution_domain {
Domain::Guest => {
let mut curr_tid = state.thread_id.lock().unwrap();
if let Some(thread_id) = *curr_tid {
unsafe {
pthread_kill(thread_id, SIGALRM);
}
while curr_tid.is_some() {
curr_tid = state.tid_change_notifier.wait(curr_tid).unwrap();
}
Ok(KillSuccess::Signalled)
} else {
panic!("logic error: instance is terminable but not actually running.");
}
}
Domain::Hostcall => {
*execution_domain = Domain::Terminated;
Ok(KillSuccess::Pending)
}
Domain::Terminated => {
Err(KillError::NotTerminable)
}
};
mem::drop(execution_domain);
result
}
}