Skip to main content

cairo_vm/
air_public_input.rs

1use crate::Felt252;
2use serde::{Deserialize, Serialize};
3use thiserror::Error;
4
5use std::collections::HashMap;
6
7use crate::vm::{
8    errors::{trace_errors::TraceError, vm_errors::VirtualMachineError},
9    trace::trace_entry::RelocatedTraceEntry,
10};
11
12#[derive(Serialize, Deserialize, Debug, PartialEq)]
13pub struct PublicMemoryEntry {
14    pub address: usize,
15    #[serde(serialize_with = "mem_value_serde::serialize")]
16    #[serde(deserialize_with = "mem_value_serde::deserialize")]
17    pub value: Option<Felt252>,
18    pub page: usize,
19}
20
21mod mem_value_serde {
22    use core::fmt;
23
24    use super::*;
25
26    use serde::{de, Deserializer, Serializer};
27
28    pub(crate) fn serialize<S: Serializer>(
29        value: &Option<Felt252>,
30        serializer: S,
31    ) -> Result<S::Ok, S::Error> {
32        if let Some(value) = value {
33            serializer.serialize_str(&format!("0x{:x}", value))
34        } else {
35            serializer.serialize_none()
36        }
37    }
38
39    pub(crate) fn deserialize<'de, D: Deserializer<'de>>(
40        d: D,
41    ) -> Result<Option<Felt252>, D::Error> {
42        d.deserialize_str(Felt252OptionVisitor)
43    }
44
45    struct Felt252OptionVisitor;
46
47    impl de::Visitor<'_> for Felt252OptionVisitor {
48        type Value = Option<Felt252>;
49
50        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
51            formatter.write_str("Could not deserialize hexadecimal string")
52        }
53
54        fn visit_none<E>(self) -> Result<Self::Value, E>
55        where
56            E: de::Error,
57        {
58            Ok(None)
59        }
60
61        fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
62        where
63            E: de::Error,
64        {
65            Felt252::from_hex(value)
66                .map_err(de::Error::custom)
67                .map(Some)
68        }
69    }
70}
71
72#[derive(Serialize, Deserialize, Debug, PartialEq)]
73pub struct MemorySegmentAddresses {
74    pub begin_addr: usize,
75    pub stop_ptr: usize,
76}
77
78impl From<(usize, usize)> for MemorySegmentAddresses {
79    fn from(addresses: (usize, usize)) -> Self {
80        let (begin_addr, stop_ptr) = addresses;
81        MemorySegmentAddresses {
82            begin_addr,
83            stop_ptr,
84        }
85    }
86}
87
88#[allow(clippy::manual_non_exhaustive)]
89#[derive(Serialize, Deserialize, Debug)]
90pub struct PublicInput<'a> {
91    pub layout: &'a str,
92    pub rc_min: isize,
93    pub rc_max: isize,
94    pub n_steps: usize,
95    pub memory_segments: HashMap<&'a str, MemorySegmentAddresses>,
96    pub public_memory: Vec<PublicMemoryEntry>,
97    #[serde(skip_deserializing)] // This is set to None by default so we can skip it
98    dynamic_params: (),
99}
100
101impl<'a> PublicInput<'a> {
102    pub fn new(
103        memory: &[Option<Felt252>],
104        layout: &'a str,
105        public_memory_addresses: &[(usize, usize)],
106        memory_segment_addresses: HashMap<&'static str, (usize, usize)>,
107        trace: &[RelocatedTraceEntry],
108        rc_limits: (isize, isize),
109    ) -> Result<Self, PublicInputError> {
110        let memory_entry =
111            |addresses: &(usize, usize)| -> Result<PublicMemoryEntry, PublicInputError> {
112                let (address, page) = addresses;
113                Ok(PublicMemoryEntry {
114                    address: *address,
115                    page: *page,
116                    value: *memory
117                        .get(*address)
118                        .ok_or(PublicInputError::MemoryNotFound(*address))?,
119                })
120            };
121        let public_memory = public_memory_addresses
122            .iter()
123            .map(memory_entry)
124            .collect::<Result<Vec<_>, _>>()?;
125
126        let (rc_min, rc_max) = rc_limits;
127
128        let trace_first = trace.first().ok_or(PublicInputError::EmptyTrace)?;
129        let trace_last = trace.last().ok_or(PublicInputError::EmptyTrace)?;
130
131        Ok(PublicInput {
132            layout,
133            dynamic_params: (),
134            rc_min,
135            rc_max,
136            n_steps: trace.len(),
137            memory_segments: {
138                let mut memory_segment_addresses = memory_segment_addresses
139                    .into_iter()
140                    .map(|(n, s)| (n, s.into()))
141                    .collect::<HashMap<_, MemorySegmentAddresses>>();
142
143                memory_segment_addresses.insert("program", (trace_first.pc, trace_last.pc).into());
144                memory_segment_addresses
145                    .insert("execution", (trace_first.ap, trace_last.ap).into());
146                memory_segment_addresses
147            },
148            public_memory,
149        })
150    }
151
152    pub fn serialize_json(&self) -> Result<String, PublicInputError> {
153        serde_json::to_string_pretty(&self).map_err(PublicInputError::from)
154    }
155}
156
157#[derive(Debug, Error)]
158pub enum PublicInputError {
159    #[error("The trace slice provided is empty")]
160    EmptyTrace,
161    #[error("The provided memory doesn't contain public address {0}")]
162    MemoryNotFound(usize),
163    #[error("Range check values are missing")]
164    NoRangeCheckLimits,
165    #[error("Failed to (de)serialize data")]
166    Serde(#[from] serde_json::Error),
167    #[error(transparent)]
168    VirtualMachine(#[from] VirtualMachineError),
169    #[error(transparent)]
170    Trace(#[from] TraceError),
171}
172#[cfg(test)]
173mod tests {
174    use super::*;
175    use rstest::rstest;
176
177    #[rstest]
178    #[case(include_bytes!("../../cairo_programs/proof_programs/fibonacci.json"))]
179    #[case(include_bytes!("../../cairo_programs/proof_programs/bitwise_output.json"))]
180    #[case(include_bytes!("../../cairo_programs/proof_programs/keccak_builtin.json"))]
181    #[case(include_bytes!("../../cairo_programs/proof_programs/poseidon_builtin.json"))]
182    #[case(include_bytes!("../../cairo_programs/proof_programs/relocate_temporary_segment_append.json"))]
183    #[case(include_bytes!("../../cairo_programs/proof_programs/pedersen_test.json"))]
184    #[case(include_bytes!("../../cairo_programs/proof_programs/ec_op.json"))]
185    fn serialize_and_deserialize_air_public_input(#[case] program_content: &[u8]) {
186        use crate::types::layout_name::LayoutName;
187
188        let config = crate::cairo_run::CairoRunConfig {
189            proof_mode: true,
190            fill_holes: true,
191            relocate_mem: true,
192            trace_enabled: true,
193            layout: LayoutName::all_cairo,
194            ..Default::default()
195        };
196        let runner = crate::cairo_run::cairo_run(program_content, &config, &mut crate::hint_processor::builtin_hint_processor::builtin_hint_processor_definition::BuiltinHintProcessor::new_empty()).unwrap();
197        let public_input = runner.get_air_public_input().unwrap();
198        // We already know serialization works as expected due to the comparison against python VM
199        let serialized_public_input = public_input.serialize_json().unwrap();
200        let deserialized_public_input: PublicInput =
201            serde_json::from_str(&serialized_public_input).unwrap();
202        // Check that the deserialized public input is equal to the one we obtained from the vm first
203        assert_eq!(public_input.layout, deserialized_public_input.layout);
204        assert_eq!(public_input.rc_max, deserialized_public_input.rc_max);
205        assert_eq!(public_input.rc_min, deserialized_public_input.rc_min);
206        assert_eq!(public_input.n_steps, deserialized_public_input.n_steps);
207        assert_eq!(
208            public_input.memory_segments,
209            deserialized_public_input.memory_segments
210        );
211        assert_eq!(
212            public_input.public_memory,
213            deserialized_public_input.public_memory
214        );
215    }
216}