#![allow(async_fn_in_trait)]
use super::{Circuit, GlobalNodeId, NodeId, trace::SchedulerEvent};
use crate::{DetailedError, Position};
use feldera_types::transaction::CommitProgressSummary;
use itertools::Itertools;
use serde::Serialize;
use std::{
borrow::Cow,
collections::{BTreeMap, BTreeSet},
error::Error as StdError,
fmt::{Display, Error as FmtError, Formatter},
future::Future,
pin::Pin,
string::ToString,
};
mod dynamic_scheduler;
pub use dynamic_scheduler::DynamicScheduler;
#[derive(Clone, Debug, Eq, PartialEq, Serialize)]
#[serde(untagged)]
pub enum Error {
OwnershipConflict {
origin: GlobalNodeId,
consumers: Vec<GlobalNodeId>,
},
CyclicCircuit {
node_id: GlobalNodeId,
},
CommitWithoutTransaction,
StepWithoutTransaction,
Killed,
TokioError {
error: String,
},
ReplayInfoConflict {
error: String,
},
}
impl DetailedError for Error {
fn error_code(&self) -> Cow<'static, str> {
match self {
Self::OwnershipConflict { .. } => Cow::from("OwnershipConflict"),
Self::CyclicCircuit { .. } => Cow::from("CyclicCircuit"),
Self::CommitWithoutTransaction => Cow::from("CommitWithoutTransaction"),
Self::StepWithoutTransaction => Cow::from("StepWithoutTransaction"),
Self::Killed => Cow::from("Killed"),
Self::TokioError { .. } => Cow::from("TokioError"),
Self::ReplayInfoConflict { .. } => Cow::from("ReplayInfoConflict"),
}
}
}
impl Display for Error {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), FmtError> {
match self {
Self::OwnershipConflict { origin, consumers } => {
write!(
f,
"ownership conflict: output of node '{origin}' is consumed by value by the following nodes: [{}]",
consumers.iter().map(ToString::to_string).format(",")
)
}
Self::CyclicCircuit { node_id } => {
write!(
f,
"unschedulable circuit due to a cyclic topology: cycle through node '{node_id}'"
)
}
Error::CommitWithoutTransaction => {
f.write_str("commit invoked outside of a transaction")
}
Error::StepWithoutTransaction => f.write_str("step called outside of a transaction"),
Self::Killed => f.write_str("circuit has been killed by the user"),
Self::TokioError { error } => write!(f, "tokio error: {error}"),
Self::ReplayInfoConflict { error } => {
write!(f, "replay info conflict: {error}")
}
}
}
}
impl StdError for Error {}
#[derive(Debug)]
pub struct CommitProgress {
completed: BTreeMap<NodeId, Option<Position>>,
in_progress: BTreeMap<NodeId, Option<Position>>,
remaining: BTreeSet<NodeId>,
}
impl Default for CommitProgress {
fn default() -> Self {
Self::new()
}
}
impl CommitProgress {
pub fn new() -> Self {
Self {
completed: BTreeMap::new(),
in_progress: BTreeMap::new(),
remaining: BTreeSet::new(),
}
}
pub fn add_remaining(&mut self, node_id: NodeId) {
self.remaining.insert(node_id);
}
pub fn add_completed(&mut self, node_id: NodeId, progress: Option<Position>) {
debug_assert!(!self.completed.contains_key(&node_id));
debug_assert!(!self.in_progress.contains_key(&node_id));
debug_assert!(!self.remaining.contains(&node_id));
self.completed.insert(node_id, progress);
}
pub fn add_in_progress(&mut self, node_id: NodeId, progress: Option<Position>) {
debug_assert!(!self.completed.contains_key(&node_id));
debug_assert!(!self.in_progress.contains_key(&node_id));
debug_assert!(!self.remaining.contains(&node_id));
self.in_progress.insert(node_id, progress);
}
pub fn get_in_progress(&self) -> &BTreeMap<NodeId, Option<Position>> {
&self.in_progress
}
pub fn summary(&self) -> CommitProgressSummary {
let completed = self.completed.len() as u64;
let in_progress = self.in_progress.len() as u64;
let remaining = self.remaining.len() as u64;
let in_progress_processed_records = self
.in_progress
.values()
.map(|progress| progress.as_ref().map(|p| p.offset).unwrap_or_default())
.sum();
let in_progress_total_records = self
.in_progress
.values()
.map(|progress| progress.as_ref().map(|p| p.total).unwrap_or_default())
.sum();
CommitProgressSummary {
completed,
in_progress,
remaining,
in_progress_processed_records,
in_progress_total_records,
}
}
}
pub trait Scheduler
where
Self: Sized,
{
fn new() -> Self;
fn prepare<C>(&mut self, circuit: &C, nodes: Option<&BTreeSet<NodeId>>) -> Result<(), Error>
where
C: Circuit;
async fn start_transaction<C>(&self, circuit: &C) -> Result<(), Error>
where
C: Circuit;
fn start_commit_transaction(&self) -> Result<(), Error>;
fn is_commit_complete(&self) -> bool;
fn commit_progress(&self) -> CommitProgress;
async fn step<C>(&self, circuit: &C) -> Result<(), Error>
where
C: Circuit;
async fn transaction<C>(&self, circuit: &C) -> Result<(), Error>
where
C: Circuit,
{
self.start_transaction(circuit).await?;
self.start_commit_transaction()?;
while !self.is_commit_complete() {
self.step(circuit).await?;
}
Ok(())
}
fn flush(&self);
fn is_flush_complete(&self) -> bool;
}
pub trait Executor<C>: 'static {
fn prepare(&mut self, circuit: &C, nodes: Option<&BTreeSet<NodeId>>) -> Result<(), Error>;
fn start_transaction<'a>(
&'a self,
circuit: &'a C,
) -> Pin<Box<dyn Future<Output = Result<(), Error>> + 'a>>;
fn start_commit_transaction(&self) -> Result<(), Error>;
fn is_commit_complete(&self) -> bool;
fn commit_progress(&self) -> CommitProgress;
fn step<'a>(&'a self, circuit: &'a C) -> Pin<Box<dyn Future<Output = Result<(), Error>> + 'a>>;
fn transaction<'a>(
&'a self,
circuit: &'a C,
) -> Pin<Box<dyn Future<Output = Result<(), Error>> + 'a>>;
fn flush(&self);
fn is_flush_complete(&self) -> bool;
}
pub(crate) struct IterativeExecutor<F, S> {
termination_check: F,
scheduler: S,
}
impl<F, S> IterativeExecutor<F, S> {
pub(crate) fn new(termination_check: F) -> Self
where
S: Scheduler,
{
Self {
termination_check,
scheduler: <S as Scheduler>::new(),
}
}
}
impl<C, F, S> Executor<C> for IterativeExecutor<F, S>
where
F: AsyncFn() -> Result<bool, Error> + 'static,
C: Circuit,
S: Scheduler + 'static,
{
fn start_transaction<'a>(
&'a self,
_circuit: &C,
) -> Pin<Box<dyn Future<Output = Result<(), Error>> + 'a>> {
unimplemented!()
}
fn start_commit_transaction(&self) -> Result<(), Error> {
unimplemented!()
}
fn commit_progress(&self) -> CommitProgress {
unimplemented!()
}
fn is_commit_complete(&self) -> bool {
unimplemented!()
}
fn step<'a>(&'a self, circuit: &'a C) -> Pin<Box<dyn Future<Output = Result<(), Error>> + 'a>> {
let circuit = circuit.clone();
Box::pin(async move { self.scheduler.step(&circuit).await })
}
fn transaction<'a>(
&'a self,
circuit: &C,
) -> Pin<Box<dyn Future<Output = Result<(), Error>> + 'a>> {
let circuit = circuit.clone();
Box::pin(async move {
circuit.log_scheduler_event(&SchedulerEvent::clock_start());
circuit.clock_start(0);
loop {
self.scheduler.transaction(&circuit).await?;
if (self.termination_check)().await? {
break;
}
}
circuit.log_scheduler_event(&SchedulerEvent::clock_end());
circuit.clock_end(0);
Ok(())
})
}
fn prepare(&mut self, circuit: &C, nodes: Option<&BTreeSet<NodeId>>) -> Result<(), Error> {
self.scheduler.prepare(circuit, nodes)
}
fn flush(&self) {
self.scheduler.flush();
}
fn is_flush_complete(&self) -> bool {
self.scheduler.is_flush_complete()
}
}
pub(crate) struct OnceExecutor<S> {
scheduler: S,
}
impl<S> OnceExecutor<S>
where
S: Scheduler,
Self: Sized,
{
pub(crate) fn new() -> Self {
Self {
scheduler: <S as Scheduler>::new(),
}
}
}
impl<C, S> Executor<C> for OnceExecutor<S>
where
C: Circuit,
S: Scheduler + 'static,
{
fn start_transaction<'a>(
&'a self,
circuit: &'a C,
) -> Pin<Box<dyn Future<Output = Result<(), Error>> + 'a>> {
Box::pin(async { self.scheduler.start_transaction(circuit).await })
}
fn step<'a>(&'a self, circuit: &'a C) -> Pin<Box<dyn Future<Output = Result<(), Error>> + 'a>> {
Box::pin(async { self.scheduler.step(circuit).await })
}
fn transaction<'a>(
&'a self,
circuit: &'a C,
) -> Pin<Box<dyn Future<Output = Result<(), Error>> + 'a>> {
Box::pin(async { self.scheduler.transaction(circuit).await })
}
fn prepare(&mut self, circuit: &C, nodes: Option<&BTreeSet<NodeId>>) -> Result<(), Error> {
self.scheduler.prepare(circuit, nodes)
}
fn start_commit_transaction(&self) -> Result<(), Error> {
self.scheduler.start_commit_transaction()
}
fn is_commit_complete(&self) -> bool {
self.scheduler.is_commit_complete()
}
fn commit_progress(&self) -> CommitProgress {
self.scheduler.commit_progress()
}
fn flush(&self) {
unimplemented!()
}
fn is_flush_complete(&self) -> bool {
unimplemented!()
}
}
mod util {
use crate::circuit::{
Circuit, GlobalNodeId, NodeId, OwnershipPreference, circuit_builder::StreamId,
schedule::Error,
};
use petgraph::graphmap::DiGraphMap;
use std::{collections::HashMap, ops::Deref};
pub(crate) fn circuit_graph<C>(circuit: &C) -> DiGraphMap<NodeId, ()>
where
C: Circuit,
{
let mut g = DiGraphMap::<NodeId, ()>::new();
for node_id in circuit.node_ids().into_iter() {
g.add_node(node_id);
}
for edge in circuit.edges().deref().iter() {
g.add_edge(edge.from, edge.to, ());
}
g
}
pub(crate) fn ownership_constraints<C>(circuit: &C) -> Result<Vec<(NodeId, NodeId)>, Error>
where
C: Circuit,
{
let num_nodes = circuit.num_nodes();
let mut successors: HashMap<
(GlobalNodeId, StreamId),
Vec<(NodeId, Option<OwnershipPreference>)>,
> = HashMap::with_capacity(num_nodes);
for edge in circuit.edges().deref().iter() {
let Some(stream_id) = edge.stream_id() else {
continue;
};
let origin = edge.origin.clone();
successors
.entry((origin, stream_id))
.or_default()
.push((edge.to, edge.ownership_preference));
}
let mut constraints = Vec::new();
for ((origin, _), succ) in successors.into_iter() {
let strong_successors: Vec<_> = succ
.iter()
.enumerate()
.filter(|(_i, (_, pref))| {
pref.is_some() && pref.unwrap() >= OwnershipPreference::STRONGLY_PREFER_OWNED
})
.collect();
if strong_successors.len() > 1 {
return Err(Error::OwnershipConflict {
origin,
consumers: strong_successors
.into_iter()
.map(|(_, (suc, _))| GlobalNodeId::child_of(circuit, *suc))
.collect(),
});
};
if strong_successors.is_empty() {
continue;
}
let strong_successor_index = strong_successors[0].0;
for (i, successor) in succ.iter().enumerate() {
if i != strong_successor_index && successor.1.is_some() {
constraints.push((successor.0, succ[strong_successor_index].0));
}
}
}
Ok(constraints)
}
}