1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
use std::{collections::BTreeMap, sync::Arc};

use miden_core::crypto::hash::RpoDigest;
use miden_processor::{
    AdviceExtractor, AdviceInjector, AdviceProvider, ExecutionError, Host, HostResponse,
    MastForest, MastForestStore, MemAdviceProvider, MemMastForestStore, ProcessState, RowIndex,
};

use super::{TraceEvent, TraceHandler};

/// This is an implementation of [Host] which is essentially [miden_processor::DefaultHost],
/// but extended with additional functionality for debugging, in particular it manages trace
/// events that record the entry or exit of a procedure call frame.
#[derive(Default)]
pub struct DebuggerHost {
    adv_provider: MemAdviceProvider,
    store: MemMastForestStore,
    tracing_callbacks: BTreeMap<u32, Vec<Box<TraceHandler>>>,
    on_assert_failed: Option<Box<TraceHandler>>,
}
impl DebuggerHost {
    /// Construct a new instance of [DebuggerHost] with the given advice provider.
    pub fn new(adv_provider: MemAdviceProvider) -> Self {
        Self {
            adv_provider,
            store: Default::default(),
            tracing_callbacks: Default::default(),
            on_assert_failed: None,
        }
    }

    /// Register a trace handler for `event`
    pub fn register_trace_handler<F>(&mut self, event: TraceEvent, callback: F)
    where
        F: FnMut(RowIndex, TraceEvent) + 'static,
    {
        let key = match event {
            TraceEvent::AssertionFailed(None) => u32::MAX,
            ev => ev.into(),
        };
        self.tracing_callbacks.entry(key).or_default().push(Box::new(callback));
    }

    /// Register a handler to be called when an assertion in the VM fails
    pub fn register_assert_failed_tracer<F>(&mut self, callback: F)
    where
        F: FnMut(RowIndex, TraceEvent) + 'static,
    {
        self.on_assert_failed = Some(Box::new(callback));
    }

    /// Load `forest` into the MAST store for this host
    pub fn load_mast_forest(&mut self, forest: MastForest) {
        self.store.insert(forest);
    }
}

impl Host for DebuggerHost {
    fn get_advice<P: ProcessState>(
        &mut self,
        process: &P,
        extractor: AdviceExtractor,
    ) -> Result<HostResponse, ExecutionError> {
        self.adv_provider.get_advice(process, &extractor)
    }

    fn set_advice<P: ProcessState>(
        &mut self,
        process: &P,
        injector: AdviceInjector,
    ) -> Result<HostResponse, ExecutionError> {
        self.adv_provider.set_advice(process, &injector)
    }

    fn get_mast_forest(&self, node_digest: &RpoDigest) -> Option<Arc<MastForest>> {
        self.store.get(node_digest)
    }

    fn on_trace<S: ProcessState>(
        &mut self,
        process: &S,
        trace_id: u32,
    ) -> Result<HostResponse, ExecutionError> {
        let event = TraceEvent::from(trace_id);
        let clk = process.clk();
        if let Some(handlers) = self.tracing_callbacks.get_mut(&trace_id) {
            for handler in handlers.iter_mut() {
                handler(clk, event);
            }
        }
        Ok(HostResponse::None)
    }

    fn on_assert_failed<S: ProcessState>(&mut self, process: &S, err_code: u32) -> ExecutionError {
        let clk = process.clk();
        if let Some(handler) = self.on_assert_failed.as_mut() {
            handler(clk, TraceEvent::AssertionFailed(core::num::NonZeroU32::new(err_code)));
        }
        let err_msg = match err_code {
            midenc_hir::ASSERT_FAILED_ALIGNMENT => Some(
                "failed alignment: use of memory address violates minimum alignment requirements \
                 for that use"
                    .to_string(),
            ),
            _ => None,
        };
        ExecutionError::FailedAssertion {
            clk,
            err_code,
            err_msg,
        }
    }
}