extern crate alloc;
use alloc::string::String;
use alloc::vec::Vec;
use rkyv::{Archive, Deserialize, Serialize};
use crate::kstring::KString;
#[derive(Debug, Clone, Archive, Serialize, Deserialize)]
#[rkyv(
serialize_bounds(__S: rkyv::ser::Writer + rkyv::ser::Allocator, __S::Error: rkyv::rancor::Source),
deserialize_bounds(__D::Error: rkyv::rancor::Source),
bytecheck(bounds(__C: rkyv::validation::ArchiveContext, <__C as rkyv::rancor::Fallible>::Error: rkyv::rancor::Source))
)]
pub enum ConstValue {
Unit,
Bool(bool),
Int(i64),
Float(f64),
StaticStr(String),
Tuple(#[rkyv(omit_bounds)] Vec<ConstValue>),
Array(#[rkyv(omit_bounds)] Vec<ConstValue>),
Struct {
type_name: String,
#[rkyv(omit_bounds)]
fields: Vec<(String, ConstValue)>,
},
Enum {
type_name: String,
variant: String,
#[rkyv(omit_bounds)]
fields: Vec<ConstValue>,
},
None,
}
#[derive(Debug, Clone)]
pub enum Value {
Unit,
Bool(bool),
Int(i64),
Float(f64),
StaticStr(String),
DynStr(String),
KStr(KString),
Tuple(Vec<Value>),
Array(Vec<Value>),
Struct {
type_name: String,
fields: Vec<(String, Value)>,
},
Enum {
type_name: String,
variant: String,
fields: Vec<Value>,
},
None,
Func {
chunk_idx: u16,
env: Vec<Value>,
recursive: bool,
},
}
impl PartialEq for Value {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Value::Unit, Value::Unit) | (Value::None, Value::None) => true,
(Value::Bool(a), Value::Bool(b)) => a == b,
(Value::Int(a), Value::Int(b)) => a == b,
(Value::Float(a), Value::Float(b)) => a == b,
(Value::StaticStr(a), Value::StaticStr(b))
| (Value::DynStr(a), Value::DynStr(b))
| (Value::StaticStr(a), Value::DynStr(b))
| (Value::DynStr(a), Value::StaticStr(b)) => a == b,
(Value::KStr(a), Value::KStr(b)) => a.epoch() == b.epoch(),
(
Value::Func {
chunk_idx: a,
env: ae,
recursive: ar,
},
Value::Func {
chunk_idx: b,
env: be,
recursive: br,
},
) => a == b && ae == be && ar == br,
(Value::Tuple(a), Value::Tuple(b)) | (Value::Array(a), Value::Array(b)) => a == b,
(
Value::Struct {
type_name: na,
fields: fa,
},
Value::Struct {
type_name: nb,
fields: fb,
},
) => na == nb && fa == fb,
(
Value::Enum {
type_name: na,
variant: va,
fields: fa,
},
Value::Enum {
type_name: nb,
variant: vb,
fields: fb,
},
) => na == nb && va == vb && fa == fb,
_ => false,
}
}
}
impl Value {
pub fn type_name(&self) -> &'static str {
match self {
Value::Unit => "Unit",
Value::Bool(_) => "Bool",
Value::Int(_) => "Int",
Value::Float(_) => "Float",
Value::StaticStr(_) => "StaticStr",
Value::DynStr(_) => "DynStr",
Value::KStr(_) => "KStr",
Value::Func { .. } => "Func",
Value::Tuple(_) => "Tuple",
Value::Array(_) => "Array",
Value::Struct { .. } => "Struct",
Value::Enum { .. } => "Enum",
Value::None => "None",
}
}
pub fn as_str(&self) -> Option<&str> {
match self {
Value::StaticStr(s) | Value::DynStr(s) => Some(s.as_str()),
_ => Option::None,
}
}
pub fn as_str_with_arena<'a>(
&'a self,
arena: &'a keleusma_arena::Arena,
) -> Result<Option<&'a str>, keleusma_arena::Stale> {
match self {
Value::StaticStr(s) | Value::DynStr(s) => Ok(Some(s.as_str())),
Value::KStr(h) => h.get(arena).map(Some),
_ => Ok(Option::None),
}
}
pub fn contains_dynstr(&self) -> bool {
match self {
Value::DynStr(_) | Value::KStr(_) => true,
Value::Tuple(items) | Value::Array(items) => items.iter().any(Value::contains_dynstr),
Value::Struct { fields, .. } => fields.iter().any(|(_, v)| v.contains_dynstr()),
Value::Enum { fields, .. } => fields.iter().any(Value::contains_dynstr),
_ => false,
}
}
pub fn from_const_archived(c: &ArchivedConstValue) -> Value {
match c {
ArchivedConstValue::Unit => Value::Unit,
ArchivedConstValue::Bool(b) => Value::Bool(*b),
ArchivedConstValue::Int(i) => Value::Int(i.to_native()),
ArchivedConstValue::Float(f) => Value::Float(f.to_native()),
ArchivedConstValue::StaticStr(s) => {
use alloc::string::ToString;
Value::StaticStr(s.as_str().to_string())
}
ArchivedConstValue::Tuple(items) => {
Value::Tuple(items.iter().map(Value::from_const_archived).collect())
}
ArchivedConstValue::Array(items) => {
Value::Array(items.iter().map(Value::from_const_archived).collect())
}
ArchivedConstValue::Struct { type_name, fields } => {
use alloc::string::ToString;
Value::Struct {
type_name: type_name.as_str().to_string(),
fields: fields
.iter()
.map(|kv| (kv.0.as_str().to_string(), Value::from_const_archived(&kv.1)))
.collect(),
}
}
ArchivedConstValue::Enum {
type_name,
variant,
fields,
} => {
use alloc::string::ToString;
Value::Enum {
type_name: type_name.as_str().to_string(),
variant: variant.as_str().to_string(),
fields: fields.iter().map(Value::from_const_archived).collect(),
}
}
ArchivedConstValue::None => Value::None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Archive, Serialize, Deserialize)]
pub enum BlockType {
Func,
Reentrant,
Stream,
}
#[derive(Debug, Clone, Copy, PartialEq, Archive, Serialize, Deserialize)]
pub enum Op {
Const(u16),
PushUnit,
PushTrue,
PushFalse,
GetLocal(u16),
SetLocal(u16),
GetData(u16),
SetData(u16),
Add,
Sub,
Mul,
Div,
Mod,
Neg,
CmpEq,
CmpNe,
CmpLt,
CmpGt,
CmpLe,
CmpGe,
Not,
If(u32),
Else(u32),
EndIf,
Loop(u32),
EndLoop(u32),
Break(u32),
BreakIf(u32),
Stream,
Reset,
Call(u16, u8),
CallNative(u16, u8),
CallIndirect(u8),
PushFunc(u16),
MakeClosure(u16, u8),
MakeRecursiveClosure(u16, u8),
Return,
Yield,
Pop,
Dup,
NewStruct(u16),
NewEnum(u16, u16, u8),
NewArray(u16),
NewTuple(u8),
WrapSome,
PushNone,
GetField(u16),
GetIndex,
GetTupleField(u8),
GetEnumField(u8),
Len,
IsEnum(u16, u16),
IsStruct(u16),
IntToFloat,
FloatToInt,
Trap(u16),
}
pub const VALUE_SLOT_SIZE_BYTES: u32 = 32;
#[derive(Clone, Copy)]
pub struct CostModel {
pub value_slot_bytes: u32,
pub op_cycles: fn(&Op) -> u32,
}
impl CostModel {
pub fn cycles(&self, op: &Op) -> u32 {
(self.op_cycles)(op)
}
pub fn slots_to_bytes(&self, slots: u32) -> u32 {
slots.saturating_mul(self.value_slot_bytes)
}
pub fn heap_alloc_bytes(&self, op: &Op, chunk: &Chunk) -> u32 {
match op {
Op::NewStruct(template_idx) => {
let idx = *template_idx as usize;
let field_count = chunk
.struct_templates
.get(idx)
.map_or(0, |t| t.field_names.len() as u32);
self.slots_to_bytes(field_count)
}
Op::NewEnum(_, _, n) => self.slots_to_bytes(*n as u32),
Op::NewArray(n) => self.slots_to_bytes(*n as u32),
Op::NewTuple(n) => self.slots_to_bytes(*n as u32),
_ => 0,
}
}
}
pub const NOMINAL_COST_MODEL: CostModel = CostModel {
value_slot_bytes: VALUE_SLOT_SIZE_BYTES,
op_cycles: nominal_op_cycles,
};
pub fn nominal_op_cycles(op: &Op) -> u32 {
match op {
Op::Const(_)
| Op::PushUnit
| Op::PushTrue
| Op::PushFalse
| Op::GetLocal(_)
| Op::SetLocal(_)
| Op::GetData(_)
| Op::SetData(_)
| Op::Pop
| Op::Dup
| Op::PushNone
| Op::WrapSome
| Op::Not => 1,
Op::If(_)
| Op::Else(_)
| Op::EndIf
| Op::Loop(_)
| Op::EndLoop(_)
| Op::Break(_)
| Op::BreakIf(_)
| Op::Stream
| Op::Reset
| Op::Yield
| Op::Trap(_) => 1,
Op::Add
| Op::Sub
| Op::Mul
| Op::Neg
| Op::CmpEq
| Op::CmpNe
| Op::CmpLt
| Op::CmpGt
| Op::CmpLe
| Op::CmpGe
| Op::GetIndex
| Op::GetTupleField(_)
| Op::GetEnumField(_)
| Op::Len
| Op::IntToFloat
| Op::FloatToInt
| Op::Return => 2,
Op::Div | Op::Mod | Op::GetField(_) | Op::IsEnum(_, _) | Op::IsStruct(_) => 3,
Op::NewStruct(_) | Op::NewEnum(_, _, _) | Op::NewArray(_) | Op::NewTuple(_) => 5,
Op::Call(_, _) | Op::CallNative(_, _) | Op::CallIndirect(_) => 10,
Op::PushFunc(_) => 0,
Op::MakeClosure(_, _) | Op::MakeRecursiveClosure(_, _) => 5,
}
}
impl Op {
pub fn cost(&self) -> u32 {
NOMINAL_COST_MODEL.cycles(self)
}
pub fn stack_growth(&self) -> u32 {
match self {
Op::Const(_)
| Op::PushUnit
| Op::PushTrue
| Op::PushFalse
| Op::GetLocal(_)
| Op::GetData(_)
| Op::Dup
| Op::PushNone => 1,
Op::WrapSome | Op::Not | Op::Neg => 0,
Op::Add
| Op::Sub
| Op::Mul
| Op::Div
| Op::Mod
| Op::CmpEq
| Op::CmpNe
| Op::CmpLt
| Op::CmpGt
| Op::CmpLe
| Op::CmpGe => 0,
Op::SetLocal(_) | Op::SetData(_) | Op::Pop => 0,
Op::If(_) | Op::BreakIf(_) => 0,
Op::Else(_) | Op::EndIf | Op::Loop(_) | Op::EndLoop(_) | Op::Break(_) => 0,
Op::Stream | Op::Reset => 0,
Op::Yield => 0,
Op::Call(_, _) | Op::CallNative(_, _) | Op::CallIndirect(_) => 1,
Op::PushFunc(_) => 0,
Op::Return => 0,
Op::NewStruct(_) | Op::NewEnum(_, _, _) | Op::NewArray(_) | Op::NewTuple(_) => 1,
Op::GetField(_)
| Op::GetIndex
| Op::GetTupleField(_)
| Op::GetEnumField(_)
| Op::Len => 0,
Op::IsEnum(_, _) | Op::IsStruct(_) => 0,
Op::IntToFloat | Op::FloatToInt => 0,
Op::Trap(_) => 0,
Op::MakeClosure(_, _) | Op::MakeRecursiveClosure(_, _) => 1,
}
}
pub fn stack_shrink(&self) -> u32 {
match self {
Op::Const(_)
| Op::PushUnit
| Op::PushTrue
| Op::PushFalse
| Op::GetLocal(_)
| Op::GetData(_)
| Op::Dup
| Op::PushNone
| Op::PushFunc(_) => 0,
Op::WrapSome | Op::Not | Op::Neg => 0,
Op::Add
| Op::Sub
| Op::Mul
| Op::Div
| Op::Mod
| Op::CmpEq
| Op::CmpNe
| Op::CmpLt
| Op::CmpGt
| Op::CmpLe
| Op::CmpGe => 1,
Op::SetLocal(_) | Op::SetData(_) | Op::Pop => 1,
Op::If(_) | Op::BreakIf(_) => 1,
Op::Else(_) | Op::EndIf | Op::Loop(_) | Op::EndLoop(_) | Op::Break(_) => 0,
Op::Stream | Op::Reset => 0,
Op::Yield => 1,
Op::Call(_, n) | Op::CallNative(_, n) => *n as u32,
Op::CallIndirect(n) => (*n as u32) + 1,
Op::Return => 0,
Op::NewStruct(_) => 0,
Op::NewEnum(_, _, n) => *n as u32,
Op::NewArray(n) => *n as u32,
Op::NewTuple(n) => *n as u32,
Op::GetField(_) | Op::GetIndex | Op::GetTupleField(_) | Op::GetEnumField(_) => 1,
Op::Len => 0,
Op::IsEnum(_, _) | Op::IsStruct(_) => 0,
Op::IntToFloat | Op::FloatToInt => 0,
Op::Trap(_) => 0,
Op::MakeClosure(_, n) | Op::MakeRecursiveClosure(_, n) => *n as u32,
}
}
pub fn heap_alloc(&self, chunk: &Chunk) -> u32 {
NOMINAL_COST_MODEL.heap_alloc_bytes(self, chunk)
}
}
#[derive(Debug, Clone, Archive, Serialize, Deserialize)]
pub struct StructTemplate {
pub type_name: String,
pub field_names: Vec<String>,
}
#[derive(Debug, Clone, Archive, Serialize, Deserialize)]
pub struct DataSlot {
pub name: String,
}
#[derive(Debug, Clone, Archive, Serialize, Deserialize)]
pub struct DataLayout {
pub slots: Vec<DataSlot>,
}
#[derive(Debug, Clone, Archive, Serialize, Deserialize)]
pub struct Chunk {
pub name: String,
pub ops: Vec<Op>,
pub constants: Vec<ConstValue>,
pub struct_templates: Vec<StructTemplate>,
pub local_count: u16,
pub param_count: u8,
pub block_type: BlockType,
}
#[derive(Debug, Clone, Archive, Serialize, Deserialize)]
pub struct Module {
pub chunks: Vec<Chunk>,
pub native_names: Vec<String>,
pub entry_point: Option<usize>,
pub data_layout: Option<DataLayout>,
pub word_bits_log2: u8,
pub addr_bits_log2: u8,
pub float_bits_log2: u8,
pub wcet_cycles: u32,
pub wcmu_bytes: u32,
}
pub const BYTECODE_MAGIC: [u8; 4] = *b"KELE";
pub const BYTECODE_VERSION: u16 = 1;
pub const RUNTIME_WORD_BITS_LOG2: u8 = 6;
pub const RUNTIME_ADDRESS_BITS_LOG2: u8 = 6;
pub const RUNTIME_FLOAT_BITS_LOG2: u8 = 6;
const HEADER_LEN: usize = 24;
const HEADER_WCET_OFFSET: usize = 16;
const HEADER_WCMU_OFFSET: usize = 20;
const FOOTER_LEN: usize = 4;
const CRC32_POLY: u32 = 0xEDB88320;
const CRC32_RESIDUE: u32 = 0x2144DF1C;
fn strip_shebang_prefix(bytes: &[u8]) -> &[u8] {
if bytes.starts_with(b"#!")
&& let Some(nl) = bytes.iter().position(|&b| b == b'\n')
{
return &bytes[nl + 1..];
}
bytes
}
pub(crate) fn crc32(bytes: &[u8]) -> u32 {
let mut crc: u32 = 0xFFFFFFFF;
for &byte in bytes {
crc ^= byte as u32;
for _ in 0..8 {
crc = if crc & 1 != 0 {
(crc >> 1) ^ CRC32_POLY
} else {
crc >> 1
};
}
}
crc ^ 0xFFFFFFFF
}
#[derive(Debug, Clone)]
pub enum LoadError {
BadMagic,
Truncated,
UnsupportedVersion {
got: u16,
expected: u16,
},
WordSizeMismatch {
got: u8,
max_supported: u8,
},
AddressSizeMismatch {
got: u8,
max_supported: u8,
},
FloatSizeMismatch {
got: u8,
max_supported: u8,
},
BadChecksum,
WcetOverflow,
WcmuOverflow,
Codec(String),
}
impl core::fmt::Display for LoadError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
LoadError::BadMagic => f.write_str("bytecode header missing magic 'KELE'"),
LoadError::Truncated => f.write_str(
"bytecode truncated, recorded length exceeds slice, or below minimum framing",
),
LoadError::UnsupportedVersion { got, expected } => {
write!(
f,
"bytecode version {} not supported, expected {}",
got, expected
)
}
LoadError::WordSizeMismatch { got, max_supported } => {
write!(
f,
"bytecode requires {}-bit words, runtime supports up to {}-bit",
1u32 << got,
1u32 << max_supported
)
}
LoadError::AddressSizeMismatch { got, max_supported } => {
write!(
f,
"bytecode requires {}-bit addresses, runtime supports up to {}-bit",
1u32 << got,
1u32 << max_supported
)
}
LoadError::FloatSizeMismatch { got, max_supported } => {
write!(
f,
"bytecode requires {}-bit floats, runtime supports up to {}-bit",
1u32 << got,
1u32 << max_supported
)
}
LoadError::BadChecksum => f.write_str("bytecode CRC-32 residue check failed"),
LoadError::WcetOverflow => {
f.write_str("declared WCET is u32::MAX (overflow); no representable bound")
}
LoadError::WcmuOverflow => {
f.write_str("declared WCMU is u32::MAX (overflow); no representable bound")
}
LoadError::Codec(msg) => write!(f, "bytecode codec error: {}", msg),
}
}
}
impl core::error::Error for LoadError {}
impl Module {
pub fn to_bytes(&self) -> Result<Vec<u8>, LoadError> {
use alloc::format;
let body = rkyv::to_bytes::<rkyv::rancor::Error>(self)
.map_err(|e| LoadError::Codec(format!("encode failed: {}", e)))?;
let total_len = (HEADER_LEN + body.len() + FOOTER_LEN) as u32;
let mut buf = Vec::with_capacity(total_len as usize);
buf.extend_from_slice(&BYTECODE_MAGIC);
buf.extend_from_slice(&BYTECODE_VERSION.to_le_bytes());
buf.extend_from_slice(&total_len.to_le_bytes());
buf.push(self.word_bits_log2);
buf.push(self.addr_bits_log2);
buf.push(self.float_bits_log2);
buf.extend_from_slice(&[0u8; 3]);
buf.extend_from_slice(&self.wcet_cycles.to_le_bytes());
buf.extend_from_slice(&self.wcmu_bytes.to_le_bytes());
buf.extend_from_slice(&body);
let crc = crc32(&buf);
buf.extend_from_slice(&crc.to_le_bytes());
Ok(buf)
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, LoadError> {
use alloc::format;
let bytes = strip_shebang_prefix(bytes);
if bytes.len() < HEADER_LEN + FOOTER_LEN {
return Err(LoadError::Truncated);
}
if bytes[0..4] != BYTECODE_MAGIC {
return Err(LoadError::BadMagic);
}
let length = u32::from_le_bytes([bytes[6], bytes[7], bytes[8], bytes[9]]) as usize;
if length < HEADER_LEN + FOOTER_LEN || length > bytes.len() {
return Err(LoadError::Truncated);
}
let bytes = &bytes[..length];
if crc32(bytes) != CRC32_RESIDUE {
return Err(LoadError::BadChecksum);
}
let version = u16::from_le_bytes([bytes[4], bytes[5]]);
if version != BYTECODE_VERSION {
return Err(LoadError::UnsupportedVersion {
got: version,
expected: BYTECODE_VERSION,
});
}
let word_bits_log2 = bytes[10];
if word_bits_log2 > RUNTIME_WORD_BITS_LOG2 {
return Err(LoadError::WordSizeMismatch {
got: word_bits_log2,
max_supported: RUNTIME_WORD_BITS_LOG2,
});
}
let addr_bits_log2 = bytes[11];
if addr_bits_log2 > RUNTIME_ADDRESS_BITS_LOG2 {
return Err(LoadError::AddressSizeMismatch {
got: addr_bits_log2,
max_supported: RUNTIME_ADDRESS_BITS_LOG2,
});
}
let float_bits_log2 = bytes[12];
if float_bits_log2 > RUNTIME_FLOAT_BITS_LOG2 {
return Err(LoadError::FloatSizeMismatch {
got: float_bits_log2,
max_supported: RUNTIME_FLOAT_BITS_LOG2,
});
}
let header_wcet = u32::from_le_bytes([
bytes[HEADER_WCET_OFFSET],
bytes[HEADER_WCET_OFFSET + 1],
bytes[HEADER_WCET_OFFSET + 2],
bytes[HEADER_WCET_OFFSET + 3],
]);
let header_wcmu = u32::from_le_bytes([
bytes[HEADER_WCMU_OFFSET],
bytes[HEADER_WCMU_OFFSET + 1],
bytes[HEADER_WCMU_OFFSET + 2],
bytes[HEADER_WCMU_OFFSET + 3],
]);
if header_wcet == u32::MAX {
return Err(LoadError::WcetOverflow);
}
if header_wcmu == u32::MAX {
return Err(LoadError::WcmuOverflow);
}
let body = &bytes[HEADER_LEN..length - FOOTER_LEN];
let mut aligned = rkyv::util::AlignedVec::<8>::with_capacity(body.len());
aligned.extend_from_slice(body);
rkyv::from_bytes::<Module, rkyv::rancor::Error>(&aligned)
.map_err(|e| LoadError::Codec(format!("decode failed: {}", e)))
}
pub fn access_bytes(bytes: &[u8]) -> Result<&ArchivedModule, LoadError> {
use alloc::format;
let bytes = strip_shebang_prefix(bytes);
if bytes.len() < HEADER_LEN + FOOTER_LEN {
return Err(LoadError::Truncated);
}
if bytes[0..4] != BYTECODE_MAGIC {
return Err(LoadError::BadMagic);
}
let length = u32::from_le_bytes([bytes[6], bytes[7], bytes[8], bytes[9]]) as usize;
if length < HEADER_LEN + FOOTER_LEN || length > bytes.len() {
return Err(LoadError::Truncated);
}
let bytes = &bytes[..length];
if crc32(bytes) != CRC32_RESIDUE {
return Err(LoadError::BadChecksum);
}
let version = u16::from_le_bytes([bytes[4], bytes[5]]);
if version != BYTECODE_VERSION {
return Err(LoadError::UnsupportedVersion {
got: version,
expected: BYTECODE_VERSION,
});
}
let word_bits_log2 = bytes[10];
if word_bits_log2 > RUNTIME_WORD_BITS_LOG2 {
return Err(LoadError::WordSizeMismatch {
got: word_bits_log2,
max_supported: RUNTIME_WORD_BITS_LOG2,
});
}
let addr_bits_log2 = bytes[11];
if addr_bits_log2 > RUNTIME_ADDRESS_BITS_LOG2 {
return Err(LoadError::AddressSizeMismatch {
got: addr_bits_log2,
max_supported: RUNTIME_ADDRESS_BITS_LOG2,
});
}
let float_bits_log2 = bytes[12];
if float_bits_log2 > RUNTIME_FLOAT_BITS_LOG2 {
return Err(LoadError::FloatSizeMismatch {
got: float_bits_log2,
max_supported: RUNTIME_FLOAT_BITS_LOG2,
});
}
let header_wcet = u32::from_le_bytes([
bytes[HEADER_WCET_OFFSET],
bytes[HEADER_WCET_OFFSET + 1],
bytes[HEADER_WCET_OFFSET + 2],
bytes[HEADER_WCET_OFFSET + 3],
]);
let header_wcmu = u32::from_le_bytes([
bytes[HEADER_WCMU_OFFSET],
bytes[HEADER_WCMU_OFFSET + 1],
bytes[HEADER_WCMU_OFFSET + 2],
bytes[HEADER_WCMU_OFFSET + 3],
]);
if header_wcet == u32::MAX {
return Err(LoadError::WcetOverflow);
}
if header_wcmu == u32::MAX {
return Err(LoadError::WcmuOverflow);
}
let body = &bytes[HEADER_LEN..length - FOOTER_LEN];
if !(body.as_ptr() as usize).is_multiple_of(8) {
return Err(LoadError::Codec(format!(
"body not 8-byte aligned (slice base 0x{:x}); use Module::from_bytes for unaligned input",
bytes.as_ptr() as usize
)));
}
rkyv::access::<ArchivedModule, rkyv::rancor::Error>(body)
.map_err(|e| LoadError::Codec(format!("rkyv access failed: {}", e)))
}
pub fn view_bytes(bytes: &[u8]) -> Result<Module, LoadError> {
use alloc::format;
let archived = Self::access_bytes(bytes)?;
rkyv::deserialize::<Module, rkyv::rancor::Error>(archived)
.map_err(|e| LoadError::Codec(format!("deserialize failed: {}", e)))
}
}
pub fn op_from_archived(archived: &ArchivedOp) -> Op {
match archived {
ArchivedOp::Const(idx) => Op::Const(idx.to_native()),
ArchivedOp::PushUnit => Op::PushUnit,
ArchivedOp::PushTrue => Op::PushTrue,
ArchivedOp::PushFalse => Op::PushFalse,
ArchivedOp::GetLocal(idx) => Op::GetLocal(idx.to_native()),
ArchivedOp::SetLocal(idx) => Op::SetLocal(idx.to_native()),
ArchivedOp::GetData(idx) => Op::GetData(idx.to_native()),
ArchivedOp::SetData(idx) => Op::SetData(idx.to_native()),
ArchivedOp::Add => Op::Add,
ArchivedOp::Sub => Op::Sub,
ArchivedOp::Mul => Op::Mul,
ArchivedOp::Div => Op::Div,
ArchivedOp::Mod => Op::Mod,
ArchivedOp::Neg => Op::Neg,
ArchivedOp::CmpEq => Op::CmpEq,
ArchivedOp::CmpNe => Op::CmpNe,
ArchivedOp::CmpLt => Op::CmpLt,
ArchivedOp::CmpGt => Op::CmpGt,
ArchivedOp::CmpLe => Op::CmpLe,
ArchivedOp::CmpGe => Op::CmpGe,
ArchivedOp::Not => Op::Not,
ArchivedOp::If(t) => Op::If(t.to_native()),
ArchivedOp::Else(t) => Op::Else(t.to_native()),
ArchivedOp::EndIf => Op::EndIf,
ArchivedOp::Loop(t) => Op::Loop(t.to_native()),
ArchivedOp::EndLoop(t) => Op::EndLoop(t.to_native()),
ArchivedOp::Break(t) => Op::Break(t.to_native()),
ArchivedOp::BreakIf(t) => Op::BreakIf(t.to_native()),
ArchivedOp::Stream => Op::Stream,
ArchivedOp::Reset => Op::Reset,
ArchivedOp::Call(c, n) => Op::Call(c.to_native(), *n),
ArchivedOp::CallNative(c, n) => Op::CallNative(c.to_native(), *n),
ArchivedOp::CallIndirect(n) => Op::CallIndirect(*n),
ArchivedOp::PushFunc(idx) => Op::PushFunc(idx.to_native()),
ArchivedOp::MakeClosure(idx, n) => Op::MakeClosure(idx.to_native(), *n),
ArchivedOp::MakeRecursiveClosure(idx, n) => Op::MakeRecursiveClosure(idx.to_native(), *n),
ArchivedOp::Return => Op::Return,
ArchivedOp::Yield => Op::Yield,
ArchivedOp::Pop => Op::Pop,
ArchivedOp::Dup => Op::Dup,
ArchivedOp::NewStruct(t) => Op::NewStruct(t.to_native()),
ArchivedOp::NewEnum(t, v, n) => Op::NewEnum(t.to_native(), v.to_native(), *n),
ArchivedOp::NewArray(n) => Op::NewArray(n.to_native()),
ArchivedOp::NewTuple(n) => Op::NewTuple(*n),
ArchivedOp::WrapSome => Op::WrapSome,
ArchivedOp::PushNone => Op::PushNone,
ArchivedOp::GetField(idx) => Op::GetField(idx.to_native()),
ArchivedOp::GetIndex => Op::GetIndex,
ArchivedOp::GetTupleField(idx) => Op::GetTupleField(*idx),
ArchivedOp::GetEnumField(idx) => Op::GetEnumField(*idx),
ArchivedOp::Len => Op::Len,
ArchivedOp::IsEnum(t, v) => Op::IsEnum(t.to_native(), v.to_native()),
ArchivedOp::IsStruct(t) => Op::IsStruct(t.to_native()),
ArchivedOp::IntToFloat => Op::IntToFloat,
ArchivedOp::FloatToInt => Op::FloatToInt,
ArchivedOp::Trap(idx) => Op::Trap(idx.to_native()),
}
}
impl ConstValue {
pub fn try_from_value(value: Value) -> Result<Self, &'static str> {
match value {
Value::Unit => Ok(ConstValue::Unit),
Value::Bool(b) => Ok(ConstValue::Bool(b)),
Value::Int(i) => Ok(ConstValue::Int(i)),
Value::Float(f) => Ok(ConstValue::Float(f)),
Value::StaticStr(s) => Ok(ConstValue::StaticStr(s)),
Value::DynStr(_) => Err("DynStr cannot be a compile-time constant"),
Value::KStr(_) => Err("KStr cannot be a compile-time constant"),
Value::Func { .. } => Err("Func cannot be a compile-time constant"),
Value::Tuple(items) => items
.into_iter()
.map(ConstValue::try_from_value)
.collect::<Result<Vec<_>, _>>()
.map(ConstValue::Tuple),
Value::Array(items) => items
.into_iter()
.map(ConstValue::try_from_value)
.collect::<Result<Vec<_>, _>>()
.map(ConstValue::Array),
Value::Struct { type_name, fields } => {
let cfields: Result<Vec<_>, _> = fields
.into_iter()
.map(|(n, v)| ConstValue::try_from_value(v).map(|cv| (n, cv)))
.collect();
Ok(ConstValue::Struct {
type_name,
fields: cfields?,
})
}
Value::Enum {
type_name,
variant,
fields,
} => {
let cfields: Result<Vec<_>, _> =
fields.into_iter().map(ConstValue::try_from_value).collect();
Ok(ConstValue::Enum {
type_name,
variant,
fields: cfields?,
})
}
Value::None => Ok(ConstValue::None),
}
}
pub fn into_value(self) -> Value {
match self {
ConstValue::Unit => Value::Unit,
ConstValue::Bool(b) => Value::Bool(b),
ConstValue::Int(i) => Value::Int(i),
ConstValue::Float(f) => Value::Float(f),
ConstValue::StaticStr(s) => Value::StaticStr(s),
ConstValue::Tuple(items) => {
Value::Tuple(items.into_iter().map(ConstValue::into_value).collect())
}
ConstValue::Array(items) => {
Value::Array(items.into_iter().map(ConstValue::into_value).collect())
}
ConstValue::Struct { type_name, fields } => Value::Struct {
type_name,
fields: fields
.into_iter()
.map(|(n, v)| (n, v.into_value()))
.collect(),
},
ConstValue::Enum {
type_name,
variant,
fields,
} => Value::Enum {
type_name,
variant,
fields: fields.into_iter().map(ConstValue::into_value).collect(),
},
ConstValue::None => Value::None,
}
}
}
impl PartialEq for ConstValue {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(ConstValue::Unit, ConstValue::Unit) | (ConstValue::None, ConstValue::None) => true,
(ConstValue::Bool(a), ConstValue::Bool(b)) => a == b,
(ConstValue::Int(a), ConstValue::Int(b)) => a == b,
(ConstValue::Float(a), ConstValue::Float(b)) => a == b,
(ConstValue::StaticStr(a), ConstValue::StaticStr(b)) => a == b,
(ConstValue::Tuple(a), ConstValue::Tuple(b))
| (ConstValue::Array(a), ConstValue::Array(b)) => a == b,
(
ConstValue::Struct {
type_name: na,
fields: fa,
},
ConstValue::Struct {
type_name: nb,
fields: fb,
},
) => na == nb && fa == fb,
(
ConstValue::Enum {
type_name: na,
variant: va,
fields: fa,
},
ConstValue::Enum {
type_name: nb,
variant: vb,
fields: fb,
},
) => na == nb && va == vb && fa == fb,
_ => false,
}
}
}
pub fn value_from_archived(archived: &ArchivedConstValue) -> Value {
Value::from_const_archived(archived)
}
pub(crate) fn truncate_int(value: i64, word_bits_log2: u8) -> i64 {
if word_bits_log2 >= 6 {
return value;
}
let bits = 1u32 << word_bits_log2;
let shift = 64 - bits;
(value << shift) >> shift
}
#[cfg(test)]
mod cost_model_tests {
use super::*;
#[test]
fn nominal_cost_model_value_slot_bytes_matches_constant() {
assert_eq!(NOMINAL_COST_MODEL.value_slot_bytes, VALUE_SLOT_SIZE_BYTES);
}
#[test]
fn nominal_cost_model_cycles_match_op_cost_method() {
let ops: alloc::vec::Vec<Op> = alloc::vec![
Op::Const(0),
Op::PushUnit,
Op::Add,
Op::Mul,
Op::Div,
Op::NewArray(2),
Op::Call(0, 0),
Op::PushFunc(0),
Op::MakeClosure(0, 0),
Op::Yield,
];
for op in &ops {
assert_eq!(NOMINAL_COST_MODEL.cycles(op), op.cost());
}
}
#[test]
fn cost_model_slots_to_bytes_uses_slot_size() {
let model = CostModel {
value_slot_bytes: 8,
op_cycles: nominal_op_cycles,
};
assert_eq!(model.slots_to_bytes(0), 0);
assert_eq!(model.slots_to_bytes(1), 8);
assert_eq!(model.slots_to_bytes(4), 32);
}
#[test]
fn cost_model_heap_alloc_bytes_scales_with_slot_size() {
let nominal = NOMINAL_COST_MODEL;
let custom = CostModel {
value_slot_bytes: VALUE_SLOT_SIZE_BYTES / 2,
op_cycles: nominal_op_cycles,
};
let chunk = Chunk {
name: alloc::string::String::from("test"),
ops: alloc::vec::Vec::new(),
constants: alloc::vec::Vec::new(),
struct_templates: alloc::vec::Vec::new(),
local_count: 0,
param_count: 0,
block_type: BlockType::Func,
};
let op = Op::NewArray(4);
let nominal_bytes = nominal.heap_alloc_bytes(&op, &chunk);
let custom_bytes = custom.heap_alloc_bytes(&op, &chunk);
assert_eq!(nominal_bytes, 4 * VALUE_SLOT_SIZE_BYTES);
assert_eq!(custom_bytes, 4 * (VALUE_SLOT_SIZE_BYTES / 2));
assert_eq!(custom_bytes * 2, nominal_bytes);
}
#[test]
fn custom_cost_model_returns_custom_cycles() {
fn flat_hundred(_op: &Op) -> u32 {
100
}
let custom = CostModel {
value_slot_bytes: VALUE_SLOT_SIZE_BYTES,
op_cycles: flat_hundred,
};
assert_eq!(custom.cycles(&Op::Add), 100);
assert_eq!(custom.cycles(&Op::PushUnit), 100);
assert_eq!(custom.cycles(&Op::Call(0, 0)), 100);
}
}