cairo-air 1.2.2

AIR (Algebraic Intermediate Representation) definitions for Cairo programs
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
use std::array;
use std::fs::File;
use std::io::{Read, Write};
use std::path::Path;

use bzip2::read::BzDecoder;
use bzip2::write::BzEncoder;
use bzip2::Compression;
use clap::ValueEnum;
use itertools::Itertools;
use num_traits::Zero;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use stwo::core::fields::m31::BaseField;
use stwo::core::fields::qm31::{SecureField, SECURE_EXTENSION_DEGREE};
use stwo::core::pcs::TreeVec;
use stwo::core::vcs::blake2_hash::Blake2sHasher;
use stwo::core::vcs_lifted::MerkleHasherLifted;
use stwo_cairo_serialize::{CairoDeserialize, CairoSerialize};
use stwo_constraint_framework::{
    INTERACTION_TRACE_IDX, ORIGINAL_TRACE_IDX, PREPROCESSED_TRACE_IDX,
};
use tracing::{span, Level};

use crate::air::{CairoProof, MemorySection, PublicMemory};
use crate::CairoProofForRustVerifier;

mod json {
    #[cfg(any(target_arch = "wasm32", target_arch = "wasm64"))]
    pub use serde_json::{from_str, to_string_pretty};
    #[cfg(not(any(target_arch = "wasm32", target_arch = "wasm64")))]
    pub use sonic_rs::{from_str, to_string_pretty};
}

/// 2^31, used for encoding small felt252 values.
const MSB_U32: u32 = 0x80000000;

/// Cairo proof format
#[derive(Debug, Clone, ValueEnum)]
pub enum ProofFormat {
    /// Standard JSON format.
    Json,
    /// Array of field elements serialized as hex strings.
    /// Compatible with `scarb execute`
    CairoSerde,
    /// Binary format.
    /// Additionally compressed to minimize the proof size.
    Binary,
    /// Extended binary format.
    ExtendedBinary,
}

pub fn pack_into_secure_felts<T: Into<BaseField>>(
    values: impl Iterator<Item = T>,
) -> Vec<SecureField> {
    values
        .chunks(SECURE_EXTENSION_DEGREE)
        .into_iter()
        .map(|mut chunk| {
            SecureField::from_m31_array(array::from_fn(|_| {
                chunk.next().map(|v| v.into()).unwrap_or(BaseField::zero())
            }))
        })
        .collect_vec()
}

pub fn binary_serialize_to_file<T: Serialize>(
    obj: &T,
    proof_file: &File,
) -> Result<(), std::io::Error> {
    let serialized_bytes = bincode::serialize(&obj).map_err(std::io::Error::other)?;

    let mut bz_encoder = BzEncoder::new(proof_file, Compression::best());
    bz_encoder.write_all(&serialized_bytes)?;
    bz_encoder.finish()?;
    Ok(())
}

pub fn binary_deserialize_from_file<T: DeserializeOwned>(
    proof_file: &File,
) -> Result<T, std::io::Error> {
    let mut bytes = Vec::new();
    let mut bz_decoder = BzDecoder::new(proof_file);
    bz_decoder.read_to_end(&mut bytes)?;
    bincode::deserialize(&bytes).map_err(std::io::Error::other)
}

/// Serializes Cairo proof given the desired format and writes it to a file.
pub fn serialize_proof_to_file<H: MerkleHasherLifted + Serialize>(
    proof: &CairoProof<H>,
    proof_path: &Path,
    proof_format: ProofFormat,
) -> Result<(), std::io::Error>
where
    H::Hash: CairoSerialize,
{
    let span = span!(Level::INFO, "Serialize proof").entered();

    let mut proof_file = File::create(proof_path)?;

    match proof_format {
        ProofFormat::Json => {
            let proof_for_rust_verifier: CairoProofForRustVerifier<_> = proof.clone().into();
            proof_file.write_all(json::to_string_pretty(&proof_for_rust_verifier)?.as_bytes())?;
        }
        ProofFormat::CairoSerde => {
            let mut serialized: Vec<starknet_ff::FieldElement> = Vec::new();
            CairoSerialize::serialize(proof, &mut serialized);

            let hex_strings: Vec<String> = serialized
                .into_iter()
                .map(|felt| format!("0x{felt:x}"))
                .collect();

            proof_file.write_all(json::to_string_pretty(&hex_strings)?.as_bytes())?;
        }
        ProofFormat::Binary => {
            let proof_for_rust_verifier: CairoProofForRustVerifier<_> = proof.clone().into();
            binary_serialize_to_file(&proof_for_rust_verifier, &proof_file)?;
        }
        ProofFormat::ExtendedBinary => {
            binary_serialize_to_file(&proof, &proof_file)?;
        }
    }

    span.exit();
    Ok(())
}

