risc0_binfmt/
elf.rs

1// Copyright 2025 RISC Zero, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15extern crate alloc;
16
17use alloc::{collections::BTreeMap, vec, vec::Vec};
18
19use anyhow::{anyhow, bail, ensure, Context, Result};
20use elf::{endian::LittleEndian, file::Class, ElfBytes};
21use risc0_zkp::core::{digest::Digest, hash::sha::Impl};
22use risc0_zkvm_platform::WORD_SIZE;
23use serde::{Deserialize, Serialize};
24
25use crate::{Digestible as _, MemoryImage, SystemState, KERNEL_START_ADDR};
26
27/// A RISC Zero program
28pub struct Program {
29    /// The entrypoint of the program
30    pub(crate) entry: u32,
31
32    /// The initial memory image
33    pub(crate) image: BTreeMap<u32, u32>,
34}
35
36impl Program {
37    /// Initialize a RISC Zero Program from an appropriate ELF file
38    pub fn load_elf(input: &[u8], max_mem: u32) -> Result<Program> {
39        let mut image: BTreeMap<u32, u32> = BTreeMap::new();
40        let elf = ElfBytes::<LittleEndian>::minimal_parse(input)
41            .map_err(|err| anyhow!("Elf parse error: {err}"))?;
42        if elf.ehdr.class != Class::ELF32 {
43            bail!("Not a 32-bit ELF");
44        }
45        if elf.ehdr.e_machine != elf::abi::EM_RISCV {
46            bail!("Invalid machine type, must be RISC-V");
47        }
48        if elf.ehdr.e_type != elf::abi::ET_EXEC {
49            bail!("Invalid ELF type, must be executable");
50        }
51        let entry: u32 = elf
52            .ehdr
53            .e_entry
54            .try_into()
55            .map_err(|err| anyhow!("e_entry was larger than 32 bits. {err}"))?;
56        if entry >= max_mem || entry % WORD_SIZE as u32 != 0 {
57            bail!("Invalid entrypoint");
58        }
59        let segments = elf
60            .segments()
61            .ok_or_else(|| anyhow!("Missing segment table"))?;
62        if segments.len() > 256 {
63            bail!("Too many program headers");
64        }
65        for segment in segments.iter().filter(|x| x.p_type == elf::abi::PT_LOAD) {
66            let file_size: u32 = segment
67                .p_filesz
68                .try_into()
69                .map_err(|err| anyhow!("filesize was larger than 32 bits. {err}"))?;
70            if file_size >= max_mem {
71                bail!("Invalid segment file_size");
72            }
73            let mem_size: u32 = segment
74                .p_memsz
75                .try_into()
76                .map_err(|err| anyhow!("mem_size was larger than 32 bits {err}"))?;
77            if mem_size >= max_mem {
78                bail!("Invalid segment mem_size");
79            }
80            let vaddr: u32 = segment
81                .p_vaddr
82                .try_into()
83                .map_err(|err| anyhow!("vaddr is larger than 32 bits. {err}"))?;
84            if vaddr % WORD_SIZE as u32 != 0 {
85                bail!("vaddr {vaddr:08x} is unaligned");
86            }
87            let offset: u32 = segment
88                .p_offset
89                .try_into()
90                .map_err(|err| anyhow!("offset is larger than 32 bits. {err}"))?;
91            for i in (0..mem_size).step_by(WORD_SIZE) {
92                let addr = vaddr.checked_add(i).context("Invalid segment vaddr")?;
93                if addr >= max_mem {
94                    bail!("Address [0x{addr:08x}] exceeds maximum address for guest programs [0x{max_mem:08x}]");
95                }
96                if i >= file_size {
97                    // Past the file size, all zeros.
98                    image.insert(addr, 0);
99                } else {
100                    let mut word = 0;
101                    // Don't read past the end of the file.
102                    let len = core::cmp::min(file_size - i, WORD_SIZE as u32);
103                    for j in 0..len {
104                        let offset = (offset + i + j) as usize;
105                        let byte = input.get(offset).context("Invalid segment offset")?;
106                        word |= (*byte as u32) << (j * 8);
107                    }
108                    image.insert(addr, word);
109                }
110            }
111        }
112        Ok(Program::new_from_entry_and_image(entry, image))
113    }
114
115    /// Create `Program` from given entry-point and image map
116    pub fn new_from_entry_and_image(entry: u32, image: BTreeMap<u32, u32>) -> Self {
117        Self { entry, image }
118    }
119
120    /// The size of the image in a count of words
121    pub fn size_in_words(&self) -> usize {
122        self.image.len()
123    }
124
125    /// Read a word from the image
126    pub fn read_u32(&self, address: &u32) -> Option<u32> {
127        self.image.get(address).copied()
128    }
129}
130
131const MAGIC: &[u8] = b"R0BF"; // RISC Zero Binary Format
132const BINARY_FORMAT_VERSION: u32 = 1; // RISC Zero Binary Format Version Number
133
134#[derive(Serialize, Deserialize)]
135enum ProgramBinaryHeaderValueOnDisk {
136    AbiVersion(AbiKind, semver::Version),
137}
138
139trait ReadBytesExt<'a> {
140    fn read_u32(&mut self) -> Result<u32>;
141    fn read_slice(&mut self, len: usize) -> Result<&'a [u8]>;
142}
143
144impl<'a> ReadBytesExt<'a> for &'a [u8] {
145    fn read_u32(&mut self) -> Result<u32> {
146        const U32_SIZE: usize = core::mem::size_of::<u32>();
147
148        if self.len() < U32_SIZE {
149            bail!("unexpected end of file");
150        }
151
152        let value = u32::from_le_bytes(self[..U32_SIZE].try_into().unwrap());
153        *self = &self[U32_SIZE..];
154        Ok(value)
155    }
156
157    fn read_slice(&mut self, len: usize) -> Result<&'a [u8]> {
158        if self.len() < len {
159            bail!("unexpected end of file");
160        }
161        let mut other: &[u8] = &[][..];
162        core::mem::swap(self, &mut other);
163        let (first, rest) = other.split_at(len);
164        *self = rest;
165        Ok(first)
166    }
167}
168
169trait WriteBytesExt {
170    fn write_u32(&mut self, value: u32);
171}
172
173impl WriteBytesExt for Vec<u8> {
174    fn write_u32(&mut self, value: u32) {
175        self.extend_from_slice(&value.to_le_bytes());
176    }
177}
178
179/// What kind of ABI is the program using
180#[non_exhaustive]
181#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
182pub enum AbiKind {
183    /// The v1 version of the ABI
184    V1Compat,
185    /// The Linux ABI
186    Linux, // unused for now
187}
188
189/// A list of key-value pairs that contains information about the program.
190#[non_exhaustive]
191#[derive(Clone, Debug, PartialEq, Eq)]
192pub struct ProgramBinaryHeader {
193    /// The ABI the program uses
194    pub abi_kind: AbiKind,
195
196    /// The version of the ABI that the program uses
197    pub abi_version: semver::Version,
198}
199
200impl Default for ProgramBinaryHeader {
201    fn default() -> Self {
202        Self {
203            abi_version: semver::Version::new(1, 0, 0),
204            abi_kind: AbiKind::V1Compat,
205        }
206    }
207}
208
209impl ProgramBinaryHeader {
210    fn decode(mut bytes: &[u8]) -> Result<Self> {
211        let num_kv_pairs = bytes.read_u32().context("Malformed ProgramBinaryHeader")?;
212
213        // Decode the key-value pairs
214        let mut kv_pairs = vec![];
215        for _ in 0..num_kv_pairs {
216            let kv_pair_len = bytes.read_u32().context("Malformed ProgramBinaryHeader")?;
217            let kv_bytes = bytes
218                .read_slice(kv_pair_len as usize)
219                .context("Malformed ProgramBinaryHeader")?;
220
221            // Skip any entries we can't decode
222            if let Ok(kv_pair) = postcard::from_bytes(kv_bytes) {
223                kv_pairs.push(kv_pair);
224            }
225        }
226
227        if !bytes.is_empty() {
228            bail!("Malformed ProgramBinaryHeader: trailing bytes");
229        }
230
231        // Find the individual key-value pairs we need
232        if kv_pairs.len() != 1 {
233            bail!("Malformed ProgramBinaryHeader: duplicate attributes");
234        }
235        let (abi_kind, abi_version) = kv_pairs
236            .into_iter()
237            .map(|pair| {
238                let ProgramBinaryHeaderValueOnDisk::AbiVersion(abi_kind, abi_version) = pair;
239                (abi_kind, abi_version)
240            })
241            .next()
242            .ok_or_else(|| anyhow!("ProgramBinary header missing AbiVersion"))?;
243
244        Ok(Self {
245            abi_kind,
246            abi_version,
247        })
248    }
249
250    fn encode(&self) -> Vec<u8> {
251        let kv_pairs = vec![ProgramBinaryHeaderValueOnDisk::AbiVersion(
252            self.abi_kind,
253            self.abi_version.clone(),
254        )];
255
256        let mut ret = vec![];
257
258        ret.write_u32(kv_pairs.len() as u32);
259        for p in &kv_pairs {
260            let kv_bytes = postcard::to_allocvec(p).unwrap();
261            ret.write_u32(kv_bytes.len() as u32);
262            ret.extend_from_slice(&kv_bytes[..]);
263        }
264
265        ret
266    }
267}
268
269#[test]
270fn header_encode_decode() {
271    let header = ProgramBinaryHeader::default();
272    let header_roundtripped = ProgramBinaryHeader::decode(&header.encode()[..]).unwrap();
273
274    assert_eq!(header, header_roundtripped);
275}
276
277#[test]
278fn header_decode_errors_on_too_short() {
279    ProgramBinaryHeader::decode(&[1, 2, 3, 4][..]).unwrap_err();
280    ProgramBinaryHeader::decode(&[1, 2, 3, 4, 5, 6][..]).unwrap_err();
281}
282
283#[test]
284fn header_decode_errors_on_trailing_bytes() {
285    let mut encoded = ProgramBinaryHeader::default().encode();
286    encoded.extend_from_slice(&[1, 2, 3, 4][..]);
287    ProgramBinaryHeader::decode(&encoded).unwrap_err();
288}
289
290#[test]
291fn header_decode_ignores_unknown_attributes() {
292    let mut encoded = ProgramBinaryHeader::default().encode();
293
294    encoded[0] += 1;
295    encoded.extend_from_slice(&[2, 0, 0, 0, 0xFF, 0xFF][..]);
296    let header = ProgramBinaryHeader::decode(&encoded).unwrap();
297    assert_eq!(header, ProgramBinaryHeader::default());
298}
299
300/// A container to hold a user ELF and a kernel ELF together.
301#[non_exhaustive]
302#[derive(Debug, PartialEq, Eq)]
303pub struct ProgramBinary<'a> {
304    /// The header.
305    pub header: ProgramBinaryHeader,
306
307    /// The user ELF.
308    pub user_elf: &'a [u8],
309
310    /// The kernel ELF.
311    pub kernel_elf: &'a [u8],
312}
313
314impl<'a> ProgramBinary<'a> {
315    /// Construct from a pair of ELFs.
316    pub fn new(user_elf: &'a [u8], kernel_elf: &'a [u8]) -> Self {
317        Self {
318            header: ProgramBinaryHeader::default(),
319            user_elf,
320            kernel_elf,
321        }
322    }
323
324    /// Parse a blob into a `ProgramBinary`.
325    pub fn decode(mut blob: &'a [u8]) -> Result<Self> {
326        // Read MAGIC bytes. These signal the file format.
327        let magic = blob
328            .read_slice(MAGIC.len())
329            .context("Malformed ProgramBinary")?;
330        ensure!(magic == MAGIC, "Malformed ProgramBinary");
331
332        // Read the format version number.
333        let binary_format_version = blob.read_u32().context("Malformed ProgramBinary")?;
334        ensure!(
335            binary_format_version == BINARY_FORMAT_VERSION,
336            "ProgramBinary binary format version mismatch"
337        );
338
339        // Read the header.
340        let header_len = blob.read_u32().context("Malformed ProgramBinary")? as usize;
341        let header = ProgramBinaryHeader::decode(
342            blob.read_slice(header_len)
343                .context("Malformed ProgramBinary")?,
344        )?;
345
346        // Read user length, and calculate kernel offset / length
347        let user_len = blob.read_u32().context("Malformed ProgramBinary")? as usize;
348        let user_elf = blob
349            .read_slice(user_len)
350            .context("Malformed ProgramBinary")?;
351        ensure!(!user_elf.is_empty(), "Malformed ProgramBinary");
352
353        let kernel_elf = blob;
354        ensure!(!kernel_elf.is_empty(), "Malformed ProgramBinary");
355
356        Ok(Self {
357            header,
358            user_elf,
359            kernel_elf,
360        })
361    }
362
363    /// Convert this binary into a blob.
364    pub fn encode(&self) -> Vec<u8> {
365        let mut ret = vec![];
366
367        // Write magic and format version
368        ret.extend_from_slice(MAGIC);
369        ret.write_u32(BINARY_FORMAT_VERSION);
370
371        // Write the header
372        let header_bytes = ProgramBinaryHeader::encode(&self.header);
373        ret.write_u32(header_bytes.len() as u32);
374        ret.extend_from_slice(&header_bytes[..]);
375
376        // Write the user and kernel elfs
377        ret.write_u32(self.user_elf.len() as u32);
378        ret.extend_from_slice(self.user_elf);
379        ret.extend_from_slice(self.kernel_elf);
380
381        ret
382    }
383
384    /// Convert this binary into a `MemoryImage`.
385    pub fn to_image(&self) -> Result<MemoryImage> {
386        let user_program =
387            Program::load_elf(self.user_elf, KERNEL_START_ADDR.0).context("Loading user ELF")?;
388        let kernel_program =
389            Program::load_elf(self.kernel_elf, u32::MAX).context("Loading kernel ELF")?;
390        Ok(MemoryImage::with_kernel(user_program, kernel_program))
391    }
392
393    /// Compute and return the ImageID of this binary.
394    pub fn compute_image_id(&self) -> Result<Digest> {
395        let merkle_root = self.to_image()?.image_id();
396        Ok(SystemState { pc: 0, merkle_root }.digest::<Impl>())
397    }
398}
399
400#[test]
401fn encode_decode() {
402    let p1 = ProgramBinary::new(&[1, 2, 3, 4], &[5, 6, 7, 8]);
403    let v = p1.encode();
404    let p2 = ProgramBinary::decode(&v).unwrap();
405
406    assert_eq!(p1, p2);
407}
408
409#[test]
410fn bad_magic() {
411    let p1 = ProgramBinary::new(&[1, 2, 3, 4], &[5, 6, 7, 8]);
412    let mut v = p1.encode();
413    v[0] = 0xbe;
414    ProgramBinary::decode(&v).unwrap_err();
415}
416
417#[test]
418fn bad_version() {
419    let p1 = ProgramBinary::new(&[1, 2, 3, 4], &[5, 6, 7, 8]);
420    let mut v = p1.encode();
421    v[MAGIC.len()] = 0xbe;
422    ProgramBinary::decode(&v).unwrap_err();
423}