use std::{
cell::{Ref, RefCell, RefMut},
collections::{BTreeSet, HashMap, HashSet},
panic,
sync::{Arc, Mutex},
};
use crate::{
Position,
circuit::{
Circuit, GlobalNodeId, NodeId,
circuit_builder::CircuitMetadata,
runtime::{Broadcast, Runtime},
schedule::{
CommitProgress, Error, Scheduler,
util::{circuit_graph, ownership_constraints},
},
trace::SchedulerEvent,
},
};
use petgraph::algo::toposort;
use tokio::{select, sync::Notify, task::JoinSet};
#[derive(Debug)]
pub enum FlushState {
UnflushedDependencies(usize),
Started(Option<Position>),
Completed(Option<Position>),
}
struct Task {
node_id: NodeId,
num_predecessors: usize,
successors: Vec<NodeId>,
is_async: bool,
unsatisfied_dependencies: usize,
is_ready: bool,
scheduled: bool,
flush_state: FlushState,
}
#[derive(Clone)]
struct Notifications {
nodes: Arc<Mutex<HashSet<NodeId>>>,
notify: Arc<Notify>,
}
impl Notifications {
fn new(size: usize) -> Self {
Self {
nodes: Arc::new(Mutex::new(HashSet::with_capacity(size))),
notify: Arc::new(Notify::new()),
}
}
fn notify(&self, node_id: NodeId) {
self.nodes.lock().unwrap().insert(node_id);
self.notify.notify_one();
}
async fn wait(&self) {
self.notify.notified().await
}
}
enum TransactionPhase {
Idle,
Started,
Committing(usize),
CommitComplete,
}
struct Inner {
tasks: HashMap<NodeId, Task>,
notifications: Notifications,
handles: JoinSet<(NodeId, Result<Option<Position>, Error>)>,
waiting: bool,
transaction_phase: TransactionPhase,
global_commit_consensus: Broadcast<(bool, bool)>,
metadata_broadcast: Broadcast<CircuitMetadata>,
before_first_step: bool,
flush_state: bool,
}
impl Inner {
fn schedule_successors<C>(&mut self, circuit: &C, node_id: NodeId, flush_complete: bool)
where
C: Circuit,
{
for i in 0..self.tasks[&node_id].successors.len() {
let succ_id = self.tasks[&node_id].successors[i];
debug_assert!(self.tasks.contains_key(&succ_id));
let successor = self.tasks.get_mut(&succ_id).unwrap();
debug_assert_ne!(successor.unsatisfied_dependencies, 0);
successor.unsatisfied_dependencies -= 1;
if flush_complete {
let FlushState::UnflushedDependencies(n) = &mut successor.flush_state else {
panic!(
"Internal scheduler error: node {node_id} is in state {:?} while it still has unflushed dependencies",
successor.flush_state
);
};
debug_assert!(*n > 0);
*n -= 1;
}
if successor.unsatisfied_dependencies == 0 && successor.is_ready {
self.spawn_task(circuit, succ_id);
}
}
}
fn process_notifications<C>(&mut self, circuit: &C)
where
C: Circuit,
{
let mut nodes = std::mem::take(&mut *self.notifications.nodes.lock().unwrap());
#[allow(unknown_lints)]
#[allow(clippy::significant_drop_in_scrutinee)]
for id in nodes.drain() {
let task = self.tasks.get_mut(&id).unwrap();
debug_assert!(task.is_async);
if task.is_ready {
continue;
}
if circuit.ready(id) {
task.is_ready = true;
let node_id = task.node_id;
if task.unsatisfied_dependencies == 0 && !task.scheduled {
self.spawn_task(circuit, node_id);
}
}
}
}
fn prepare<C>(circuit: &C, nodes: Option<&BTreeSet<NodeId>>) -> Result<Self, Error>
where
C: Circuit,
{
let nodes = nodes
.map(|nodes| nodes.iter().cloned().collect::<BTreeSet<_>>())
.unwrap_or_else(|| BTreeSet::from_iter(circuit.node_ids()));
let mut g: petgraph::prelude::GraphMap<NodeId, (), petgraph::Directed> =
circuit_graph(circuit);
let extra_constraints = ownership_constraints(circuit)?;
for (from, to) in extra_constraints.iter() {
g.add_edge(*from, *to, ());
}
toposort(&g, None).map_err(|e| Error::CyclicCircuit {
node_id: GlobalNodeId::child_of(circuit, e.node_id()),
})?;
let num_nodes = nodes.len();
let mut successors: HashMap<NodeId, Vec<NodeId>> = HashMap::with_capacity(num_nodes);
let mut predecessors: HashMap<NodeId, Vec<NodeId>> = HashMap::with_capacity(num_nodes);
circuit.edges().iter().for_each(|edge| {
if let Some(stream) = &edge.stream {
stream.clear_consumer_count();
}
});
for edge in circuit.edges().iter() {
if nodes.contains(&edge.to) && nodes.contains(&edge.from) {
successors.entry(edge.from).or_default().push(edge.to);
predecessors.entry(edge.to).or_default().push(edge.from);
if let Some(stream) = &edge.stream {
stream.register_consumer();
}
}
}
for (from, to) in extra_constraints.into_iter() {
if nodes.contains(&to) && nodes.contains(&from) {
successors.entry(from).or_default().push(to);
predecessors.entry(to).or_default().push(from);
}
}
let mut tasks = HashMap::new();
let mut num_async_nodes = 0;
for &node_id in nodes.iter() {
let num_predecessors = predecessors.entry(node_id).or_default().len();
let is_async = circuit.is_async_node(node_id);
if is_async {
num_async_nodes += 1;
}
tasks.insert(
node_id,
Task {
node_id,
num_predecessors,
successors: successors.entry(node_id).or_default().clone(),
is_async,
unsatisfied_dependencies: num_predecessors,
is_ready: !is_async,
scheduled: false,
flush_state: FlushState::UnflushedDependencies(num_predecessors),
},
);
}
let scheduler = Self {
tasks,
notifications: Notifications::new(num_async_nodes),
handles: JoinSet::new(),
waiting: false,
transaction_phase: TransactionPhase::Idle,
global_commit_consensus: Broadcast::new(),
metadata_broadcast: Broadcast::new(),
before_first_step: true,
flush_state: false,
};
for &node_id in nodes.iter() {
if circuit.is_async_node(node_id) {
let notifications = scheduler.notifications.clone();
circuit.register_ready_callback(
node_id,
Box::new(move || notifications.notify(node_id)),
);
if circuit.ready(node_id) {
scheduler.notifications.notify(node_id);
}
}
}
if circuit.root_scope() == 0 {
circuit.balancer().prepare(circuit);
}
Ok(scheduler)
}
fn commit_complete(&self) -> bool {
matches!(self.transaction_phase, TransactionPhase::CommitComplete)
}
fn transaction_started(&self) -> bool {
matches!(self.transaction_phase, TransactionPhase::Started)
}
fn transaction_in_progress(&self) -> bool {
matches!(
self.transaction_phase,
TransactionPhase::Started | TransactionPhase::Committing(_)
)
}
fn spawn_task<C>(&mut self, circuit: &C, node_id: NodeId)
where
C: Circuit,
{
let task = self.tasks.get_mut(&node_id).unwrap();
debug_assert_eq!(task.unsatisfied_dependencies, 0);
debug_assert!(task.is_ready);
debug_assert!(!task.scheduled);
task.scheduled = true;
if self.handles.is_empty() && self.waiting {
self.waiting = false;
circuit.log_scheduler_event(&SchedulerEvent::wait_end(circuit.global_id()));
}
let circuit = circuit.clone();
let committing = matches!(self.transaction_phase, TransactionPhase::Committing(_));
if committing && matches!(task.flush_state, FlushState::UnflushedDependencies(0)) {
circuit.flush_node(node_id);
task.flush_state = FlushState::Started(None);
}
self.handles.spawn_local(async move {
let result = circuit.eval_node(node_id).await;
(node_id, result)
});
}
async fn abort(&mut self) {
while !self.handles.is_empty() {
let _ = self.handles.join_next().await;
}
}
fn start_transaction<C>(&mut self, circuit: &C)
where
C: Circuit,
{
for task in self.tasks.values_mut() {
task.flush_state = FlushState::UnflushedDependencies(task.num_predecessors);
}
self.transaction_phase = TransactionPhase::Started;
circuit.notify_start_transaction();
if circuit.root_scope() == 0 {
circuit.balancer().start_transaction();
}
}
fn start_commit_transaction(&mut self) -> Result<(), Error> {
if !self.transaction_started() {
return Err(Error::CommitWithoutTransaction);
}
self.transaction_phase = TransactionPhase::Committing(self.tasks.len());
Ok(())
}
fn is_commit_complete(&self) -> bool {
self.commit_complete()
}
fn commit_progress(&self) -> CommitProgress {
let mut commit_progress = CommitProgress::new();
for (node_id, task) in self.tasks.iter() {
match &task.flush_state {
FlushState::UnflushedDependencies(_) => commit_progress.add_remaining(*node_id),
FlushState::Completed(progress) => {
commit_progress.add_completed(*node_id, progress.clone())
}
FlushState::Started(progress) => {
commit_progress.add_in_progress(*node_id, progress.clone())
}
}
}
commit_progress
}
async fn exchange_metadata<C>(&mut self, circuit: &C) -> Result<(), Error>
where
C: Circuit,
{
if circuit.root_scope() != 0 {
return Ok(());
}
let metadata = circuit.metadata_exchange().local_metadata().clone();
let global_metadata = self.metadata_broadcast.collect(metadata).await?;
circuit
.metadata_exchange()
.set_global_metadata(global_metadata);
Ok(())
}
fn flush(&mut self) {
self.flush_state = true;
}
fn is_flush_complete(&self) -> bool {
!self.flush_state
}
async fn step<C>(&mut self, circuit: &C) -> Result<(), Error>
where
C: Circuit,
{
if !self.transaction_in_progress() {
return Err(Error::StepWithoutTransaction);
}
circuit.log_scheduler_event(&SchedulerEvent::step_start(circuit.global_id()));
if self.before_first_step {
self.before_first_step = false;
self.exchange_metadata(circuit).await?;
if circuit.root_scope() == 0 {
circuit.balancer().update_metadata();
}
}
if circuit.root_scope() == 0 {
circuit.balancer().start_step();
}
let result = self.do_step(circuit).await;
self.exchange_metadata(circuit).await?;
if circuit.root_scope() == 0 {
circuit.balancer().update_metadata();
}
if let TransactionPhase::Committing(unflushed_operators) = &self.transaction_phase {
let statuses = self
.global_commit_consensus
.collect((*unflushed_operators == 0, self.flush_state))
.await?;
let commit_complete = statuses
.iter()
.all(|(commit_complete, _flush_complete)| *commit_complete);
if commit_complete {
self.transaction_phase = TransactionPhase::CommitComplete;
let flush_complete = statuses
.iter()
.all(|(_commit_complete, flush_complete)| *flush_complete);
if flush_complete {
self.flush_state = false;
}
if circuit.root_scope() == 0 {
circuit.balancer().transaction_committed();
}
}
}
circuit.log_scheduler_event(&SchedulerEvent::step_end(circuit.global_id()));
if self.commit_complete() {
circuit.tick();
}
result
}
async fn do_step<C>(&mut self, circuit: &C) -> Result<(), Error>
where
C: Circuit,
{
let mut completed_tasks = 0;
self.waiting = false;
if self.tasks.is_empty() {
return Ok(());
}
let mut spawn = Vec::with_capacity(self.tasks.len());
for task in self.tasks.values_mut() {
task.unsatisfied_dependencies = task.num_predecessors;
task.scheduled = false;
if task.unsatisfied_dependencies == 0 && task.is_ready {
spawn.push(task.node_id);
}
}
for node_id in spawn.into_iter() {
self.spawn_task(circuit, node_id);
}
loop {
select! {
ret = self.handles.join_next(), if !self.handles.is_empty() => {
completed_tasks += 1;
let result = ret.expect("JoinSet::join_next returned None on a non-empty join set.");
let (node_id, task_result) = match result {
Err(error) => {
self.abort().await;
if error.is_panic() {
panic::resume_unwind(error.into_panic());
} else {
return Err(Error::TokioError { error: error.to_string() });
}
},
Ok(result) => result
};
if self.tasks[&node_id].is_async {
self.tasks.get_mut(&node_id).unwrap().is_ready = false;
}
let progress = match task_result {
Ok(progress) => progress,
Err(e) => {
self.abort().await;
return Err(e);
}
};
if Runtime::kill_in_progress() {
self.abort().await;
return Err(Error::Killed);
}
let flush_complete = match &mut self.tasks.get_mut(&node_id).unwrap().flush_state {
flush_state@FlushState::Started(_) if circuit.is_flush_complete(node_id) => {
*flush_state = FlushState::Completed(progress);
let TransactionPhase::Committing(ref mut unflushed_operators) = self.transaction_phase else {
panic!("Internal scheduler error: flush called while not committing");
};
*unflushed_operators -= 1;
true
}
FlushState::Started(prog) => {
*prog = progress;
false
}
_ => false
};
debug_assert!(completed_tasks <= self.tasks.len());
if completed_tasks == self.tasks.len() {
return Ok(());
}
self.schedule_successors(circuit, node_id, flush_complete);
if self.handles.is_empty() {
self.waiting = true;
circuit.log_scheduler_event(
&SchedulerEvent::wait_start(circuit.global_id()));
}
}
_ = self.notifications.wait() => {
self.process_notifications(circuit);
}
}
}
}
}
pub struct DynamicScheduler(Option<RefCell<Inner>>);
impl DynamicScheduler {
fn inner(&self) -> Ref<'_, Inner> {
self.0
.as_ref()
.expect("DynamicScheduler: prepare() must be called before running the circuit")
.borrow()
}
fn inner_mut(&self) -> RefMut<'_, Inner> {
self.0
.as_ref()
.expect("DynamicScheduler: prepare() must be called before running the circuit")
.borrow_mut()
}
}
impl Scheduler for DynamicScheduler {
fn new() -> Self {
Self(None)
}
fn prepare<C>(&mut self, circuit: &C, nodes: Option<&BTreeSet<NodeId>>) -> Result<(), Error>
where
C: Circuit,
{
self.0 = Some(RefCell::new(Inner::prepare(circuit, nodes)?));
Ok(())
}
async fn start_transaction<C>(&self, circuit: &C) -> Result<(), Error>
where
C: Circuit,
{
let inner = &mut *self.inner_mut();
inner.start_transaction(circuit);
Ok(())
}
#[allow(clippy::await_holding_refcell_ref)]
async fn step<C>(&self, circuit: &C) -> Result<(), Error>
where
C: Circuit,
{
let inner = &mut *self.inner_mut();
inner.step(circuit).await
}
fn start_commit_transaction(&self) -> Result<(), Error> {
self.inner_mut().start_commit_transaction()
}
fn is_commit_complete(&self) -> bool {
self.inner().is_commit_complete()
}
fn commit_progress(&self) -> super::CommitProgress {
self.inner().commit_progress()
}
fn flush(&self) {
self.inner_mut().flush();
}
fn is_flush_complete(&self) -> bool {
self.inner().is_flush_complete()
}
}