/// Loads a Cairo proof for the Rust verifier from a file in the specified format.
pub fn deserialize_proof_from_file<H: MerkleHasherLifted + DeserializeOwned>(
    proof_path: &Path,
    proof_format: ProofFormat,
) -> Result<CairoProofForRustVerifier<H>, std::io::Error>
where
    H::Hash: CairoDeserialize,
{
    match proof_format {
        ProofFormat::Json => {
            let proof_str = std::fs::read_to_string(proof_path)?;
            json::from_str(&proof_str).map_err(std::io::Error::other)
        }
        ProofFormat::CairoSerde => {
            panic!("Deserialization from a Cairo-serialized proof is not supported.");
        }
        ProofFormat::Binary => {
            let proof_file = File::open(proof_path)?;
            binary_deserialize_from_file(&proof_file)
        }
        ProofFormat::ExtendedBinary => {
            let proof_file = File::open(proof_path)?;
            let extended_proof: CairoProof<H> = binary_deserialize_from_file(&proof_file)?;
            Ok(extended_proof.into())
        }
    }
}

/// The data associated with the Cairo proof.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VerificationOutput {
    /// Program hash.
    pub program_hash: starknet_ff::FieldElement,
    /// Public output.
    pub output: Vec<starknet_ff::FieldElement>,
}

/// Extract program hash (blake2s) and public output from the public memory.
pub fn get_verification_output(public_memory: &PublicMemory) -> VerificationOutput {
    let program_hash = construct_f252(&encode_and_hash_memory_section(&public_memory.program));
    let output = public_memory
        .output
        .iter()
        .map(|(_, entry)| construct_f252(entry))
        .collect();
    VerificationOutput {
        program_hash,
        output,
    }
}

/// Encodes a memory section and hashes it using Cairo blake.
pub fn encode_and_hash_memory_section(section: &MemorySection) -> [u32; 8] {
    let mut hasher = Blake2sHasher::new();
    for entry in section {
        let (_, val) = *entry;
        let limbs = encode_felt_in_limbs(val);
        for limb in limbs {
            // Cairo blake uses little-endian byte order for the input.
            hasher.update(&limb.to_le_bytes());
        }
    }
    let digest_bytes = hasher.finalize().0.to_vec();

    // Cairo blake uses little-endian byte order for the output as well, so we need to reverse each
    // 4-byte limb.
    let limbs: Vec<u32> = digest_bytes
        .chunks_exact(4)
        .map(|l| u32::from_le_bytes(l.try_into().unwrap()))
        .collect();

    limbs.try_into().unwrap()
}

/// Convert digest to a field element, adding every limb to the result (shifted) to reduce modulo
/// stark prime.
pub fn construct_f252(limbs: &[u32; 8]) -> starknet_ff::FieldElement {
    let mut result = starknet_ff::FieldElement::ZERO;
    let offset = starknet_ff::FieldElement::from(0x100000000_u64);
    for limb in limbs.iter().rev() {
        result = result * offset + (*limb).into();
    }
    result
}

/// Encodes a felt, represented by 8 u32 limbs in little-endian order and returns the encoded result
/// in big-endian order.
///
/// The encoding is done in the following way:
/// * If the felt is less than 2^63, it's encoded as the 2 least significant limbs.
/// * Otherwise, it's encoded as the 8 limbs, where the most significant limb has its MSB set (Note
///   that the prime is less than 2^255 so the MSB could not be set prior to this intervention).
pub fn encode_felt_in_limbs(felt: [u32; 8]) -> Vec<u32> {
    let [v0, v1, v2, v3, v4, v5, v6, v7] = felt;
    if v2 == 0 && v3 == 0 && v4 == 0 && v5 == 0 && v6 == 0 && v7 == 0 && v1 < MSB_U32 {
        vec![v1, v0]
    } else {
        vec![v7 + MSB_U32, v6, v5, v4, v3, v2, v1, v0]
    }
}

/// A utility function which transforms the order and layout of the queried values of a stwo proof
/// according to the format expected by the Cairo verifier.
pub fn sort_and_transpose_queried_values(
    queried_values: &TreeVec<Vec<Vec<BaseField>>>,
    trace_and_interaction_trace_log_sizes: Vec<&[u32]>,
) -> TreeVec<Vec<BaseField>> {
    debug_assert!(trace_and_interaction_trace_log_sizes[PREPROCESSED_TRACE_IDX].is_empty());
    debug_assert!(trace_and_interaction_trace_log_sizes.len() == 3);

    let mut new_queried_values_per_tree = vec![];
    let n_queries = queried_values[0][0].len();
    // Transpose the preprocessed queried values. The preprocessed columns are already sorted in
    // ascending order so there is no need to sort the values.
    let pp_queried_values = &queried_values.first().unwrap();
    let mut new_queried_values: Vec<BaseField> = vec![];
    for row_idx in 0..n_queries {
        new_queried_values.extend(pp_queried_values.iter().map(|vals| vals[row_idx]));
    }
    new_queried_values_per_tree.push(new_queried_values);

    // Sort and transpose the queried values of the base trace and interaction trace.
    for (queried_values, col_sizes) in queried_values[ORIGINAL_TRACE_IDX..=INTERACTION_TRACE_IDX]
        .iter()
        .zip_eq(
            trace_and_interaction_trace_log_sizes[ORIGINAL_TRACE_IDX..=INTERACTION_TRACE_IDX]
                .iter(),
        )
    {
        let mut new_queried_values = vec![];
        let mut sorted_queries: Vec<_> = queried_values
            .iter()
            .zip_eq(col_sizes.iter())
            .sorted_by_key(|(_, col_size)| *col_size)
            .map(|(vals, _)| vals.iter())
            .collect();
        for _ in 0..n_queries {
            new_queried_values.extend(
                sorted_queries
                    .iter_mut()
                    .map(|col_iter| *col_iter.next().unwrap()),
            );
        }
        new_queried_values_per_tree.push(new_queried_values)
    }

    // Transpose the queried values of the composition polynomial commitment. All columns
    // in the composition commitment are of the same length so there is no need to sort.
    let composition_queried_values = &queried_values.last().unwrap();
    let mut new_queried_values: Vec<BaseField> = vec![];
    for row_idx in 0..n_queries {
        new_queried_values.extend(composition_queried_values.iter().map(|vals| vals[row_idx]));
    }
    new_queried_values_per_tree.push(new_queried_values);
    TreeVec(new_queried_values_per_tree)
}

