use core::borrow::Borrow;
use slop_air::{Air, AirBuilder, BaseAir};
use slop_algebra::{AbstractField, PrimeField32};
use slop_matrix::Matrix;
use sp1_core_executor::{
events::{ByteLookupEvent, ByteRecord, PrecompileEvent},
ByteOpcode, ExecutionRecord, Program, SyscallCode,
};
use sp1_derive::AlignedBorrow;
use sp1_hypercube::air::{InteractionScope, MachineAir};
use sp1_primitives::consts::{PROT_EXEC, PROT_READ, PROT_WRITE};
use std::{borrow::BorrowMut, mem::MaybeUninit};
use crate::{air::SP1CoreAirBuilder, memory::PageProtAccessCols, utils::next_multiple_of_32};
const NUM_COLS: usize = size_of::<MProtectCols<u8>>();
#[derive(Default)]
pub struct MProtectChip;
impl MProtectChip {
pub const fn new() -> Self {
Self
}
}
#[derive(Debug, Clone, AlignedBorrow)]
#[repr(C)]
pub struct MProtectCols<T> {
pub clk_high: T,
pub clk_low: T,
pub addr: [T; 3],
pub addr_4_bits: T,
pub addr_12_bits: T,
pub prot: T,
pub prot_read: T,
pub prot_write: T,
pub prot_exec: T,
pub is_real: T,
pub page_prot_access: PageProtAccessCols<T>,
}
impl<F> BaseAir<F> for MProtectChip {
fn width(&self) -> usize {
NUM_COLS
}
}
impl<F: PrimeField32> MachineAir<F> for MProtectChip {
type Record = ExecutionRecord;
type Program = Program;
fn name(&self) -> &'static str {
"Mprotect"
}
fn num_rows(&self, input: &Self::Record) -> Option<usize> {
let nb_rows = input.get_precompile_events(SyscallCode::MPROTECT).len();
let size_log2 = input.fixed_log2_rows::<F, _>(self);
let padded_nb_rows = next_multiple_of_32(nb_rows, size_log2);
Some(padded_nb_rows)
}
fn generate_trace_into(
&self,
input: &ExecutionRecord,
output: &mut ExecutionRecord,
buffer: &mut [MaybeUninit<F>],
) {
let padded_nb_rows = <MProtectChip as MachineAir<F>>::num_rows(self, input).unwrap();
let mut blu_events = Vec::new();
let mprotect_events = input.get_precompile_events(SyscallCode::MPROTECT);
let num_event_rows = mprotect_events.len();
if input.public_values.is_untrusted_programs_enabled == 0 {
assert!(
mprotect_events.is_empty(),
"Page protect is disabled, but mprotect events are present"
);
}
unsafe {
let padding_start = num_event_rows * NUM_COLS;
let padding_size = (padded_nb_rows - num_event_rows) * NUM_COLS;
if padding_size > 0 {
core::ptr::write_bytes(buffer[padding_start..].as_mut_ptr(), 0, padding_size);
}
}
let buffer_ptr = buffer.as_mut_ptr() as *mut F;
let values =
unsafe { core::slice::from_raw_parts_mut(buffer_ptr, num_event_rows * NUM_COLS) };
values.chunks_mut(NUM_COLS).enumerate().for_each(|(idx, row)| {
let event = &mprotect_events[idx].1;
let event =
if let PrecompileEvent::Mprotect(event) = event { event } else { unreachable!() };
let cols: &mut MProtectCols<F> = row.borrow_mut();
assert!(event.local_page_prot_access.len() == 1);
let clk = event.local_page_prot_access[0].final_page_prot_access.timestamp;
cols.clk_high = F::from_canonical_u32((clk >> 24) as u32);
cols.clk_low = F::from_canonical_u32((clk & 0xFFFFFF) as u32);
cols.addr[0] = F::from_canonical_u32((event.addr & 0xFFFF) as u32);
cols.addr[1] = F::from_canonical_u32(((event.addr >> 16) & 0xFFFF) as u32);
cols.addr[2] = F::from_canonical_u32(((event.addr >> 32) & 0xFFFF) as u32);
let addr_12_bits = (event.addr & 0xFFF) as u16; let addr_4_bits = ((event.addr >> 12) & 0xF) as u16;
cols.addr_12_bits = F::from_canonical_u16(addr_12_bits);
cols.addr_4_bits = F::from_canonical_u16(addr_4_bits);
blu_events.push(ByteLookupEvent {
opcode: ByteOpcode::Range,
a: addr_4_bits,
b: 4,
c: 0,
});
blu_events.push(ByteLookupEvent {
opcode: ByteOpcode::Range,
a: addr_12_bits,
b: 12, c: 0,
});
let page_prot = event.local_page_prot_access[0].final_page_prot_access.page_prot;
cols.prot = F::from_canonical_u8(page_prot);
cols.prot_read = if page_prot & PROT_READ != 0 { F::one() } else { F::zero() };
cols.prot_write = if page_prot & PROT_WRITE != 0 { F::one() } else { F::zero() };
cols.prot_exec = if page_prot & PROT_EXEC != 0 { F::one() } else { F::zero() };
cols.page_prot_access.populate(
&event.local_page_prot_access[0].initial_page_prot_access,
clk,
&mut blu_events,
);
cols.is_real = F::one();
});
output.add_byte_lookup_events(blu_events);
}
fn included(&self, shard: &Self::Record) -> bool {
if let Some(shape) = shard.shape.as_ref() {
shape.included::<F, _>(self)
} else {
!shard.get_precompile_events(SyscallCode::MPROTECT).is_empty()
}
}
}
impl<AB> Air<AB> for MProtectChip
where
AB: SP1CoreAirBuilder,
{
fn eval(&self, builder: &mut AB) {
let main = builder.main();
let local = main.row_slice(0);
let local: &MProtectCols<AB::Var> = (*local).borrow();
builder.assert_bool(local.is_real);
builder.when(local.is_real).assert_eq(
local.addr[0],
local.addr_12_bits + local.addr_4_bits * AB::Expr::from_canonical_u32(4096),
);
builder.send_byte(
AB::Expr::from_canonical_u32(ByteOpcode::Range as u32),
local.addr_4_bits.into(),
AB::Expr::from_canonical_u32(4), AB::Expr::zero(),
local.is_real,
);
builder.send_byte(
AB::Expr::from_canonical_u32(ByteOpcode::Range as u32),
local.addr_12_bits.into(),
AB::Expr::from_canonical_u32(12), AB::Expr::zero(),
local.is_real,
);
builder.when(local.is_real).assert_zero(local.addr_12_bits);
builder.assert_bool(local.prot_read);
builder.assert_bool(local.prot_write);
builder.assert_bool(local.prot_exec);
let expected_prot = local.prot_read * AB::Expr::from_canonical_u8(PROT_READ)
+ local.prot_write * AB::Expr::from_canonical_u8(PROT_WRITE)
+ local.prot_exec * AB::Expr::from_canonical_u8(PROT_EXEC);
builder.when(local.is_real).assert_eq(local.prot, expected_prot.clone());
builder.receive_syscall(
local.clk_high,
local.clk_low,
AB::F::from_canonical_u32(SyscallCode::MPROTECT.syscall_id()),
local.addr.map(Into::into),
[local.prot.into(), AB::Expr::zero(), AB::Expr::zero()],
local.is_real,
InteractionScope::Local,
);
builder.eval_page_prot_access_write(
local.clk_high,
local.clk_low,
&[local.addr_4_bits, local.addr[1], local.addr[2]],
local.page_prot_access,
expected_prot,
local.is_real,
);
}
}