1use std::{collections::BTreeMap, num::NonZeroU32, sync::Arc};
2
3use miden_assembly::SourceManager;
4use miden_core::{
5 Word,
6 events::{EventId, EventName},
7};
8use miden_debug_types::{Location, SourceFile, SourceSpan};
9use miden_processor::{
10 ExecutionError, FutureMaybeSend, Host, MastForestStore, MemMastForestStore, ProcessorState,
11 TraceError,
12 advice::AdviceMutation,
13 event::{EventError, EventHandler, EventHandlerRegistry},
14 mast::MastForest,
15 trace::RowIndex,
16};
17
18use super::{TraceEvent, TraceHandler};
19
20pub struct DebuggerHost<S: SourceManager + ?Sized> {
24 store: MemMastForestStore,
25 event_handlers: EventHandlerRegistry,
26 tracing_callbacks: BTreeMap<u32, Vec<Box<TraceHandler>>>,
27 on_assert_failed: Option<Box<TraceHandler>>,
28 source_manager: Arc<S>,
29}
30impl<S> DebuggerHost<S>
31where
32 S: SourceManager + ?Sized,
33{
34 pub fn new(source_manager: Arc<S>) -> Self {
36 Self {
37 store: Default::default(),
38 event_handlers: EventHandlerRegistry::default(),
39 tracing_callbacks: Default::default(),
40 on_assert_failed: None,
41 source_manager,
42 }
43 }
44
45 pub fn register_trace_handler<F>(&mut self, event: TraceEvent, callback: F)
47 where
48 F: FnMut(RowIndex, TraceEvent) + 'static,
49 {
50 let key = match event {
51 TraceEvent::AssertionFailed(None) => u32::MAX,
52 ev => ev.into(),
53 };
54 self.tracing_callbacks.entry(key).or_default().push(Box::new(callback));
55 }
56
57 pub fn register_assert_failed_tracer<F>(&mut self, callback: F)
59 where
60 F: FnMut(RowIndex, TraceEvent) + 'static,
61 {
62 self.on_assert_failed = Some(Box::new(callback));
63 }
64
65 pub fn handle_assert_failed(&mut self, clk: RowIndex, err_code: Option<NonZeroU32>) {
70 if let Some(handler) = self.on_assert_failed.as_mut() {
71 handler(clk, TraceEvent::AssertionFailed(err_code));
72 }
73 }
74
75 pub fn load_mast_forest(&mut self, forest: Arc<MastForest>) {
77 self.store.insert(forest);
78 }
79
80 pub fn register_event_handler(
82 &mut self,
83 event: EventName,
84 handler: Arc<dyn EventHandler>,
85 ) -> Result<(), ExecutionError> {
86 self.event_handlers.register(event, handler)
87 }
88}
89
90impl<S> Host for DebuggerHost<S>
91where
92 S: SourceManager + ?Sized,
93{
94 fn get_label_and_source_file(
95 &self,
96 location: &Location,
97 ) -> (SourceSpan, Option<Arc<SourceFile>>) {
98 let maybe_file = self.source_manager.get_by_uri(location.uri());
99 let span = self.source_manager.location_to_span(location.clone()).unwrap_or_default();
100 (span, maybe_file)
101 }
102
103 fn get_mast_forest(&self, node_digest: &Word) -> impl FutureMaybeSend<Option<Arc<MastForest>>> {
104 std::future::ready(self.store.get(node_digest))
105 }
106
107 fn on_event(
108 &mut self,
109 process: &ProcessorState<'_>,
110 ) -> impl FutureMaybeSend<Result<Vec<AdviceMutation>, EventError>> {
111 let event_id = EventId::from_felt(process.get_stack_item(0));
112 let result = match self.event_handlers.handle_event(event_id, process) {
113 Ok(Some(mutations)) => Ok(mutations),
114 Ok(None) => {
115 #[derive(Debug, thiserror::Error)]
116 #[error("no event handler registered")]
117 struct UnhandledEvent;
118
119 Err(UnhandledEvent.into())
120 }
121 Err(err) => Err(err),
122 };
123 std::future::ready(result)
124 }
125
126 fn on_trace(&mut self, process: &ProcessorState<'_>, trace_id: u32) -> Result<(), TraceError> {
127 let event = TraceEvent::from(trace_id);
128 let clk = process.clock();
129 if let Some(handlers) = self.tracing_callbacks.get_mut(&trace_id) {
130 for handler in handlers.iter_mut() {
131 handler(clk, event);
132 }
133 }
134 Ok(())
135 }
136
137 fn resolve_event(&self, event_id: EventId) -> Option<&EventName> {
138 self.event_handlers.resolve_event(event_id)
139 }
140}