triton_isa/
program.rs

1use std::collections::HashMap;
2use std::collections::HashSet;
3use std::collections::hash_map::Entry;
4use std::fmt::Display;
5use std::fmt::Formatter;
6use std::fmt::Result as FmtResult;
7use std::hash::Hash;
8use std::io::Cursor;
9
10use arbitrary::Arbitrary;
11use get_size2::GetSize;
12use itertools::Itertools;
13use serde::Deserialize;
14use serde::Serialize;
15use thiserror::Error;
16use twenty_first::prelude::*;
17
18use crate::instruction::AnInstruction;
19use crate::instruction::AssertionContext;
20use crate::instruction::Instruction;
21use crate::instruction::InstructionError;
22use crate::instruction::LabelledInstruction;
23use crate::instruction::TypeHint;
24use crate::parser;
25use crate::parser::ParseError;
26
27/// A program for Triton VM. Triton VM can run and profile such programs,
28/// and trace its execution in order to generate a proof of correct execution.
29/// See there for details.
30///
31/// A program may contain debug information, such as label names and
32/// breakpoints. Access this information through methods
33/// [`label_for_address()`][label_for_address] and
34/// [`is_breakpoint()`][is_breakpoint]. Some operations, most notably
35/// [BField-encoding](BFieldCodec::encode), discard this debug information.
36///
37/// [program attestation]: https://triton-vm.org/spec/program-attestation.html
38/// [label_for_address]: Program::label_for_address
39/// [is_breakpoint]: Program::is_breakpoint
40#[derive(Debug, Clone, Eq, Serialize, Deserialize, GetSize)]
41pub struct Program {
42    pub instructions: Vec<Instruction>,
43    address_to_label: HashMap<u64, String>,
44    debug_information: DebugInformation,
45}
46
47impl Display for Program {
48    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
49        for instruction in self.labelled_instructions() {
50            writeln!(f, "{instruction}")?;
51        }
52        Ok(())
53    }
54}
55
56impl PartialEq for Program {
57    fn eq(&self, other: &Program) -> bool {
58        self.instructions.eq(&other.instructions)
59    }
60}
61
62impl BFieldCodec for Program {
63    type Error = ProgramDecodingError;
64
65    fn decode(sequence: &[BFieldElement]) -> Result<Box<Self>, Self::Error> {
66        if sequence.is_empty() {
67            return Err(Self::Error::EmptySequence);
68        }
69        let program_length = sequence[0].value() as usize;
70        let sequence = &sequence[1..];
71        if sequence.len() < program_length {
72            return Err(Self::Error::SequenceTooShort);
73        }
74        if sequence.len() > program_length {
75            return Err(Self::Error::SequenceTooLong);
76        }
77
78        // instantiating with claimed capacity is a potential DOS vector
79        let mut instructions = vec![];
80        let mut read_idx = 0;
81        while read_idx < program_length {
82            let opcode = sequence[read_idx];
83            let mut instruction = Instruction::try_from(opcode)
84                .map_err(|err| Self::Error::InvalidInstruction(read_idx, err))?;
85            let instruction_has_arg = instruction.arg().is_some();
86            if instruction_has_arg && instructions.len() + instruction.size() > program_length {
87                return Err(Self::Error::MissingArgument(read_idx, instruction));
88            }
89            if instruction_has_arg {
90                let arg = sequence[read_idx + 1];
91                instruction = instruction
92                    .change_arg(arg)
93                    .map_err(|err| Self::Error::InvalidInstruction(read_idx, err))?;
94            }
95
96            instructions.extend(vec![instruction; instruction.size()]);
97            read_idx += instruction.size();
98        }
99
100        if read_idx != program_length {
101            return Err(Self::Error::LengthMismatch);
102        }
103        if instructions.len() != program_length {
104            return Err(Self::Error::LengthMismatch);
105        }
106
107        Ok(Box::new(Program {
108            instructions,
109            address_to_label: HashMap::default(),
110            debug_information: DebugInformation::default(),
111        }))
112    }
113
114    fn encode(&self) -> Vec<BFieldElement> {
115        let mut sequence = Vec::with_capacity(self.len_bwords() + 1);
116        sequence.push(bfe!(self.len_bwords() as u64));
117        sequence.extend(self.to_bwords());
118        sequence
119    }
120
121    fn static_length() -> Option<usize> {
122        None
123    }
124}
125
126impl<'a> Arbitrary<'a> for Program {
127    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
128        let contains_label = |labelled_instructions: &[_], maybe_label: &_| {
129            let LabelledInstruction::Label(label) = maybe_label else {
130                return false;
131            };
132            labelled_instructions
133                .iter()
134                .any(|labelled_instruction| match labelled_instruction {
135                    LabelledInstruction::Label(l) => l == label,
136                    _ => false,
137                })
138        };
139        let is_assertion = |maybe_instruction: &_| {
140            matches!(
141                maybe_instruction,
142                LabelledInstruction::Instruction(
143                    AnInstruction::Assert | AnInstruction::AssertVector
144                )
145            )
146        };
147
148        let mut labelled_instructions = vec![];
149        for _ in 0..u.arbitrary_len::<LabelledInstruction>()? {
150            let labelled_instruction = u.arbitrary()?;
151            if contains_label(&labelled_instructions, &labelled_instruction) {
152                continue;
153            }
154            if let LabelledInstruction::AssertionContext(_) = labelled_instruction {
155                // assertion context must come after an assertion
156                continue;
157            }
158
159            let is_assertion = is_assertion(&labelled_instruction);
160            labelled_instructions.push(labelled_instruction);
161
162            if is_assertion && u.arbitrary()? {
163                let assertion_context = LabelledInstruction::AssertionContext(u.arbitrary()?);
164                labelled_instructions.push(assertion_context);
165            }
166        }
167
168        let all_call_targets = labelled_instructions
169            .iter()
170            .filter_map(|instruction| match instruction {
171                LabelledInstruction::Instruction(AnInstruction::Call(target)) => Some(target),
172                _ => None,
173            })
174            .unique();
175        let labels_that_are_called_but_not_declared = all_call_targets
176            .map(|target| LabelledInstruction::Label(target.clone()))
177            .filter(|label| !contains_label(&labelled_instructions, label))
178            .collect_vec();
179
180        for label in labels_that_are_called_but_not_declared {
181            let insertion_index = u.choose_index(labelled_instructions.len() + 1)?;
182            labelled_instructions.insert(insertion_index, label);
183        }
184
185        Ok(Program::new(&labelled_instructions))
186    }
187}
188
189/// An `InstructionIter` loops the instructions of a `Program` by skipping
190/// duplicate placeholders.
191#[derive(Debug, Default, Clone, Eq, PartialEq)]
192pub struct InstructionIter {
193    cursor: Cursor<Vec<Instruction>>,
194}
195
196impl Iterator for InstructionIter {
197    type Item = Instruction;
198
199    fn next(&mut self) -> Option<Self::Item> {
200        let pos = self.cursor.position() as usize;
201        let instructions = self.cursor.get_ref();
202        let instruction = *instructions.get(pos)?;
203        self.cursor.set_position((pos + instruction.size()) as u64);
204
205        Some(instruction)
206    }
207}
208
209impl IntoIterator for Program {
210    type Item = Instruction;
211
212    type IntoIter = InstructionIter;
213
214    fn into_iter(self) -> Self::IntoIter {
215        let cursor = Cursor::new(self.instructions);
216        InstructionIter { cursor }
217    }
218}
219
220#[derive(Debug, Default, Clone, Eq, PartialEq, Serialize, Deserialize, Arbitrary, GetSize)]
221struct DebugInformation {
222    breakpoints: Vec<bool>,
223    type_hints: HashMap<u64, Vec<TypeHint>>,
224    assertion_context: HashMap<u64, AssertionContext>,
225}
226
227impl Program {
228    pub fn new(labelled_instructions: &[LabelledInstruction]) -> Self {
229        let label_to_address = parser::build_label_to_address_map(labelled_instructions);
230        let instructions =
231            parser::turn_labels_into_addresses(labelled_instructions, &label_to_address);
232        let address_to_label = Self::flip_map(label_to_address);
233        let debug_information = Self::extract_debug_information(labelled_instructions);
234
235        debug_assert_eq!(instructions.len(), debug_information.breakpoints.len());
236        Program {
237            instructions,
238            address_to_label,
239            debug_information,
240        }
241    }
242
243    fn flip_map<Key, Value: Eq + Hash>(map: HashMap<Key, Value>) -> HashMap<Value, Key> {
244        map.into_iter().map(|(key, value)| (value, key)).collect()
245    }
246
247    fn extract_debug_information(
248        labelled_instructions: &[LabelledInstruction],
249    ) -> DebugInformation {
250        let mut address = 0;
251        let mut break_before_next_instruction = false;
252        let mut debug_info = DebugInformation::default();
253        for instruction in labelled_instructions {
254            match instruction {
255                LabelledInstruction::Instruction(instruction) => {
256                    let new_breakpoints = vec![break_before_next_instruction; instruction.size()];
257                    debug_info.breakpoints.extend(new_breakpoints);
258                    break_before_next_instruction = false;
259                    address += instruction.size() as u64;
260                }
261                LabelledInstruction::Label(_) => (),
262                LabelledInstruction::Breakpoint => break_before_next_instruction = true,
263                LabelledInstruction::TypeHint(hint) => match debug_info.type_hints.entry(address) {
264                    Entry::Occupied(mut entry) => entry.get_mut().push(hint.clone()),
265                    Entry::Vacant(entry) => entry.insert(vec![]).push(hint.clone()),
266                },
267                LabelledInstruction::AssertionContext(ctx) => {
268                    let address_of_associated_assertion = address.saturating_sub(1);
269                    debug_info
270                        .assertion_context
271                        .insert(address_of_associated_assertion, ctx.clone());
272                }
273            }
274        }
275
276        debug_info
277    }
278
279    /// Create a `Program` by parsing source code.
280    pub fn from_code(code: &str) -> Result<Self, ParseError<'_>> {
281        parser::parse(code)
282            .map(|tokens| parser::to_labelled_instructions(&tokens))
283            .map(|instructions| Program::new(&instructions))
284    }
285
286    pub fn labelled_instructions(&self) -> Vec<LabelledInstruction> {
287        let call_targets = self.call_targets();
288        let instructions_with_labels = self.instructions.iter().map(|instruction| {
289            instruction.map_call_address(|&address| self.label_for_address(address.value()))
290        });
291
292        let mut labelled_instructions = vec![];
293        let mut address = 0;
294        let mut instruction_stream = instructions_with_labels.into_iter();
295        while let Some(instruction) = instruction_stream.next() {
296            let instruction_size = instruction.size() as u64;
297            if call_targets.contains(&address) {
298                let label = self.label_for_address(address);
299                let label = LabelledInstruction::Label(label);
300                labelled_instructions.push(label);
301            }
302            for type_hint in self.type_hints_at(address) {
303                labelled_instructions.push(LabelledInstruction::TypeHint(type_hint));
304            }
305            if self.is_breakpoint(address) {
306                labelled_instructions.push(LabelledInstruction::Breakpoint);
307            }
308            labelled_instructions.push(LabelledInstruction::Instruction(instruction));
309            if let Some(context) = self.assertion_context_at(address) {
310                labelled_instructions.push(LabelledInstruction::AssertionContext(context));
311            }
312
313            for _ in 1..instruction_size {
314                instruction_stream.next();
315            }
316            address += instruction_size;
317        }
318
319        let leftover_labels = self
320            .address_to_label
321            .iter()
322            .filter(|&(&labels_address, _)| labels_address >= address)
323            .sorted();
324        for (_, label) in leftover_labels {
325            labelled_instructions.push(LabelledInstruction::Label(label.clone()));
326        }
327
328        labelled_instructions
329    }
330
331    fn call_targets(&self) -> HashSet<u64> {
332        self.instructions
333            .iter()
334            .filter_map(|instruction| match instruction {
335                Instruction::Call(address) => Some(address.value()),
336                _ => None,
337            })
338            .collect()
339    }
340
341    pub fn is_breakpoint(&self, address: u64) -> bool {
342        let address: usize = address.try_into().unwrap();
343        self.debug_information
344            .breakpoints
345            .get(address)
346            .copied()
347            .unwrap_or_default()
348    }
349
350    pub fn type_hints_at(&self, address: u64) -> Vec<TypeHint> {
351        self.debug_information
352            .type_hints
353            .get(&address)
354            .cloned()
355            .unwrap_or_default()
356    }
357
358    pub fn assertion_context_at(&self, address: u64) -> Option<AssertionContext> {
359        self.debug_information
360            .assertion_context
361            .get(&address)
362            .cloned()
363    }
364
365    /// Turn the program into a sequence of `BFieldElement`s. Each instruction
366    /// is encoded as its opcode, followed by its argument (if any).
367    ///
368    /// **Note**: This is _almost_ (but not quite!) equivalent to
369    /// [encoding](BFieldCodec::encode) the program. For that, use
370    /// [`encode()`](Self::encode()) instead.
371    pub fn to_bwords(&self) -> Vec<BFieldElement> {
372        self.clone()
373            .into_iter()
374            .flat_map(|instruction| {
375                let opcode = instruction.opcode_b();
376                if let Some(arg) = instruction.arg() {
377                    vec![opcode, arg]
378                } else {
379                    vec![opcode]
380                }
381            })
382            .collect()
383    }
384
385    /// The total length of the program as `BFieldElement`s. Double-word
386    /// instructions contribute two `BFieldElement`s.
387    pub fn len_bwords(&self) -> usize {
388        self.instructions.len()
389    }
390
391    pub fn is_empty(&self) -> bool {
392        self.instructions.is_empty()
393    }
394
395    /// Produces the program's canonical hash digest. Uses [`Tip5`], the
396    /// canonical hash function for Triton VM.
397    pub fn hash(&self) -> Digest {
398        // not encoded using `BFieldCodec` because that would prepend the length
399        Tip5::hash_varlen(&self.to_bwords())
400    }
401
402    /// The label for the given address, or a deterministic, unique substitute
403    /// if no label is found.
404    pub fn label_for_address(&self, address: u64) -> String {
405        // Uniqueness of the label is relevant for printing and subsequent
406        // parsing: Parsing fails on duplicate labels.
407        self.address_to_label
408            .get(&address)
409            .cloned()
410            .unwrap_or_else(|| format!("address_{address}"))
411    }
412}
413
414#[non_exhaustive]
415#[derive(Debug, Clone, Eq, PartialEq, Error)]
416pub enum ProgramDecodingError {
417    #[error("sequence to decode is empty")]
418    EmptySequence,
419
420    #[error("sequence to decode is too short")]
421    SequenceTooShort,
422
423    #[error("sequence to decode is too long")]
424    SequenceTooLong,
425
426    #[error("length of decoded program is unexpected")]
427    LengthMismatch,
428
429    #[error("sequence to decode contains invalid instruction at index {0}: {1}")]
430    InvalidInstruction(usize, InstructionError),
431
432    #[error("missing argument for instruction {1} at index {0}")]
433    MissingArgument(usize, Instruction),
434}
435
436#[cfg(test)]
437#[cfg_attr(coverage_nightly, coverage(off))]
438mod tests {
439    use assert2::assert;
440    use assert2::let_assert;
441    use proptest::prelude::*;
442    use proptest_arbitrary_interop::arb;
443    use rand::Rng;
444    use test_strategy::proptest;
445
446    use crate::triton_program;
447
448    use super::*;
449
450    #[proptest]
451    fn random_program_encode_decode_equivalence(#[strategy(arb())] program: Program) {
452        let encoding = program.encode();
453        let decoding = *Program::decode(&encoding).unwrap();
454        prop_assert_eq!(program, decoding);
455    }
456
457    #[test]
458    fn decode_program_with_missing_argument_as_last_instruction() {
459        let program = triton_program!(push 3 push 3 eq assert push 3);
460        let program_length = program.len_bwords() as u64;
461        let encoded = program.encode();
462
463        let mut encoded = encoded[0..encoded.len() - 1].to_vec();
464        encoded[0] = bfe!(program_length - 1);
465
466        let_assert!(Err(err) = Program::decode(&encoded));
467        let_assert!(ProgramDecodingError::MissingArgument(6, _) = err);
468    }
469
470    #[test]
471    fn decode_program_with_shorter_than_indicated_sequence() {
472        let program = triton_program!(nop nop hash push 0 skiz end: halt call end);
473        let mut encoded = program.encode();
474        encoded[0] += bfe!(1);
475        let_assert!(Err(err) = Program::decode(&encoded));
476        let_assert!(ProgramDecodingError::SequenceTooShort = err);
477    }
478
479    #[test]
480    fn decode_program_with_longer_than_indicated_sequence() {
481        let program = triton_program!(nop nop hash push 0 skiz end: halt call end);
482        let mut encoded = program.encode();
483        encoded[0] -= bfe!(1);
484        let_assert!(Err(err) = Program::decode(&encoded));
485        let_assert!(ProgramDecodingError::SequenceTooLong = err);
486    }
487
488    #[test]
489    fn decode_program_from_empty_sequence() {
490        let encoded = vec![];
491        let_assert!(Err(err) = Program::decode(&encoded));
492        let_assert!(ProgramDecodingError::EmptySequence = err);
493    }
494
495    #[test]
496    fn hash_simple_program() {
497        let program = triton_program!(halt);
498        let digest = program.hash();
499
500        let expected_digest = bfe_array![
501            0x4338_de79_520b_3949_u64,
502            0xe6a2_129b_2885_0dc9_u64,
503            0xfd3c_d098_6a86_0450_u64,
504            0x69fd_ba91_0ceb_a7bc_u64,
505            0x7e5b_118c_9594_c062_u64,
506        ];
507        let expected_digest = Digest::new(expected_digest);
508
509        assert!(expected_digest == digest);
510    }
511
512    #[test]
513    fn empty_program_is_empty() {
514        let program = triton_program!();
515        assert!(program.is_empty());
516    }
517
518    #[test]
519    fn create_program_from_code() {
520        let element_3 = rand::rng().random_range(0..BFieldElement::P);
521        let element_2 = 1337_usize;
522        let element_1 = "17";
523        let element_0 = bfe!(0);
524        let instruction_push = Instruction::Push(bfe!(42));
525        let dup_arg = 1;
526        let label = "my_label".to_string();
527
528        let source_code = format!(
529            "push {element_3} push {element_2} push {element_1} push {element_0}
530             call {label} halt
531             {label}:
532                {instruction_push}
533                dup {dup_arg}
534                skiz
535                recurse
536                return"
537        );
538        let program_from_code = Program::from_code(&source_code).unwrap();
539        let program_from_macro = triton_program!({ source_code });
540        assert!(program_from_code == program_from_macro);
541    }
542
543    #[test]
544    fn parser_macro_with_interpolated_label_as_first_argument() {
545        let label = "my_label";
546        let _program = triton_program!(
547            {label}: push 1 assert halt
548        );
549    }
550
551    #[test]
552    fn breakpoints_propagate_to_debug_information_as_expected() {
553        let program = triton_program! {
554            break push 1 push 2
555            break break break break
556            pop 2 hash halt
557            break // no effect
558        };
559
560        assert!(program.is_breakpoint(0));
561        assert!(program.is_breakpoint(1));
562        assert!(!program.is_breakpoint(2));
563        assert!(!program.is_breakpoint(3));
564        assert!(program.is_breakpoint(4));
565        assert!(program.is_breakpoint(5));
566        assert!(!program.is_breakpoint(6));
567        assert!(!program.is_breakpoint(7));
568
569        // going beyond the length of the program must not break things
570        assert!(!program.is_breakpoint(8));
571        assert!(!program.is_breakpoint(9));
572    }
573
574    #[test]
575    fn print_program_without_any_debug_information() {
576        let program = triton_program! {
577            call foo
578            call bar
579            call baz
580            halt
581            foo: nop nop return
582            bar: call baz return
583            baz: push 1 return
584        };
585        let encoding = program.encode();
586        let program = Program::decode(&encoding).unwrap();
587        println!("{program}");
588    }
589
590    #[proptest]
591    fn printed_program_can_be_parsed_again(#[strategy(arb())] program: Program) {
592        parser::parse(&program.to_string())?;
593    }
594
595    struct TypeHintTestCase {
596        expected: TypeHint,
597        input: &'static str,
598    }
599
600    impl TypeHintTestCase {
601        fn run(&self) {
602            let program = Program::from_code(self.input).unwrap();
603            let [ref type_hint] = program.type_hints_at(0)[..] else {
604                panic!("Expected a single type hint at address 0");
605            };
606            assert!(&self.expected == type_hint);
607        }
608    }
609
610    #[test]
611    fn parse_simple_type_hint() {
612        let expected = TypeHint {
613            starting_index: 0,
614            length: 1,
615            type_name: Some("Type".to_string()),
616            variable_name: "foo".to_string(),
617        };
618
619        TypeHintTestCase {
620            expected,
621            input: "hint foo: Type = stack[0]",
622        }
623        .run();
624    }
625
626    #[test]
627    fn parse_type_hint_with_range() {
628        let expected = TypeHint {
629            starting_index: 0,
630            length: 5,
631            type_name: Some("Digest".to_string()),
632            variable_name: "foo".to_string(),
633        };
634
635        TypeHintTestCase {
636            expected,
637            input: "hint foo: Digest = stack[0..5]",
638        }
639        .run();
640    }
641
642    #[test]
643    fn parse_type_hint_with_range_and_offset() {
644        let expected = TypeHint {
645            starting_index: 7,
646            length: 3,
647            type_name: Some("XFieldElement".to_string()),
648            variable_name: "bar".to_string(),
649        };
650
651        TypeHintTestCase {
652            expected,
653            input: "hint bar: XFieldElement = stack[7..10]",
654        }
655        .run();
656    }
657
658    #[test]
659    fn parse_type_hint_with_range_and_offset_and_weird_whitespace() {
660        let expected = TypeHint {
661            starting_index: 2,
662            length: 12,
663            type_name: Some("BigType".to_string()),
664            variable_name: "bar".to_string(),
665        };
666
667        TypeHintTestCase {
668            expected,
669            input: " hint \t \t bar  :BigType=stack[ 2\t.. 14 ]\t \n",
670        }
671        .run();
672    }
673
674    #[test]
675    fn parse_type_hint_with_no_type_only_variable_name() {
676        let expected = TypeHint {
677            starting_index: 0,
678            length: 1,
679            type_name: None,
680            variable_name: "foo".to_string(),
681        };
682
683        TypeHintTestCase {
684            expected,
685            input: "hint foo = stack[0]",
686        }
687        .run();
688    }
689
690    #[test]
691    fn parse_type_hint_with_no_type_only_variable_name_with_range() {
692        let expected = TypeHint {
693            starting_index: 2,
694            length: 5,
695            type_name: None,
696            variable_name: "foo".to_string(),
697        };
698
699        TypeHintTestCase {
700            expected,
701            input: "hint foo = stack[2..7]",
702        }
703        .run();
704    }
705
706    #[test]
707    fn assertion_context_is_propagated_into_debug_info() {
708        let program = triton_program! {push 1000 assert error_id 17 halt};
709        //                              ↑0   ↑1   ↑2
710
711        let assertion_contexts = program.debug_information.assertion_context;
712        assert!(1 == assertion_contexts.len());
713        let_assert!(AssertionContext::ID(error_id) = &assertion_contexts[&2]);
714        assert!(17 == *error_id);
715    }
716
717    #[test]
718    fn printing_program_includes_debug_information() {
719        let source_code = "\
720            call foo\n\
721            break\n\
722            call bar\n\
723            halt\n\
724            foo:\n\
725            break\n\
726            call baz\n\
727            push 1\n\
728            nop\n\
729            return\n\
730            baz:\n\
731            hash\n\
732            hint my_digest: Digest = stack[0..5]\n\
733            hint random_stuff = stack[17]\n\
734            return\n\
735            nop\n\
736            pop 1\n\
737            bar:\n\
738            divine 1\n\
739            hint got_insight: Magic = stack[0]\n\
740            skiz\n\
741            split\n\
742            break\n\
743            assert\n\
744            error_id 1337\n\
745            return\n\
746        ";
747        let program = Program::from_code(source_code).unwrap();
748        let printed_program = format!("{program}");
749        assert_eq!(source_code, &printed_program);
750    }
751}