use crate::definition::{DependencyMode, FlowNodeSpec, FrameSpec};
use crate::error::MobError;
use crate::ids::{FlowNodeId, FrameId, LoopId, LoopInstanceId, RunId, StepId};
use crate::run::FrameSnapshot;
use crate::store::MobRunStore;
use meerkat_machine_kernels::generated::flow_frame;
use meerkat_machine_kernels::{KernelEffect, KernelInput, KernelValue};
use std::collections::{BTreeMap, BTreeSet, VecDeque};
use std::sync::Arc;
mod sealed {
pub trait Sealed {}
}
pub struct StepCompletionOpts<'a> {
pub node_id: &'a FlowNodeId,
pub step_id: &'a StepId,
pub output: serde_json::Value,
pub loop_context: Option<(&'a LoopId, u64)>,
pub max_retries: usize,
}
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
pub trait FlowFrameMutator: sealed::Sealed {
async fn start_frame(
&self,
run_id: &RunId,
frame_id: &FrameId,
spec: &FrameSpec,
) -> Result<FrameSnapshot, MobError>;
async fn admit_next_ready_node(
&self,
run_id: &RunId,
frame_id: &FrameId,
) -> Result<Option<Vec<KernelEffect>>, MobError>;
async fn admit_next_ready_node_with_retry(
&self,
run_id: &RunId,
frame_id: &FrameId,
max_retries: usize,
) -> Result<Option<Vec<KernelEffect>>, MobError>;
async fn complete_step(
&self,
run_id: &RunId,
frame_id: &FrameId,
opts: StepCompletionOpts<'_>,
) -> Result<(), MobError>;
async fn complete_node(
&self,
run_id: &RunId,
frame_id: &FrameId,
node_id: &FlowNodeId,
) -> Result<bool, MobError>;
async fn fail_node(
&self,
run_id: &RunId,
frame_id: &FrameId,
node_id: &FlowNodeId,
) -> Result<bool, MobError>;
async fn skip_node(
&self,
run_id: &RunId,
frame_id: &FrameId,
node_id: &FlowNodeId,
) -> Result<bool, MobError>;
async fn cancel_node(
&self,
run_id: &RunId,
frame_id: &FrameId,
node_id: &FlowNodeId,
) -> Result<bool, MobError>;
async fn terminalize_frame(&self, run_id: &RunId, frame_id: &FrameId)
-> Result<bool, MobError>;
}
pub struct FlowFrameKernel {
run_store: Arc<dyn MobRunStore>,
}
impl FlowFrameKernel {
pub fn new(run_store: Arc<dyn MobRunStore>) -> Self {
Self { run_store }
}
fn node_val(node_id: &FlowNodeId) -> KernelValue {
KernelValue::String(node_id.to_string())
}
async fn require_frame(
&self,
run_id: &RunId,
frame_id: &FrameId,
) -> Result<FrameSnapshot, MobError> {
let run = self
.run_store
.get_run(run_id)
.await?
.ok_or_else(|| MobError::RunNotFound(run_id.clone()))?;
run.frames.get(frame_id).cloned().ok_or_else(|| {
MobError::Internal(format!("frame '{frame_id}' not found in run '{run_id}'"))
})
}
async fn transition_frame(
&self,
run_id: &RunId,
frame_id: &FrameId,
input: KernelInput,
max_retries: usize,
) -> Result<Vec<KernelEffect>, MobError> {
for _ in 0..=max_retries {
let current = self.require_frame(run_id, frame_id).await?;
let outcome = flow_frame::transition(¤t.kernel_state, &input)
.map_err(|e| MobError::Internal(format!("flow_frame transition failed: {e:?}")))?;
let next = FrameSnapshot {
kernel_state: outcome.next_state,
};
let effects = outcome.effects.clone();
let won = self
.run_store
.cas_frame_state(run_id, frame_id, Some(¤t), next)
.await?;
if won {
return Ok(effects);
}
}
Err(MobError::Internal(format!(
"transition_frame: CAS exhausted {max_retries} retries for frame '{frame_id}'"
)))
}
}
impl sealed::Sealed for FlowFrameKernel {}
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl FlowFrameMutator for FlowFrameKernel {
async fn start_frame(
&self,
run_id: &RunId,
frame_id: &FrameId,
spec: &FrameSpec,
) -> Result<FrameSnapshot, MobError> {
let run = self
.run_store
.get_run(run_id)
.await?
.ok_or_else(|| MobError::RunNotFound(run_id.clone()))?;
if let Some(existing) = run.frames.get(frame_id) {
return Ok(existing.clone());
}
let initial = flow_frame::initial_state()
.map_err(|e| MobError::Internal(format!("flow_frame initial_state failed: {e:?}")))?;
let ordered = topological_order(spec)?;
let start_input = build_start_root_frame_input(frame_id, spec, &ordered);
let outcome = flow_frame::transition(&initial, &start_input)
.map_err(|e| MobError::Internal(format!("flow_frame StartRootFrame failed: {e:?}")))?;
let snapshot = FrameSnapshot {
kernel_state: outcome.next_state,
};
let inserted = self
.run_store
.cas_frame_state(run_id, frame_id, None, snapshot.clone())
.await?;
if !inserted {
let run2 = self
.run_store
.get_run(run_id)
.await?
.ok_or_else(|| MobError::RunNotFound(run_id.clone()))?;
return run2.frames.get(frame_id).cloned().ok_or_else(|| {
MobError::Internal(format!(
"frame '{frame_id}' missing after concurrent insert in run '{run_id}'"
))
});
}
Ok(snapshot)
}
async fn admit_next_ready_node(
&self,
run_id: &RunId,
frame_id: &FrameId,
) -> Result<Option<Vec<KernelEffect>>, MobError> {
let input = KernelInput {
variant: "AdmitNextReadyNode".into(),
fields: BTreeMap::new(),
};
self.transition_frame(run_id, frame_id, input, 5)
.await
.map(Some)
}
async fn admit_next_ready_node_with_retry(
&self,
run_id: &RunId,
frame_id: &FrameId,
max_retries: usize,
) -> Result<Option<Vec<KernelEffect>>, MobError> {
for _ in 0..=max_retries {
let snap = self.require_frame(run_id, frame_id).await?;
let queue_empty = match snap.kernel_state.fields.get("ready_queue") {
Some(KernelValue::Seq(seq)) => seq.is_empty(),
_ => true,
};
if queue_empty {
return Ok(None); }
let admit_input = KernelInput {
variant: "AdmitNextReadyNode".into(),
fields: BTreeMap::new(),
};
let outcome = flow_frame::transition(&snap.kernel_state, &admit_input)
.map_err(|e| MobError::Internal(format!("AdmitNextReadyNode failed: {e:?}")))?;
let next_snap = FrameSnapshot {
kernel_state: outcome.next_state,
};
let won = self
.run_store
.cas_frame_state(run_id, frame_id, Some(&snap), next_snap)
.await?;
if won {
return Ok(Some(outcome.effects));
}
}
Err(MobError::Internal(format!(
"admit_next_ready_node: CAS exhausted {max_retries} retries for frame '{frame_id}' \
— queue was non-empty but every attempt lost the CAS"
)))
}
async fn complete_step(
&self,
run_id: &RunId,
frame_id: &FrameId,
opts: StepCompletionOpts<'_>,
) -> Result<(), MobError> {
let StepCompletionOpts {
node_id,
step_id,
output,
loop_context,
max_retries,
} = opts;
for attempt in 0..=max_retries {
let snap = self.require_frame(run_id, frame_id).await?;
let complete_input = KernelInput {
variant: "CompleteNode".into(),
fields: BTreeMap::from([("node_id".into(), Self::node_val(node_id))]),
};
let next_outcome = flow_frame::transition(&snap.kernel_state, &complete_input)
.map_err(|e| MobError::Internal(format!("CompleteNode failed: {e:?}")))?;
let next_snap = FrameSnapshot {
kernel_state: next_outcome.next_state,
};
let won = self
.run_store
.cas_complete_step_and_record_output(
run_id,
frame_id,
&snap,
next_snap,
step_id.to_string(),
output.clone(),
loop_context,
)
.await?;
if won {
return Ok(());
}
if attempt == max_retries {
return Err(MobError::Internal(format!(
"CompleteNode CAS failed after {} attempts for node '{node_id}'",
max_retries + 1
)));
}
}
Err(MobError::Internal("CompleteNode CAS exhausted".into()))
}
async fn complete_node(
&self,
run_id: &RunId,
frame_id: &FrameId,
node_id: &FlowNodeId,
) -> Result<bool, MobError> {
let input = KernelInput {
variant: "CompleteNode".into(),
fields: BTreeMap::from([("node_id".into(), Self::node_val(node_id))]),
};
self.transition_frame(run_id, frame_id, input, 5)
.await
.map(|_| true)
}
async fn fail_node(
&self,
run_id: &RunId,
frame_id: &FrameId,
node_id: &FlowNodeId,
) -> Result<bool, MobError> {
let input = KernelInput {
variant: "FailNode".into(),
fields: BTreeMap::from([("node_id".into(), Self::node_val(node_id))]),
};
self.transition_frame(run_id, frame_id, input, 5)
.await
.map(|_| true)
}
async fn skip_node(
&self,
run_id: &RunId,
frame_id: &FrameId,
node_id: &FlowNodeId,
) -> Result<bool, MobError> {
let input = KernelInput {
variant: "SkipNode".into(),
fields: BTreeMap::from([("node_id".into(), Self::node_val(node_id))]),
};
self.transition_frame(run_id, frame_id, input, 5)
.await
.map(|_| true)
}
async fn terminalize_frame(
&self,
run_id: &RunId,
frame_id: &FrameId,
) -> Result<bool, MobError> {
let input = KernelInput {
variant: "SealFrame".into(),
fields: BTreeMap::new(),
};
self.transition_frame(run_id, frame_id, input, 5)
.await
.map(|_| true)
}
async fn cancel_node(
&self,
run_id: &RunId,
frame_id: &FrameId,
node_id: &FlowNodeId,
) -> Result<bool, MobError> {
let input = KernelInput {
variant: "CancelNode".into(),
fields: BTreeMap::from([("node_id".into(), Self::node_val(node_id))]),
};
self.transition_frame(run_id, frame_id, input, 5)
.await
.map(|_| true)
}
}
fn build_frame_start_fields(
frame_id: &FrameId,
spec: &FrameSpec,
ordered: &[FlowNodeId],
) -> BTreeMap<String, KernelValue> {
let ordered_kv: Vec<KernelValue> = ordered
.iter()
.map(|n| KernelValue::String(n.to_string()))
.collect();
let tracked: BTreeSet<KernelValue> = ordered
.iter()
.map(|n| KernelValue::String(n.to_string()))
.collect();
let mut node_kind: BTreeMap<KernelValue, KernelValue> = BTreeMap::new();
let mut node_deps: BTreeMap<KernelValue, KernelValue> = BTreeMap::new();
let mut node_dep_modes: BTreeMap<KernelValue, KernelValue> = BTreeMap::new();
let mut node_branches: BTreeMap<KernelValue, KernelValue> = BTreeMap::new();
for (node_id, node_spec) in &spec.nodes {
let k = KernelValue::String(node_id.to_string());
match node_spec {
FlowNodeSpec::Step(s) => {
node_kind.insert(
k.clone(),
KernelValue::NamedVariant {
enum_name: "FlowNodeKind".into(),
variant: "Step".into(),
},
);
node_deps.insert(
k.clone(),
KernelValue::Seq(
s.depends_on
.iter()
.map(|d| KernelValue::String(d.to_string()))
.collect(),
),
);
node_dep_modes.insert(k.clone(), dep_mode_kv(&s.depends_on_mode));
node_branches.insert(
k.clone(),
s.branch
.as_ref()
.map_or(KernelValue::None, |b| KernelValue::String(b.to_string())),
);
}
FlowNodeSpec::RepeatUntil(l) => {
node_kind.insert(
k.clone(),
KernelValue::NamedVariant {
enum_name: "FlowNodeKind".into(),
variant: "Loop".into(),
},
);
node_deps.insert(
k.clone(),
KernelValue::Seq(
l.depends_on
.iter()
.map(|d| KernelValue::String(d.to_string()))
.collect(),
),
);
node_dep_modes.insert(k.clone(), dep_mode_kv(&l.depends_on_mode));
node_branches.insert(k.clone(), KernelValue::None);
}
}
}
BTreeMap::from([
("frame_id".into(), KernelValue::String(frame_id.to_string())),
("tracked_nodes".into(), KernelValue::Set(tracked)),
("ordered_nodes".into(), KernelValue::Seq(ordered_kv)),
("node_kind".into(), KernelValue::Map(node_kind)),
("node_dependencies".into(), KernelValue::Map(node_deps)),
(
"node_dependency_modes".into(),
KernelValue::Map(node_dep_modes),
),
("node_branches".into(), KernelValue::Map(node_branches)),
])
}
pub(crate) fn build_start_root_frame_input(
frame_id: &FrameId,
spec: &FrameSpec,
ordered: &[FlowNodeId],
) -> KernelInput {
KernelInput {
variant: "StartRootFrame".into(),
fields: build_frame_start_fields(frame_id, spec, ordered),
}
}
pub(crate) fn build_start_body_frame_input(
frame_id: &FrameId,
loop_instance_id: &LoopInstanceId,
iteration: u64,
spec: &FrameSpec,
ordered: &[FlowNodeId],
) -> KernelInput {
let mut fields = build_frame_start_fields(frame_id, spec, ordered);
fields.insert(
"loop_instance_id".into(),
KernelValue::String(loop_instance_id.to_string()),
);
fields.insert("iteration".into(), KernelValue::U64(iteration));
KernelInput {
variant: "StartBodyFrame".into(),
fields,
}
}
fn dep_mode_kv(mode: &DependencyMode) -> KernelValue {
let variant = match mode {
DependencyMode::All => "All",
DependencyMode::Any => "Any",
};
KernelValue::NamedVariant {
enum_name: "DependencyMode".into(),
variant: variant.into(),
}
}
pub(crate) fn topological_order(spec: &FrameSpec) -> Result<Vec<FlowNodeId>, MobError> {
let mut in_degree: BTreeMap<FlowNodeId, usize> = BTreeMap::new();
let mut outgoing: BTreeMap<FlowNodeId, Vec<FlowNodeId>> = BTreeMap::new();
for node_id in spec.nodes.keys() {
in_degree.insert(node_id.clone(), 0);
outgoing.entry(node_id.clone()).or_default();
}
for (node_id, node_spec) in &spec.nodes {
let deps = match node_spec {
FlowNodeSpec::Step(s) => s.depends_on.clone(),
FlowNodeSpec::RepeatUntil(l) => l.depends_on.clone(),
};
for dep in deps {
if !in_degree.contains_key(&dep) {
return Err(MobError::Internal(format!(
"node '{node_id}' depends on unknown node '{dep}'"
)));
}
*in_degree.entry(node_id.clone()).or_insert(0) += 1;
outgoing
.entry(dep.clone())
.or_default()
.push(node_id.clone());
}
}
let mut queue = VecDeque::new();
for node_id in spec.nodes.keys() {
if in_degree.get(node_id) == Some(&0) {
queue.push_back(node_id.clone());
}
}
let mut ordered = Vec::with_capacity(spec.nodes.len());
while let Some(next) = queue.pop_front() {
ordered.push(next.clone());
if let Some(children) = outgoing.get(&next) {
for child in children {
if let Some(count) = in_degree.get_mut(child)
&& *count > 0
{
*count -= 1;
if *count == 0 {
queue.push_back(child.clone());
}
}
}
}
}
if ordered.len() != spec.nodes.len() {
return Err(MobError::Internal(
"frame contains a cycle; cannot compute topological order".to_string(),
));
}
Ok(ordered)
}