use std::{
collections::{BTreeMap, VecDeque},
num::NonZeroU32,
sync::Arc,
};
use miden_assembly::SourceManager;
use miden_core::{
Word,
events::{EventId, EventName},
};
use miden_debug_types::{Location, SourceFile, SourceSpan};
use miden_processor::{
ExecutionError, FutureMaybeSend, Host, MastForestStore, MemMastForestStore, ProcessorState,
TraceError,
advice::AdviceMutation,
event::{EventError, EventHandler, EventHandlerRegistry},
mast::MastForest,
trace::RowIndex,
};
use super::{TraceEvent, TraceHandler};
pub struct DebuggerHost<S: SourceManager + ?Sized> {
store: MemMastForestStore,
event_handlers: EventHandlerRegistry,
tracing_callbacks: BTreeMap<u32, Vec<Box<TraceHandler>>>,
on_assert_failed: Option<Box<TraceHandler>>,
source_manager: Arc<S>,
event_replay: VecDeque<Vec<AdviceMutation>>,
}
impl<S> DebuggerHost<S>
where
S: SourceManager + ?Sized,
{
pub fn new(source_manager: Arc<S>) -> Self {
Self {
store: Default::default(),
event_handlers: EventHandlerRegistry::default(),
tracing_callbacks: Default::default(),
on_assert_failed: None,
source_manager,
event_replay: VecDeque::new(),
}
}
pub fn set_event_replay(&mut self, events: VecDeque<Vec<AdviceMutation>>) {
self.event_replay = events;
}
pub fn register_trace_handler<F>(&mut self, event: TraceEvent, callback: F)
where
F: FnMut(RowIndex, TraceEvent) + 'static,
{
let key = match event {
TraceEvent::AssertionFailed(None) => u32::MAX,
ev => ev.into(),
};
self.tracing_callbacks.entry(key).or_default().push(Box::new(callback));
}
pub fn register_assert_failed_tracer<F>(&mut self, callback: F)
where
F: FnMut(RowIndex, TraceEvent) + 'static,
{
self.on_assert_failed = Some(Box::new(callback));
}
pub fn handle_assert_failed(&mut self, clk: RowIndex, err_code: Option<NonZeroU32>) {
if let Some(handler) = self.on_assert_failed.as_mut() {
handler(clk, TraceEvent::AssertionFailed(err_code));
}
}
pub fn load_mast_forest(&mut self, forest: Arc<MastForest>) {
self.store.insert(forest);
}
pub fn register_event_handler(
&mut self,
event: EventName,
handler: Arc<dyn EventHandler>,
) -> Result<(), ExecutionError> {
self.event_handlers.register(event, handler)
}
}
impl<S> Host for DebuggerHost<S>
where
S: SourceManager + ?Sized,
{
fn get_label_and_source_file(
&self,
location: &Location,
) -> (SourceSpan, Option<Arc<SourceFile>>) {
let maybe_file = self.source_manager.get_by_uri(location.uri());
let span = self.source_manager.location_to_span(location.clone()).unwrap_or_default();
(span, maybe_file)
}
fn get_mast_forest(&self, node_digest: &Word) -> impl FutureMaybeSend<Option<Arc<MastForest>>> {
std::future::ready(self.store.get(node_digest))
}
fn on_event(
&mut self,
process: &ProcessorState<'_>,
) -> impl FutureMaybeSend<Result<Vec<AdviceMutation>, EventError>> {
if !self.event_replay.is_empty() {
let mutations = self.event_replay.pop_front().unwrap_or_default();
return std::future::ready(Ok(mutations));
}
let event_id = EventId::from_felt(process.get_stack_item(0));
let result = match self.event_handlers.handle_event(event_id, process) {
Ok(Some(mutations)) => Ok(mutations),
Ok(None) => {
#[derive(Debug, thiserror::Error)]
#[error("no event handler registered")]
struct UnhandledEvent;
Err(UnhandledEvent.into())
}
Err(err) => Err(err),
};
std::future::ready(result)
}
fn on_trace(&mut self, process: &ProcessorState<'_>, trace_id: u32) -> Result<(), TraceError> {
let event = TraceEvent::from(trace_id);
let clk = process.clock();
if let Some(handlers) = self.tracing_callbacks.get_mut(&trace_id) {
for handler in handlers.iter_mut() {
handler(clk, event);
}
}
Ok(())
}
fn resolve_event(&self, event_id: EventId) -> Option<&EventName> {
self.event_handlers.resolve_event(event_id)
}
}