use self::collective::*;
use super::*;
#[doc(hidden)]
pub struct Task {
pub start: Instant,
pub machine: ShareableMachine,
}
impl Task {
pub fn new(machine: &ShareableMachine) -> Self {
Self {
start: std::time::Instant::now(),
machine: Arc::clone(machine),
}
}
}
pub struct SchedTask {
pub start: Instant,
pub machine_key: usize,
}
impl SchedTask {
pub fn new(machine_key: usize) -> Self {
Self {
start: Instant::now(),
machine_key,
}
}
}
#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
pub struct ExecutorStats {
pub id: usize,
pub tasks_executed: u128,
pub instructs_sent: u128,
pub blocked_senders: u128,
pub max_blocked_senders: usize,
pub exhausted_slice: u128,
pub recv_time: std::time::Duration,
pub time_on_queue: std::time::Duration,
}
#[derive(Copy, Clone, Eq, PartialEq, SmartDefault, Debug)]
pub enum ExecutorState {
#[default]
Init,
Drain,
Parked,
Running,
}
#[doc(hidden)]
#[derive(PartialEq, Eq, Clone, Copy)]
pub enum TrySendError {
Full,
Disconnected,
}
#[doc(hidden)]
pub struct SharedCollectiveSenderAdapter {
pub id: Uuid,
pub key: usize,
pub state: MachineState,
pub normalized_adapter: CommonCollectiveSenderAdapter,
}
impl SharedCollectiveSenderAdapter {
pub const fn get_id(&self) -> Uuid { self.id }
pub const fn get_key(&self) -> usize { self.key }
pub fn try_send(&mut self) -> Result<(), TrySendError> { self.normalized_adapter.try_send() }
}
#[doc(hidden)]
pub trait CollectiveSenderAdapter {
fn get_id(&self) -> Uuid;
fn get_key(&self) -> usize;
fn try_send(&mut self) -> Result<(), TrySendError>;
}
pub type CommonCollectiveSenderAdapter = Box<dyn CollectiveSenderAdapter>;
#[derive(Debug)]
pub struct SharedExecutorInfo {
state: ExecutorState,
start_idle: Instant,
}
impl SharedExecutorInfo {
pub fn set_idle(&mut self) -> Instant {
self.start_idle = Instant::now();
self.start_idle
}
pub fn set_state(&mut self, new: ExecutorState) { self.state = new }
pub const fn get_state(&self) -> ExecutorState { self.state }
pub fn compare_set_state(&mut self, old: ExecutorState, new: ExecutorState) {
if self.state == old {
self.state = new
}
}
pub fn get_state_and_elapsed(&self) -> (ExecutorState, Duration) { (self.state, self.start_idle.elapsed()) }
}
impl Default for SharedExecutorInfo {
fn default() -> Self {
Self {
state: ExecutorState::Init,
start_idle: Instant::now(),
}
}
}
#[doc(hidden)]
#[derive(Default)]
pub struct ExecutorData {
pub id: usize,
pub machine: ExecutorDataField,
pub blocked_senders: Vec<SharedCollectiveSenderAdapter>,
pub last_blocked_send_len: usize,
pub notifier: ExecutorDataField,
pub shared_info: Arc<Mutex<SharedExecutorInfo>>,
}
impl ExecutorData {
pub fn block_or_continue() {
tls_executor_data.with(|t| {
let mut tls = t.borrow_mut();
if tls.id == 0 {
return;
}
if let ExecutorDataField::Machine(machine) = &tls.machine {
if machine.state.get() != CollectiveState::Running {
tls.recursive_block();
}
}
});
}
pub fn recursive_block(&mut self) {
self.shared_info
.lock()
.as_mut()
.unwrap()
.compare_set_state(ExecutorState::Running, ExecutorState::Drain);
self.drain();
let mut mutable = self.shared_info.lock().unwrap();
mutable.compare_set_state(ExecutorState::Drain, ExecutorState::Running);
mutable.set_idle();
}
pub fn sender_blocked(&mut self, channel_id: usize, adapter: SharedCollectiveSenderAdapter) {
if adapter.state.get() == CollectiveState::SendBlock {
log::info!(
"Executor {} detected recursive send block, this should not happen",
self.id
);
unreachable!(
"block_or_continue() should be called to prevent entering sender_blocked with a blocked machine"
)
} else {
log::trace!("executor {} parking sender {}", self.id, channel_id);
adapter.state.set(CollectiveState::SendBlock);
self.blocked_senders.push(adapter);
}
}
fn drain(&mut self) {
let backoff = crossbeam::utils::Backoff::new();
let (machine_key, machine_state) = match &self.machine {
ExecutorDataField::Machine(machine) => (machine.key, machine.state.clone()),
_ => panic!("machine field was not set prior to running"),
};
while !self.blocked_senders.is_empty() {
self.shared_info.lock().as_mut().unwrap().set_idle();
let mut still_blocked: Vec<SharedCollectiveSenderAdapter> = Vec::with_capacity(self.blocked_senders.len());
let mut handled_recursive_sender = false;
for mut sender in self.blocked_senders.drain(..) {
match sender.try_send() {
Ok(()) if sender.key == machine_key => {
backoff.reset();
machine_state.set(CollectiveState::Running);
handled_recursive_sender = true;
},
Err(TrySendError::Disconnected) if sender.key == machine_key => {
backoff.reset();
machine_state.set(CollectiveState::Running);
handled_recursive_sender = true;
},
Ok(()) => {
backoff.reset();
match &self.notifier {
ExecutorDataField::Notifier(obj) => obj.notify_can_schedule(sender.key),
_ => log::error!("can't notify scheduler!!!"),
};
},
Err(TrySendError::Disconnected) => {
backoff.reset();
match &self.notifier {
ExecutorDataField::Notifier(obj) => obj.notify_can_schedule(sender.key),
_ => log::error!("can't notify scheduler!!!"),
};
},
Err(TrySendError::Full) => {
still_blocked.push(sender);
},
}
}
self.blocked_senders = still_blocked;
if handled_recursive_sender {
break;
}
if backoff.is_completed() && self.shared_info.lock().unwrap().get_state() != ExecutorState::Parked {
self.shared_info
.lock()
.as_mut()
.unwrap()
.set_state(ExecutorState::Parked);
match &self.notifier {
ExecutorDataField::Notifier(obj) => obj.notify_parked(self.id),
_ => log::error!("Executor {} doesn't have a notifier", self.id),
};
}
backoff.snooze();
}
log::debug!("drained recursive sender, allowing send to continue");
}
}
#[doc(hidden)]
#[derive(SmartDefault)]
pub enum ExecutorDataField {
#[default]
Uninitialized,
Notifier(ExecutorNotifierObj),
Machine(ShareableMachine),
}
pub trait ExecutorNotifier: Send + Sync + 'static {
fn notify_parked(&self, executor_id: usize);
fn notify_can_schedule(&self, machine_key: usize);
}
pub type ExecutorNotifierObj = std::sync::Arc<dyn ExecutorNotifier>;
thread_local! {
#[doc(hidden)]
#[allow(non_upper_case_globals)]
pub static tls_executor_data: RefCell<ExecutorData> = RefCell::new(ExecutorData::default());
}
#[cfg(test)]
mod tests {
#[allow(unused_imports)] use super::*;
}