#[cfg(test)]
mod tests {
    use stwo::core::fields::m31::M31;

    use super::*;

    #[test]
    fn test_encode_felt_in_limbs() {
        let felt0 = [0x12345678, 0x70000000, 0, 0, 0, 0, 0, 0];
        let felt1 = [
            0x12345678, 0x90abcdef, 0xabcdef12, 0x34567890, 0x01234567, 0x89abcdef, 0x01234567, 0,
        ];
        let limbs0 = encode_felt_in_limbs(felt0);
        let limbs1 = encode_felt_in_limbs(felt1);
        assert_eq!(limbs0, vec![1879048192, 305419896]);
        assert_eq!(
            limbs1,
            vec![
                2147483648, 19088743, 2309737967, 19088743, 878082192, 2882400018, 2427178479,
                305419896
            ]
        );
    }

    #[test]
    fn test_encode_and_hash_memory_section() {
        let memory_section = vec![
            (0, [0x12345678, 0x90abcdef, 0, 0, 0, 0, 0, 0]),
            (1, [0xabcdef12, 0x34567890, 0, 0, 0, 0, 0, 0]),
        ];
        let hash = encode_and_hash_memory_section(&memory_section);
        let expected = [
            2421522214_u32,
            635981307,
            2862863578,
            1664236125,
            1878536921,
            1607560013,
            4274188691,
            2957079540,
        ];
        assert_eq!(hash, expected);
    }

    #[test]
    fn test_construct_f252() {
        let limbs = [
            2421522214_u32,
            635981307,
            2862863578,
            1664236125,
            1878536921,
            1607560013,
            4274188691,
            2957079540,
        ];
        let expected = starknet_ff::FieldElement::from_dec_str(
            "115645365096977585374207223166120623839439046970571781411593222716768222992",
        )
        .unwrap();
        assert_eq!(construct_f252(&limbs), expected);
    }

    #[test]
    fn test_sort_queried_values() {
        let trace_and_interaction_trace_log_sizes = [vec![], vec![4, 3, 2, 1], vec![4, 1, 3, 2]];
        let trace_and_interaction_trace_log_sizes: Vec<&[u32]> =
            trace_and_interaction_trace_log_sizes
                .iter()
                .map(|v| v.as_slice())
                .collect();
        let unsorted_queried_values = TreeVec(vec![
            vec![
                vec![M31::from(1), M31::from(2)],
                vec![M31::from(3), M31::from(4)],
                vec![M31::from(5), M31::from(6)],
            ],
            vec![
                vec![M31::from(1), M31::from(2)],
                vec![M31::from(3), M31::from(4)],
                vec![M31::from(5), M31::from(6)],
                vec![M31::from(7), M31::from(8)],
            ],
            vec![
                vec![M31::from(1), M31::from(2)],
                vec![M31::from(3), M31::from(4)],
                vec![M31::from(5), M31::from(6)],
                vec![M31::from(7), M31::from(8)],
            ],
            vec![vec![M31::from(1), M31::from(2)]; 8],
        ]);
        let sorted_queried_values = TreeVec(vec![
            vec![
                M31::from(1),
                M31::from(3),
                M31::from(5),
                M31::from(2),
                M31::from(4),
                M31::from(6),
            ],
            vec![
                M31::from(7),
                M31::from(5),
                M31::from(3),
                M31::from(1),
                M31::from(8),
                M31::from(6),
                M31::from(4),
                M31::from(2),
            ],
            vec![
                M31::from(3),
                M31::from(7),
                M31::from(5),
                M31::from(1),
                M31::from(4),
                M31::from(8),
                M31::from(6),
                M31::from(2),
            ],
            [[M31::from(1); 8], [M31::from(2); 8]].concat(),
        ]);

        assert_eq!(
            sorted_queried_values.0,
            sort_and_transpose_queried_values(
                &unsorted_queried_values,
                trace_and_interaction_trace_log_sizes
            )
            .0
        );
    }

    // TODO(Leo): add tests for serializing and deserializing the proof for rust verifier.
}