use alloc::{sync::Arc, vec::Vec};
use miden_core::{
mast::{MastForest, MastNodeId},
program::Program,
};
const CONTINUATION_STACK_SIZE_HINT: usize = 64;
#[derive(Debug, Clone)]
pub enum Continuation {
StartNode(MastNodeId),
FinishJoin(MastNodeId),
FinishSplit(MastNodeId),
FinishLoop { node_id: MastNodeId, was_entered: bool },
FinishCall(MastNodeId),
FinishDyn(MastNodeId),
FinishExternal(MastNodeId),
ResumeBasicBlock {
node_id: MastNodeId,
batch_index: usize,
op_idx_in_batch: usize,
},
Respan { node_id: MastNodeId, batch_index: usize },
FinishBasicBlock(MastNodeId),
EnterForest(Arc<MastForest>),
AfterExitDecorators(MastNodeId),
AfterExitDecoratorsBasicBlock(MastNodeId),
}
impl Continuation {
pub fn increments_clk(&self) -> bool {
use Continuation::*;
match self {
StartNode(_)
| FinishJoin(_)
| FinishSplit(_)
| FinishLoop { node_id: _, was_entered: _ }
| FinishCall(_)
| FinishDyn(_)
| ResumeBasicBlock {
node_id: _,
batch_index: _,
op_idx_in_batch: _,
}
| Respan { node_id: _, batch_index: _ }
| FinishBasicBlock(_) => true,
FinishExternal(_)
| EnterForest(_)
| AfterExitDecorators(_)
| AfterExitDecoratorsBasicBlock(_) => false,
}
}
}
#[derive(Debug, Default, Clone)]
pub struct ContinuationStack {
stack: Vec<Continuation>,
}
impl ContinuationStack {
pub fn new(program: &Program) -> Self {
let mut stack = Vec::with_capacity(CONTINUATION_STACK_SIZE_HINT);
stack.push(Continuation::StartNode(program.entrypoint()));
Self { stack }
}
pub fn push_continuation(&mut self, continuation: Continuation) {
self.stack.push(continuation);
}
pub fn push_enter_forest(&mut self, forest: Arc<MastForest>) {
self.stack.push(Continuation::EnterForest(forest));
}
pub fn push_finish_join(&mut self, node_id: MastNodeId) {
self.stack.push(Continuation::FinishJoin(node_id));
}
pub fn push_finish_split(&mut self, node_id: MastNodeId) {
self.stack.push(Continuation::FinishSplit(node_id));
}
pub fn push_finish_loop_entered(&mut self, node_id: MastNodeId) {
self.stack.push(Continuation::FinishLoop { node_id, was_entered: true });
}
pub fn push_finish_call(&mut self, node_id: MastNodeId) {
self.stack.push(Continuation::FinishCall(node_id));
}
pub fn push_finish_dyn(&mut self, node_id: MastNodeId) {
self.stack.push(Continuation::FinishDyn(node_id));
}
pub fn push_finish_external(&mut self, node_id: MastNodeId) {
self.stack.push(Continuation::FinishExternal(node_id));
}
pub fn push_start_node(&mut self, node_id: MastNodeId) {
self.stack.push(Continuation::StartNode(node_id));
}
pub fn pop_continuation(&mut self) -> Option<Continuation> {
self.stack.pop()
}
pub fn len(&self) -> usize {
self.stack.len()
}
pub fn peek_continuation(&self) -> Option<&Continuation> {
self.stack.last()
}
pub fn iter_continuations_for_next_clock(&self) -> impl Iterator<Item = &Continuation> {
let mut found_incrementing_cont = false;
self.stack.iter().rev().take_while(move |continuation| {
if found_incrementing_cont {
false
} else if continuation.increments_clk() {
found_incrementing_cont = true;
true
} else {
true
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn get_next_clock_cycle_increment_empty_stack() {
let stack = ContinuationStack::default();
let result: Vec<_> = stack.iter_continuations_for_next_clock().collect();
assert!(result.is_empty());
}
#[test]
fn get_next_clock_cycle_increment_ends_with_incrementing() {
let mut stack = ContinuationStack::default();
stack.push_continuation(Continuation::StartNode(MastNodeId::new_unchecked(0)));
let result: Vec<_> = stack.iter_continuations_for_next_clock().collect();
assert_eq!(result.len(), 1);
assert!(matches!(result[0], Continuation::StartNode(_)));
}
#[test]
fn get_next_clock_cycle_increment_non_incrementing_after_incrementing() {
let mut stack = ContinuationStack::default();
stack.push_continuation(Continuation::StartNode(MastNodeId::new_unchecked(0)));
stack.push_continuation(Continuation::AfterExitDecorators(MastNodeId::new_unchecked(0)));
let result: Vec<_> = stack.iter_continuations_for_next_clock().collect();
assert_eq!(result.len(), 2);
assert!(matches!(result[0], Continuation::AfterExitDecorators(_)));
assert!(matches!(result[1], Continuation::StartNode(_)));
}
#[test]
fn get_next_clock_cycle_increment_two_non_incrementing_after_incrementing() {
let mut stack = ContinuationStack::default();
stack.push_continuation(Continuation::StartNode(MastNodeId::new_unchecked(0)));
stack.push_continuation(Continuation::AfterExitDecorators(MastNodeId::new_unchecked(0)));
stack.push_continuation(Continuation::EnterForest(Arc::new(MastForest::new())));
let result: Vec<_> = stack.iter_continuations_for_next_clock().collect();
assert_eq!(result.len(), 3);
assert!(matches!(result[0], Continuation::EnterForest(_)));
assert!(matches!(result[1], Continuation::AfterExitDecorators(_)));
assert!(matches!(result[2], Continuation::StartNode(_)));
}
}