risc0_zkvm_host/
lib.rs

1// Copyright 2022 Risc0, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![deny(missing_docs)]
16#![doc = include_str!("../README.md")]
17
18mod exception;
19mod ffi;
20
21use std::mem;
22
23use serde::{Deserialize, Deserializer, Serialize, Serializer};
24
25pub use exception::Exception;
26
27#[cxx::bridge]
28mod bridge {}
29
30/// The default digest count when generating a MethodId.
31pub const DEFAULT_METHOD_ID_LIMIT: u32 = 12;
32
33/// A Result specialized for [Exception].
34pub type Result<T> = std::result::Result<T, Exception>;
35
36/// A record attesting to the correct execution of a 'method'.
37///
38/// Consists of:
39/// * journal: all data the method wants to publicly output and commit to.
40/// * seal: the cryptographic blob which proves that the receipt is valid.
41pub struct Receipt {
42    ptr: *const ffi::RawReceipt,
43}
44
45/// The prover generates a [Receipt] by executing a given method in a ZKVM.
46pub struct Prover {
47    ptr: *mut ffi::RawProver,
48}
49
50/// A MethodId represents a unique identifier associated with a particular ELF
51/// binary.
52pub struct MethodId {
53    ptr: *const ffi::RawMethodId,
54}
55
56fn into_words(slice: &[u8]) -> Result<Vec<u32>> {
57    let mut vec = Vec::new();
58    let chunks = slice.chunks_exact(4);
59    assert!(chunks.remainder().len() == 0);
60    for chunk in chunks {
61        let word = chunk[0] as u32
62            | (chunk[1] as u32) << 8
63            | (chunk[2] as u32) << 16
64            | (chunk[3] as u32) << 24;
65        vec.push(word);
66    }
67    Ok(vec)
68}
69
70impl MethodId {
71    /// Compute the MethodId associated with an existing ELF binary.
72    pub fn compute(elf_contents: &[u8], limit: u32) -> Result<Self> {
73        let mut err = ffi::RawError::default();
74        let ptr = unsafe {
75            ffi::risc0_method_id_compute(&mut err, elf_contents.as_ptr(), elf_contents.len(), limit)
76        };
77        ffi::check(err, || MethodId { ptr })
78    }
79
80    /// Load an existing MethodId from a buffer.
81    pub fn from_slice(slice: &[u8]) -> Result<Self> {
82        let mut err = ffi::RawError::default();
83        let ptr = unsafe { ffi::risc0_method_id_load(&mut err, slice.as_ptr(), slice.len()) };
84        ffi::check(err, || MethodId { ptr })
85    }
86
87    /// Access the raw slice of a MethodId.
88    pub fn as_slice(&self) -> Result<&[u8]> {
89        let mut err = ffi::RawError::default();
90        let mut len: u32 = 0;
91        let ptr = unsafe { ffi::risc0_method_id_get_buf(&mut err, self.ptr, &mut len) };
92        ffi::check(err, || unsafe {
93            std::slice::from_raw_parts(ptr, len as usize)
94        })
95    }
96}
97
98impl Drop for MethodId {
99    fn drop(&mut self) {
100        let mut err = ffi::RawError::default();
101        unsafe { ffi::risc0_method_id_free(&mut err, self.ptr) };
102        ffi::check(err, || ()).unwrap()
103    }
104}
105
106impl Receipt {
107    /// Construct a new [Receipt] from individual journal and seal parts.
108    pub fn new(journal: &[u8], seal: &[u32]) -> Result<Self> {
109        let mut err = ffi::RawError::default();
110        let ptr = unsafe {
111            ffi::risc0_receipt_new(
112                &mut err,
113                journal.as_ptr(),
114                journal.len(),
115                seal.as_ptr(),
116                seal.len(),
117            )
118        };
119        ffi::check(err, || Receipt { ptr })
120    }
121
122    /// Verify that the current [Receipt] is a valid result of executing the
123    /// method associated with the given method ID in a ZKVM.
124    pub fn verify(&self, method_id: &[u8]) -> Result<()> {
125        let mut err = ffi::RawError::default();
126        unsafe {
127            ffi::risc0_receipt_verify(&mut err, self.ptr, method_id.as_ptr(), method_id.len())
128        };
129        ffi::check(err, || ())
130    }
131
132    /// Provides access to the `seal` of a [Receipt].
133    pub fn get_seal(&self) -> Result<&[u32]> {
134        unsafe {
135            let mut err = ffi::RawError::default();
136            let buf = ffi::risc0_receipt_get_seal_buf(&mut err, self.ptr);
137            let buf = ffi::check(err, || buf)?;
138            let mut err = ffi::RawError::default();
139            let len = ffi::risc0_receipt_get_seal_len(&mut err, self.ptr);
140            let len = ffi::check(err, || len)?;
141            Ok(std::slice::from_raw_parts(buf, len))
142        }
143    }
144
145    /// Provides access to the `journal` of a [Receipt].
146    pub fn get_journal(&self) -> Result<&[u8]> {
147        unsafe {
148            let mut err = ffi::RawError::default();
149            let buf = ffi::risc0_receipt_get_journal_buf(&mut err, self.ptr);
150            let buf = ffi::check(err, || buf)?;
151            let mut err = ffi::RawError::default();
152            let len = ffi::risc0_receipt_get_journal_len(&mut err, self.ptr);
153            let len = ffi::check(err, || len)?;
154            Ok(std::slice::from_raw_parts(buf, len))
155        }
156    }
157
158    /// Provides access to the `journal` of a [Receipt] as a [`Vec<u32>`].
159    pub fn get_journal_vec(&self) -> Result<Vec<u32>> {
160        into_words(self.get_journal()?)
161    }
162}
163
164// TODO(nils): Lift "Receipt" from the pure-rust verify implementation so we
165// don't have to proxy through this structure.
166#[derive(Serialize, Deserialize)]
167struct ReceiptData {
168    journal: Vec<u8>,
169    seal: Vec<u32>,
170}
171
172impl Serialize for Receipt {
173    /// Generate a serialized version of the whole receipt.
174    fn serialize<S: Serializer>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error> {
175        let data: ReceiptData = ReceiptData {
176            journal: self.get_journal().unwrap().into(),
177            seal: self.get_seal().unwrap().into(),
178        };
179        data.serialize(serializer)
180    }
181}
182
183impl<'de> Deserialize<'de> for Receipt {
184    /// Deserialize a receipt.
185    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
186    where
187        D: Deserializer<'de>,
188    {
189        let data = ReceiptData::deserialize(deserializer)?;
190        Ok(Receipt::new(&data.journal, &data.seal).unwrap())
191    }
192}
193
194impl Prover {
195    /// Create a new [Prover] with the given method (specified via
196    /// `elf_contents`) and an associated method ID (specified via
197    /// `method_id`).
198    pub fn new(elf_contents: &[u8], method_id: &[u8]) -> Result<Self> {
199        let mut err = ffi::RawError::default();
200        let ptr = unsafe {
201            ffi::risc0_prover_new(
202                &mut err,
203                elf_contents.as_ptr(),
204                elf_contents.len(),
205                method_id.as_ptr(),
206                method_id.len(),
207            )
208        };
209        ffi::check(err, || Prover { ptr })
210    }
211
212    /// Provide private input data that is availble to guest-side method code
213    /// to 'read'.
214    pub fn add_input(&mut self, slice: &[u32]) -> Result<()> {
215        let mut err = ffi::RawError::default();
216        unsafe {
217            ffi::risc0_prover_add_input(
218                &mut err,
219                self.ptr,
220                slice.as_ptr().cast(),
221                slice.len() * mem::size_of::<u32>(),
222            )
223        };
224        ffi::check(err, || ())
225    }
226
227    /// Provide access to private output data written by guest-side method code.
228    pub fn get_output(&self) -> Result<&[u8]> {
229        unsafe {
230            let mut err = ffi::RawError::default();
231            let buf = ffi::risc0_prover_get_output_buf(&mut err, self.ptr);
232            let buf = ffi::check(err, || buf)?;
233            let mut err = ffi::RawError::default();
234            let len = ffi::risc0_prover_get_output_len(&mut err, self.ptr);
235            let len = ffi::check(err, || len)?;
236            Ok(std::slice::from_raw_parts(buf, len))
237        }
238    }
239
240    /// Provide access to private output data written to by guest-side method
241    /// code.
242    ///
243    /// This returns the data as a [`Vec<u32>`].
244    pub fn get_output_vec(&self) -> Result<Vec<u32>> {
245        into_words(self.get_output()?)
246    }
247
248    /// Execute the ZKVM to produce a [Receipt].
249    pub fn run(&self) -> Result<Receipt> {
250        let mut err = ffi::RawError::default();
251        let ptr = unsafe { ffi::risc0_prover_run(&mut err, self.ptr) };
252        ffi::check(err, || Receipt { ptr })
253    }
254}
255
256impl Drop for Receipt {
257    fn drop(&mut self) {
258        let mut err = ffi::RawError::default();
259        unsafe { ffi::risc0_receipt_free(&mut err, self.ptr) };
260        ffi::check(err, || ()).unwrap()
261    }
262}
263
264impl Drop for Prover {
265    fn drop(&mut self) {
266        let mut err = ffi::RawError::default();
267        unsafe { ffi::risc0_prover_free(&mut err, self.ptr) };
268        ffi::check(err, || ()).unwrap()
269    }
270}
271
272#[ctor::ctor]
273fn init() {
274    unsafe { ffi::risc0_init() };
275}
276
277#[cfg(test)]
278mod test {
279    use super::{Prover, Receipt};
280    use anyhow::Result;
281    use risc0_zkvm_core::Digest;
282    use risc0_zkvm_methods::methods::{FAIL_ID, FAIL_PATH, IO_ID, IO_PATH, SHA_ID, SHA_PATH};
283    use risc0_zkvm_serde::{from_slice, to_vec};
284
285    #[test]
286    fn sha() {
287        assert_eq!(
288            run_sha(""),
289            Digest::new([
290                0xe3b0c442, 0x98fc1c14, 0x9afbf4c8, 0x996fb924, 0x27ae41e4, 0x649b934c, 0xa495991b,
291                0x7852b855,
292            ])
293        );
294        assert_eq!(
295            run_sha("a"),
296            Digest::new([
297                0xca978112, 0xca1bbdca, 0xfac231b3, 0x9a23dc4d, 0xa786eff8, 0x147c4e72, 0xb9807785,
298                0xafee48bb,
299            ])
300        );
301        assert_eq!(
302            run_sha("abc"),
303            Digest::new([
304                0xba7816bf, 0x8f01cfea, 0x414140de, 0x5dae2223, 0xb00361a3, 0x96177a9c, 0xb410ff61,
305                0xf20015ad
306            ])
307        );
308        assert_eq!(
309            run_sha("abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq"),
310            Digest::new([
311                0x248d6a61, 0xd20638b8, 0xe5c02693, 0x0c3e6039, 0xa33ce459, 0x64ff2167, 0xf6ecedd4,
312                0x19db06c1
313            ])
314        );
315    }
316
317    fn run_sha(msg: &str) -> Digest {
318        let mut prover = Prover::new(&std::fs::read(SHA_PATH).unwrap(), SHA_ID).unwrap();
319        let vec = to_vec(&msg).unwrap();
320        prover.add_input(vec.as_slice()).unwrap();
321        let receipt = prover.run().unwrap();
322        let vec = receipt.get_journal_vec().unwrap();
323        from_slice::<Digest>(vec.as_slice()).unwrap()
324    }
325
326    #[test]
327    fn memory_io() {
328        // TODO(nils): Move these constants into something both the guest and host can
329        // depend on
330        const HEAP_START: u32 = 0x00A0_0000;
331        const COMMIT_START: u32 = 0x03F0_0000;
332
333        // Double write to WOM are fine
334        assert!(run_memio(&[(COMMIT_START, 1), (COMMIT_START, 1)]).is_ok());
335
336        // Double write to WOM with different values throw
337        assert!(run_memio(&[(COMMIT_START, 1), (COMMIT_START, 2)]).is_err());
338
339        // But they are OK at different addresses
340        assert!(run_memio(&[(COMMIT_START, 1), (COMMIT_START + 4, 2)]).is_ok());
341
342        // Aligned write is fine
343        assert!(run_memio(&[(HEAP_START, 1)]).is_ok());
344
345        // Unaligned write is bad
346        assert!(run_memio(&[(HEAP_START + 1, 1)]).is_err());
347
348        // Aligned read is fine
349        assert!(run_memio(&[(HEAP_START, 0)]).is_ok());
350
351        // Unaligned read is bad
352        assert!(run_memio(&[(HEAP_START + 1, 0)]).is_err());
353    }
354
355    fn run_memio(pairs: &[(u32, u32)]) -> Result<Receipt> {
356        let mut vec = Vec::new();
357        vec.push(pairs.len() as u32);
358        for (first, second) in pairs {
359            vec.push(*first);
360            vec.push(*second);
361        }
362        let mut prover = Prover::new(&std::fs::read(IO_PATH).unwrap(), IO_ID).unwrap();
363        prover.add_input(vec.as_slice()).unwrap();
364        let receipt = prover.run()?;
365        receipt.verify(IO_ID).unwrap();
366        Ok(receipt)
367    }
368
369    #[test]
370    fn receipt_serde() {
371        // TODO(nils): Move this constant into something both the guest and host can
372        // depend on
373        const HEAP_START: u32 = 0x00A0_0000;
374
375        let receipt: Receipt = run_memio(&[(HEAP_START, 0)]).unwrap();
376        let ser: Vec<u32> = risc0_zkvm_serde::to_vec(&receipt).unwrap();
377        let de: Receipt = risc0_zkvm_serde::from_slice(&ser).unwrap();
378        assert_eq!(de.get_journal().unwrap(), receipt.get_journal().unwrap());
379        assert_eq!(de.get_seal().unwrap(), receipt.get_seal().unwrap());
380        de.verify(IO_ID).unwrap();
381    }
382
383    #[test]
384    fn fail() {
385        // Check that a compliant host will fault.
386        let prover = Prover::new(&std::fs::read(FAIL_PATH).unwrap(), FAIL_ID).unwrap();
387        assert!(prover.run().is_err());
388    }
389}