use serde::{Deserialize, Serialize};
use super::{Chunk, CompiledProgram, TypeRegistry};
use crate::ast::{Program, Span};
use crate::interpreter::Value;
use std::sync::Arc;
pub const BLOB_SCHEMA_VERSION: u32 = 1;
#[derive(Debug, Serialize, Deserialize)]
pub struct ProgramBlob {
pub schema_version: u32,
pub chunks: Vec<WireChunk>,
pub func_names: Vec<String>,
pub is_tool: Vec<bool>,
pub type_registry_entries: Vec<(String, Vec<String>, u64)>,
pub ast_json: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct WireChunk {
pub code: Vec<u32>,
pub constants: Vec<WireConst>,
pub param_count: u8,
pub reg_count: u8,
pub spans: Vec<(u32, u32)>,
pub all_regs_numeric: bool,
}
#[derive(Debug, Serialize, Deserialize)]
pub enum WireConst {
Nil,
Number(f64),
Text(String),
Bool(bool),
List(Vec<WireConst>),
}
impl From<&Value> for WireConst {
fn from(v: &Value) -> Self {
match v {
Value::Nil => WireConst::Nil,
Value::Number(n) => WireConst::Number(*n),
Value::Text(s) => WireConst::Text(s.as_ref().clone()),
Value::Bool(b) => WireConst::Bool(*b),
Value::List(items) => WireConst::List(items.iter().map(WireConst::from).collect()),
other => panic!(
"aot_blob: unexpected chunk constant variant {:?} — only Nil/Number/Text/Bool/List are emitted by RegCompiler today; add a wire variant before lifting this",
std::mem::discriminant(other)
),
}
}
}
impl From<WireConst> for Value {
fn from(c: WireConst) -> Self {
match c {
WireConst::Nil => Value::Nil,
WireConst::Number(n) => Value::Number(n),
WireConst::Text(s) => Value::Text(Arc::new(s)),
WireConst::Bool(b) => Value::Bool(b),
WireConst::List(items) => {
Value::List(Arc::new(items.into_iter().map(Value::from).collect()))
}
}
}
}
impl WireChunk {
pub fn from_chunk(chunk: &Chunk) -> Self {
WireChunk {
code: chunk.code.clone(),
constants: chunk.constants.iter().map(WireConst::from).collect(),
param_count: chunk.param_count,
reg_count: chunk.reg_count,
spans: chunk
.spans
.iter()
.map(|s| (s.start as u32, s.end as u32))
.collect(),
all_regs_numeric: chunk.all_regs_numeric,
}
}
pub fn into_chunk(self) -> Chunk {
Chunk {
code: self.code,
constants: self.constants.into_iter().map(Value::from).collect(),
param_count: self.param_count,
reg_count: self.reg_count,
spans: self
.spans
.into_iter()
.map(|(s, e)| Span {
start: s as usize,
end: e as usize,
})
.collect(),
all_regs_numeric: self.all_regs_numeric,
stmt_debug: Vec::new(),
}
}
}
pub fn serialize_program(program: &CompiledProgram) -> Result<Vec<u8>, String> {
let chunks: Vec<WireChunk> = program.chunks.iter().map(WireChunk::from_chunk).collect();
let type_registry_entries: Vec<(String, Vec<String>, u64)> = program
.type_registry
.types
.iter()
.map(|ti| (ti.name.clone(), ti.fields.clone(), ti.num_fields))
.collect();
let ast_json = match &program.ast {
Some(ast) => serde_json::to_string(ast.as_ref())
.map_err(|e| format!("serde_json serialize ast: {}", e))?,
None => "{\"declarations\":[]}".to_string(),
};
let blob = ProgramBlob {
schema_version: BLOB_SCHEMA_VERSION,
chunks,
func_names: program.func_names.clone(),
is_tool: program.is_tool.clone(),
type_registry_entries,
ast_json,
};
postcard::to_allocvec(&blob).map_err(|e| format!("postcard serialize: {}", e))
}
pub fn deserialize_program(bytes: &[u8]) -> Result<CompiledProgram, String> {
let blob: ProgramBlob =
postcard::from_bytes(bytes).map_err(|e| format!("postcard deserialize: {}", e))?;
if blob.schema_version != BLOB_SCHEMA_VERSION {
return Err(format!(
"AOT program blob schema_version mismatch: binary embeds v{} but this runtime expects v{}. Recompile with this ilo version.",
blob.schema_version, BLOB_SCHEMA_VERSION
));
}
let n_wire_chunks = blob.chunks.len();
let n_func_names = blob.func_names.len();
let n_is_tool = blob.is_tool.len();
if n_wire_chunks != n_func_names {
return Err(format!(
"AOT blob invariant violated: chunks.len()={} != func_names.len()={}",
n_wire_chunks, n_func_names
));
}
if n_wire_chunks != n_is_tool {
return Err(format!(
"AOT blob invariant violated: chunks.len()={} != is_tool.len()={}",
n_wire_chunks, n_is_tool
));
}
let ast: Program = serde_json::from_str(&blob.ast_json)
.map_err(|e| format!("serde_json deserialize ast: {}", e))?;
let chunks: Vec<Chunk> = blob.chunks.into_iter().map(WireChunk::into_chunk).collect();
let nan_constants: Vec<Vec<super::NanVal>> = chunks
.iter()
.map(|c| c.constants.iter().map(super::NanVal::from_value).collect())
.collect();
let mut type_registry = TypeRegistry::default();
for (name, fields, num_fields) in blob.type_registry_entries {
type_registry.register(name, fields, num_fields);
}
let n_fns = blob.func_names.len();
let mut is_defer_fn = vec![false; n_fns];
for (i, decl) in ast
.declarations
.iter()
.filter(|d| {
matches!(
d,
crate::ast::Decl::Function { .. } | crate::ast::Decl::Tool { .. }
)
})
.enumerate()
{
if let crate::ast::Decl::Function { body, .. } = decl {
if i < n_fns {
is_defer_fn[i] = crate::vm::body_has_defer(body);
}
}
}
Ok(CompiledProgram {
chunks,
func_names: blob.func_names,
nan_constants,
type_registry,
is_tool: blob.is_tool,
is_defer_fn,
ast: Some(Arc::new(ast)),
defer_fns: std::collections::HashSet::new(),
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::vm::compile;
fn roundtrip(src: &str) -> CompiledProgram {
let tokens = crate::lexer::lex(src).expect("lex");
let token_spans: Vec<_> = tokens
.into_iter()
.map(|(t, r)| {
(
t,
crate::ast::Span {
start: r.start,
end: r.end,
},
)
})
.collect();
let (program, errors) = crate::parser::parse(token_spans);
assert!(errors.is_empty(), "parse errors: {:?}", errors);
let compiled = compile(&program).expect("compile");
let bytes = serialize_program(&compiled).expect("serialize");
deserialize_program(&bytes).expect("deserialize")
}
#[test]
fn empty_program_roundtrips() {
let r = roundtrip("main>n;42");
assert_eq!(r.func_names, vec!["main".to_string()]);
assert_eq!(r.chunks.len(), 1);
assert!(r.ast.is_some());
}
#[test]
fn schema_version_mismatch_is_rejected() {
let blob = ProgramBlob {
schema_version: 999,
chunks: vec![],
func_names: vec![],
is_tool: vec![],
type_registry_entries: vec![],
ast_json: "{\"declarations\":[]}".to_string(),
};
let bytes = postcard::to_allocvec(&blob).unwrap();
let err = match deserialize_program(&bytes) {
Ok(_) => panic!("expected schema mismatch error, got Ok"),
Err(e) => e,
};
assert!(err.contains("schema_version mismatch"), "got: {}", err);
}
#[test]
fn map_lambda_program_roundtrips() {
let r = roundtrip("main>L n;map (x:n>n;*x 2) [1,2,3]");
let nv: Vec<&str> = r.func_names.iter().map(|s| s.as_str()).collect();
assert!(nv.contains(&"main"));
assert!(nv.iter().any(|n| n.starts_with("__lit_")));
}
#[test]
fn fld_program_roundtrips() {
let r = roundtrip("add a:n b:n>n;+a b\nmain>n;fld add [1,2,3,4] 0");
assert!(r.func_names.contains(&"add".to_string()));
assert!(r.func_names.contains(&"main".to_string()));
}
#[test]
fn type_registry_roundtrips() {
let mut tr = TypeRegistry::default();
tr.register(
"point".to_string(),
vec!["x".to_string(), "y".to_string()],
0b11, );
let prog = CompiledProgram {
chunks: vec![],
func_names: vec![],
nan_constants: vec![],
type_registry: tr,
is_tool: vec![],
is_defer_fn: vec![],
ast: None,
defer_fns: std::collections::HashSet::new(),
};
let bytes = serialize_program(&prog).expect("serialize");
let r = deserialize_program(&bytes).expect("deserialize");
assert!(r.type_registry.name_to_id.contains_key("point"));
let id = r.type_registry.name_to_id["point"];
let info = &r.type_registry.types[id as usize];
assert_eq!(info.fields, vec!["x".to_string(), "y".to_string()]);
assert_eq!(info.num_fields, 0b11);
}
}