1extern crate alloc;
16
17use alloc::{collections::BTreeMap, vec, vec::Vec};
18
19use anyhow::{anyhow, bail, ensure, Context, Result};
20use elf::{endian::LittleEndian, file::Class, ElfBytes};
21use risc0_zkp::core::{digest::Digest, hash::sha::Impl};
22use risc0_zkvm_platform::WORD_SIZE;
23use serde::{Deserialize, Serialize};
24
25use crate::{Digestible as _, MemoryImage, SystemState, KERNEL_START_ADDR};
26
27pub struct Program {
29 pub(crate) entry: u32,
31
32 pub(crate) image: BTreeMap<u32, u32>,
34}
35
36impl Program {
37 pub fn load_elf(input: &[u8], max_mem: u32) -> Result<Program> {
39 let mut image: BTreeMap<u32, u32> = BTreeMap::new();
40 let elf = ElfBytes::<LittleEndian>::minimal_parse(input)
41 .map_err(|err| anyhow!("Elf parse error: {err}"))?;
42 if elf.ehdr.class != Class::ELF32 {
43 bail!("Not a 32-bit ELF");
44 }
45 if elf.ehdr.e_machine != elf::abi::EM_RISCV {
46 bail!("Invalid machine type, must be RISC-V");
47 }
48 if elf.ehdr.e_type != elf::abi::ET_EXEC {
49 bail!("Invalid ELF type, must be executable");
50 }
51 let entry: u32 = elf
52 .ehdr
53 .e_entry
54 .try_into()
55 .map_err(|err| anyhow!("e_entry was larger than 32 bits. {err}"))?;
56 if entry >= max_mem || entry % WORD_SIZE as u32 != 0 {
57 bail!("Invalid entrypoint");
58 }
59 let segments = elf
60 .segments()
61 .ok_or_else(|| anyhow!("Missing segment table"))?;
62 if segments.len() > 256 {
63 bail!("Too many program headers");
64 }
65 for segment in segments.iter().filter(|x| x.p_type == elf::abi::PT_LOAD) {
66 let file_size: u32 = segment
67 .p_filesz
68 .try_into()
69 .map_err(|err| anyhow!("filesize was larger than 32 bits. {err}"))?;
70 if file_size >= max_mem {
71 bail!("Invalid segment file_size");
72 }
73 let mem_size: u32 = segment
74 .p_memsz
75 .try_into()
76 .map_err(|err| anyhow!("mem_size was larger than 32 bits {err}"))?;
77 if mem_size >= max_mem {
78 bail!("Invalid segment mem_size");
79 }
80 let vaddr: u32 = segment
81 .p_vaddr
82 .try_into()
83 .map_err(|err| anyhow!("vaddr is larger than 32 bits. {err}"))?;
84 if vaddr % WORD_SIZE as u32 != 0 {
85 bail!("vaddr {vaddr:08x} is unaligned");
86 }
87 let offset: u32 = segment
88 .p_offset
89 .try_into()
90 .map_err(|err| anyhow!("offset is larger than 32 bits. {err}"))?;
91 for i in (0..mem_size).step_by(WORD_SIZE) {
92 let addr = vaddr.checked_add(i).context("Invalid segment vaddr")?;
93 if addr >= max_mem {
94 bail!("Address [0x{addr:08x}] exceeds maximum address for guest programs [0x{max_mem:08x}]");
95 }
96 if i >= file_size {
97 image.insert(addr, 0);
99 } else {
100 let mut word = 0;
101 let len = core::cmp::min(file_size - i, WORD_SIZE as u32);
103 for j in 0..len {
104 let offset = (offset + i + j) as usize;
105 let byte = input.get(offset).context("Invalid segment offset")?;
106 word |= (*byte as u32) << (j * 8);
107 }
108 image.insert(addr, word);
109 }
110 }
111 }
112 Ok(Program::new_from_entry_and_image(entry, image))
113 }
114
115 pub fn new_from_entry_and_image(entry: u32, image: BTreeMap<u32, u32>) -> Self {
117 Self { entry, image }
118 }
119
120 pub fn size_in_words(&self) -> usize {
122 self.image.len()
123 }
124
125 pub fn read_u32(&self, address: &u32) -> Option<u32> {
127 self.image.get(address).copied()
128 }
129}
130
131const MAGIC: &[u8] = b"R0BF"; const BINARY_FORMAT_VERSION: u32 = 1; #[derive(Serialize, Deserialize)]
135enum ProgramBinaryHeaderValueOnDisk {
136 AbiVersion(AbiKind, semver::Version),
137}
138
139trait ReadBytesExt<'a> {
140 fn read_u32(&mut self) -> Result<u32>;
141 fn read_slice(&mut self, len: usize) -> Result<&'a [u8]>;
142}
143
144impl<'a> ReadBytesExt<'a> for &'a [u8] {
145 fn read_u32(&mut self) -> Result<u32> {
146 const U32_SIZE: usize = core::mem::size_of::<u32>();
147
148 if self.len() < U32_SIZE {
149 bail!("unexpected end of file");
150 }
151
152 let value = u32::from_le_bytes(self[..U32_SIZE].try_into().unwrap());
153 *self = &self[U32_SIZE..];
154 Ok(value)
155 }
156
157 fn read_slice(&mut self, len: usize) -> Result<&'a [u8]> {
158 if self.len() < len {
159 bail!("unexpected end of file");
160 }
161 let mut other: &[u8] = &[][..];
162 core::mem::swap(self, &mut other);
163 let (first, rest) = other.split_at(len);
164 *self = rest;
165 Ok(first)
166 }
167}
168
169trait WriteBytesExt {
170 fn write_u32(&mut self, value: u32);
171}
172
173impl WriteBytesExt for Vec<u8> {
174 fn write_u32(&mut self, value: u32) {
175 self.extend_from_slice(&value.to_le_bytes());
176 }
177}
178
179#[non_exhaustive]
181#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
182pub enum AbiKind {
183 V1Compat,
185 Linux, }
188
189#[non_exhaustive]
191#[derive(Clone, Debug, PartialEq, Eq)]
192pub struct ProgramBinaryHeader {
193 pub abi_kind: AbiKind,
195
196 pub abi_version: semver::Version,
198}
199
200impl Default for ProgramBinaryHeader {
201 fn default() -> Self {
202 Self {
203 abi_version: semver::Version::new(1, 0, 0),
204 abi_kind: AbiKind::V1Compat,
205 }
206 }
207}
208
209impl ProgramBinaryHeader {
210 fn decode(mut bytes: &[u8]) -> Result<Self> {
211 let num_kv_pairs = bytes.read_u32().context("Malformed ProgramBinaryHeader")?;
212
213 let mut kv_pairs = vec![];
215 for _ in 0..num_kv_pairs {
216 let kv_pair_len = bytes.read_u32().context("Malformed ProgramBinaryHeader")?;
217 let kv_bytes = bytes
218 .read_slice(kv_pair_len as usize)
219 .context("Malformed ProgramBinaryHeader")?;
220
221 if let Ok(kv_pair) = postcard::from_bytes(kv_bytes) {
223 kv_pairs.push(kv_pair);
224 }
225 }
226
227 if !bytes.is_empty() {
228 bail!("Malformed ProgramBinaryHeader: trailing bytes");
229 }
230
231 if kv_pairs.len() != 1 {
233 bail!("Malformed ProgramBinaryHeader: duplicate attributes");
234 }
235 let (abi_kind, abi_version) = kv_pairs
236 .into_iter()
237 .map(|pair| {
238 let ProgramBinaryHeaderValueOnDisk::AbiVersion(abi_kind, abi_version) = pair;
239 (abi_kind, abi_version)
240 })
241 .next()
242 .ok_or_else(|| anyhow!("ProgramBinary header missing AbiVersion"))?;
243
244 Ok(Self {
245 abi_kind,
246 abi_version,
247 })
248 }
249
250 fn encode(&self) -> Vec<u8> {
251 let kv_pairs = vec![ProgramBinaryHeaderValueOnDisk::AbiVersion(
252 self.abi_kind,
253 self.abi_version.clone(),
254 )];
255
256 let mut ret = vec![];
257
258 ret.write_u32(kv_pairs.len() as u32);
259 for p in &kv_pairs {
260 let kv_bytes = postcard::to_allocvec(p).unwrap();
261 ret.write_u32(kv_bytes.len() as u32);
262 ret.extend_from_slice(&kv_bytes[..]);
263 }
264
265 ret
266 }
267}
268
269#[test]
270fn header_encode_decode() {
271 let header = ProgramBinaryHeader::default();
272 let header_roundtripped = ProgramBinaryHeader::decode(&header.encode()[..]).unwrap();
273
274 assert_eq!(header, header_roundtripped);
275}
276
277#[test]
278fn header_decode_errors_on_too_short() {
279 ProgramBinaryHeader::decode(&[1, 2, 3, 4][..]).unwrap_err();
280 ProgramBinaryHeader::decode(&[1, 2, 3, 4, 5, 6][..]).unwrap_err();
281}
282
283#[test]
284fn header_decode_errors_on_trailing_bytes() {
285 let mut encoded = ProgramBinaryHeader::default().encode();
286 encoded.extend_from_slice(&[1, 2, 3, 4][..]);
287 ProgramBinaryHeader::decode(&encoded).unwrap_err();
288}
289
290#[test]
291fn header_decode_ignores_unknown_attributes() {
292 let mut encoded = ProgramBinaryHeader::default().encode();
293
294 encoded[0] += 1;
295 encoded.extend_from_slice(&[2, 0, 0, 0, 0xFF, 0xFF][..]);
296 let header = ProgramBinaryHeader::decode(&encoded).unwrap();
297 assert_eq!(header, ProgramBinaryHeader::default());
298}
299
300#[non_exhaustive]
302#[derive(Debug, PartialEq, Eq)]
303pub struct ProgramBinary<'a> {
304 pub header: ProgramBinaryHeader,
306
307 pub user_elf: &'a [u8],
309
310 pub kernel_elf: &'a [u8],
312}
313
314impl<'a> ProgramBinary<'a> {
315 pub fn new(user_elf: &'a [u8], kernel_elf: &'a [u8]) -> Self {
317 Self {
318 header: ProgramBinaryHeader::default(),
319 user_elf,
320 kernel_elf,
321 }
322 }
323
324 pub fn decode(mut blob: &'a [u8]) -> Result<Self> {
326 let magic = blob
328 .read_slice(MAGIC.len())
329 .context("Malformed ProgramBinary")?;
330 ensure!(magic == MAGIC, "Malformed ProgramBinary");
331
332 let binary_format_version = blob.read_u32().context("Malformed ProgramBinary")?;
334 ensure!(
335 binary_format_version == BINARY_FORMAT_VERSION,
336 "ProgramBinary binary format version mismatch"
337 );
338
339 let header_len = blob.read_u32().context("Malformed ProgramBinary")? as usize;
341 let header = ProgramBinaryHeader::decode(
342 blob.read_slice(header_len)
343 .context("Malformed ProgramBinary")?,
344 )?;
345
346 let user_len = blob.read_u32().context("Malformed ProgramBinary")? as usize;
348 let user_elf = blob
349 .read_slice(user_len)
350 .context("Malformed ProgramBinary")?;
351 ensure!(!user_elf.is_empty(), "Malformed ProgramBinary");
352
353 let kernel_elf = blob;
354 ensure!(!kernel_elf.is_empty(), "Malformed ProgramBinary");
355
356 Ok(Self {
357 header,
358 user_elf,
359 kernel_elf,
360 })
361 }
362
363 pub fn encode(&self) -> Vec<u8> {
365 let mut ret = vec![];
366
367 ret.extend_from_slice(MAGIC);
369 ret.write_u32(BINARY_FORMAT_VERSION);
370
371 let header_bytes = ProgramBinaryHeader::encode(&self.header);
373 ret.write_u32(header_bytes.len() as u32);
374 ret.extend_from_slice(&header_bytes[..]);
375
376 ret.write_u32(self.user_elf.len() as u32);
378 ret.extend_from_slice(self.user_elf);
379 ret.extend_from_slice(self.kernel_elf);
380
381 ret
382 }
383
384 pub fn to_image(&self) -> Result<MemoryImage> {
386 let user_program =
387 Program::load_elf(self.user_elf, KERNEL_START_ADDR.0).context("Loading user ELF")?;
388 let kernel_program =
389 Program::load_elf(self.kernel_elf, u32::MAX).context("Loading kernel ELF")?;
390 Ok(MemoryImage::with_kernel(user_program, kernel_program))
391 }
392
393 pub fn compute_image_id(&self) -> Result<Digest> {
395 let merkle_root = self.to_image()?.image_id();
396 Ok(SystemState { pc: 0, merkle_root }.digest::<Impl>())
397 }
398}
399
400#[test]
401fn encode_decode() {
402 let p1 = ProgramBinary::new(&[1, 2, 3, 4], &[5, 6, 7, 8]);
403 let v = p1.encode();
404 let p2 = ProgramBinary::decode(&v).unwrap();
405
406 assert_eq!(p1, p2);
407}
408
409#[test]
410fn bad_magic() {
411 let p1 = ProgramBinary::new(&[1, 2, 3, 4], &[5, 6, 7, 8]);
412 let mut v = p1.encode();
413 v[0] = 0xbe;
414 ProgramBinary::decode(&v).unwrap_err();
415}
416
417#[test]
418fn bad_version() {
419 let p1 = ProgramBinary::new(&[1, 2, 3, 4], &[5, 6, 7, 8]);
420 let mut v = p1.encode();
421 v[MAGIC.len()] = 0xbe;
422 ProgramBinary::decode(&v).unwrap_err();
423}