use cranelift::codegen::ir::AbiParam;
use cranelift::prelude::*;
use cranelift_jit::JITModule;
use cranelift_module::{FuncId, Linkage, Module};
use facet_core::{Def, Shape};
use super::super::format::{
JIT_SCRATCH_ERROR_CODE_OFFSET, JIT_SCRATCH_ERROR_POS_OFFSET,
JIT_SCRATCH_OUTPUT_INITIALIZED_OFFSET, JitCursor, JitFormat, make_c_sig,
};
use super::super::helpers;
use super::super::jit_debug;
use super::{
FormatListElementKind, ShapeMemo, compile_list_format_deserializer,
compile_struct_format_deserializer, compile_struct_positional_deserializer, func_addr_value,
tier2_call_sig,
};
pub(crate) fn compile_map_format_deserializer<F: JitFormat>(
module: &mut JITModule,
shape: &'static Shape,
memo: &mut ShapeMemo,
) -> Option<FuncId> {
jit_debug!(
"compile_map_format_deserializer ENTRY for shape {:p}",
shape
);
let shape_ptr = shape as *const Shape;
if let Some(&func_id) = memo.get(&shape_ptr) {
jit_debug!(
"compile_map_format_deserializer: using memoized FuncId for shape {:p}",
shape
);
return Some(func_id);
}
let Def::Map(map_def) = &shape.def else {
jit_debug!("[compile_map] Not a map");
return None;
};
if map_def.k.scalar_type() != Some(facet_core::ScalarType::String) {
jit_debug!("[compile_map] Only String keys supported");
return None;
}
let value_shape = map_def.v;
let value_kind = match FormatListElementKind::from_shape(value_shape) {
Some(k) => k,
None => {
jit_debug!("[compile_map] Value type not supported");
return None;
}
};
let from_pair_slice_fn = map_def.vtable.from_pair_slice?;
let pair_stride = map_def.vtable.pair_stride;
let value_offset_in_pair = map_def.vtable.value_offset_in_pair;
let pointer_type = module.target_config().pointer_type();
let sig = {
let mut s = make_c_sig(module);
s.params.push(AbiParam::new(pointer_type)); s.params.push(AbiParam::new(pointer_type)); s.params.push(AbiParam::new(pointer_type)); s.params.push(AbiParam::new(pointer_type)); s.params.push(AbiParam::new(pointer_type)); s.returns.push(AbiParam::new(pointer_type)); s
};
let func_name = format!("jit_deserialize_map_{:x}", shape as *const _ as usize);
let func_id = match module.declare_function(&func_name, Linkage::Local, &sig) {
Ok(id) => id,
Err(e) => {
jit_debug!("[compile_map] declare {} failed: {:?}", func_name, e);
jit_debug!("declare_function('{}') failed: {:?}", func_name, e);
return None;
}
};
memo.insert(shape_ptr, func_id);
jit_debug!(
"compile_map_format_deserializer: memoized FuncId for shape {:p}",
shape
);
let mut ctx = module.make_context();
ctx.func.signature = sig;
let mut builder_ctx = FunctionBuilderContext::new();
let mut builder = FunctionBuilder::new(&mut ctx.func, &mut builder_ctx);
let nested_call_sig_ref = builder.import_signature(tier2_call_sig(module, pointer_type));
let entry = builder.create_block();
builder.append_block_params_for_function_params(entry);
builder.switch_to_block(entry);
let input_ptr = builder.block_params(entry)[0];
let len = builder.block_params(entry)[1];
let pos_param = builder.block_params(entry)[2];
let out_ptr = builder.block_params(entry)[3];
let scratch_ptr = builder.block_params(entry)[4];
let pos_var = builder.declare_var(pointer_type);
builder.def_var(pos_var, pos_param);
let err_var = builder.declare_var(types::I32);
let zero_i32 = builder.ins().iconst(types::I32, 0);
builder.def_var(err_var, zero_i32);
let state_ptr = if F::MAP_STATE_SIZE > 0 {
let align_shift = F::MAP_STATE_ALIGN.trailing_zeros() as u8;
let state_slot = builder.create_sized_stack_slot(StackSlotData::new(
StackSlotKind::ExplicitSlot,
F::MAP_STATE_SIZE,
align_shift,
));
builder.ins().stack_addr(pointer_type, state_slot, 0)
} else {
builder.ins().iconst(pointer_type, 0)
};
let key_ptr_var = builder.declare_var(pointer_type);
let key_len_var = builder.declare_var(pointer_type);
let key_cap_var = builder.declare_var(pointer_type);
let key_owned_var = builder.declare_var(types::I8);
let zero_ptr = builder.ins().iconst(pointer_type, 0);
let zero_i8 = builder.ins().iconst(types::I8, 0);
builder.def_var(key_ptr_var, zero_ptr);
builder.def_var(key_len_var, zero_ptr);
builder.def_var(key_cap_var, zero_ptr);
builder.def_var(key_owned_var, zero_i8);
let collector_var = builder.declare_var(pointer_type);
builder.def_var(collector_var, zero_ptr);
let collector_new_sig_ref = {
let mut s = make_c_sig(module);
s.params.push(AbiParam::new(pointer_type)); s.params.push(AbiParam::new(pointer_type)); s.returns.push(AbiParam::new(pointer_type)); builder.import_signature(s)
};
let collector_new_ptr = builder.ins().iconst(
pointer_type,
helpers::jit_map_collector_new as *const u8 as i64,
);
let collector_push_sig_ref = {
let mut s = make_c_sig(module);
s.params.push(AbiParam::new(pointer_type)); s.params.push(AbiParam::new(pointer_type)); s.params.push(AbiParam::new(pointer_type)); s.params.push(AbiParam::new(pointer_type)); s.params.push(AbiParam::new(types::I8)); s.params.push(AbiParam::new(pointer_type)); s.params.push(AbiParam::new(pointer_type)); builder.import_signature(s)
};
let collector_push_ptr = builder.ins().iconst(
pointer_type,
helpers::jit_map_collector_push as *const u8 as i64,
);
let collector_finalize_sig_ref = {
let mut s = make_c_sig(module);
s.params.push(AbiParam::new(pointer_type)); s.params.push(AbiParam::new(pointer_type)); s.params.push(AbiParam::new(pointer_type)); builder.import_signature(s)
};
let collector_finalize_ptr = builder.ins().iconst(
pointer_type,
helpers::jit_map_collector_finalize as *const u8 as i64,
);
let collector_abort_sig_ref = {
let mut s = make_c_sig(module);
s.params.push(AbiParam::new(pointer_type)); builder.import_signature(s)
};
let collector_abort_ptr = builder.ins().iconst(
pointer_type,
helpers::jit_map_collector_abort as *const u8 as i64,
);
let write_string_sig_ref = {
let mut s = make_c_sig(module);
s.params.push(AbiParam::new(pointer_type)); s.params.push(AbiParam::new(pointer_type)); s.params.push(AbiParam::new(pointer_type)); s.params.push(AbiParam::new(pointer_type)); s.params.push(AbiParam::new(pointer_type)); s.params.push(AbiParam::new(types::I8)); builder.import_signature(s)
};
let write_string_ptr = builder
.ins()
.iconst(pointer_type, helpers::jit_write_string as *const u8 as i64);
let drop_owned_string_sig_ref = {
let mut s = make_c_sig(module);
s.params.push(AbiParam::new(pointer_type)); s.params.push(AbiParam::new(pointer_type)); s.params.push(AbiParam::new(pointer_type)); builder.import_signature(s)
};
let drop_owned_string_ptr = builder.ins().iconst(
pointer_type,
helpers::jit_drop_owned_string as *const u8 as i64,
);
let value_layout = match value_shape.layout.sized_layout() {
Ok(layout) => layout,
Err(_) => {
jit_debug!("[compile_map] Value shape has unsized layout");
return None;
}
};
let value_size = value_layout.size() as u32;
let value_align = value_layout.align().trailing_zeros() as u8;
let value_slot = builder.create_sized_stack_slot(StackSlotData::new(
StackSlotKind::ExplicitSlot,
value_size,
value_align,
));
let value_ptr = builder.ins().stack_addr(pointer_type, value_slot, 0);
let pair_stride_val = builder.ins().iconst(pointer_type, pair_stride as i64);
let value_offset_val = builder
.ins()
.iconst(pointer_type, value_offset_in_pair as i64);
let collector_result = builder.ins().call_indirect(
collector_new_sig_ref,
collector_new_ptr,
&[pair_stride_val, value_offset_val],
);
let collector = builder.inst_results(collector_result)[0];
builder.def_var(collector_var, collector);
let format = F::default();
let mut cursor = JitCursor {
input_ptr,
len,
pos: pos_var,
ptr_type: pointer_type,
scratch_ptr,
};
let loop_check_end = builder.create_block();
let loop_body = builder.create_block();
let done = builder.create_block();
let error = builder.create_block();
let nested_error_passthrough = builder.create_block();
let begin_err = format.emit_map_begin(module, &mut builder, &mut cursor, state_ptr);
builder.def_var(err_var, begin_err);
let begin_ok = builder.ins().icmp_imm(IntCC::Equal, begin_err, 0);
builder
.ins()
.brif(begin_ok, loop_check_end, &[], error, &[]);
builder.seal_block(entry);
builder.switch_to_block(loop_check_end);
let (is_end, end_err) = format.emit_map_is_end(module, &mut builder, &mut cursor, state_ptr);
builder.def_var(err_var, end_err);
let end_ok = builder.ins().icmp_imm(IntCC::Equal, end_err, 0);
let check_end_value = builder.create_block();
builder.ins().brif(end_ok, check_end_value, &[], error, &[]);
builder.switch_to_block(check_end_value);
builder.seal_block(check_end_value);
let is_end_bool = builder.ins().icmp_imm(IntCC::NotEqual, is_end, 0);
builder.ins().brif(is_end_bool, done, &[], loop_body, &[]);
builder.switch_to_block(loop_body);
builder.def_var(key_ptr_var, zero_ptr);
builder.def_var(key_len_var, zero_ptr);
builder.def_var(key_cap_var, zero_ptr);
builder.def_var(key_owned_var, zero_i8);
let (key_value, key_err) =
format.emit_map_read_key(module, &mut builder, &mut cursor, state_ptr);
builder.def_var(err_var, key_err);
let key_ok = builder.ins().icmp_imm(IntCC::Equal, key_err, 0);
let after_key = builder.create_block();
builder.ins().brif(key_ok, after_key, &[], error, &[]);
builder.switch_to_block(after_key);
builder.seal_block(after_key);
builder.def_var(key_ptr_var, key_value.ptr);
builder.def_var(key_len_var, key_value.len);
builder.def_var(key_cap_var, key_value.cap);
builder.def_var(key_owned_var, key_value.owned);
let sep_err = format.emit_map_kv_sep(module, &mut builder, &mut cursor, state_ptr);
builder.def_var(err_var, sep_err);
let sep_ok = builder.ins().icmp_imm(IntCC::Equal, sep_err, 0);
let after_sep = builder.create_block();
builder.ins().brif(sep_ok, after_sep, &[], error, &[]);
builder.switch_to_block(after_sep);
builder.seal_block(after_sep);
match value_kind {
FormatListElementKind::Bool => {
let (value_i8, err) = format.emit_parse_bool(module, &mut builder, &mut cursor);
builder.def_var(err_var, err);
let ok = builder.ins().icmp_imm(IntCC::Equal, err, 0);
let store = builder.create_block();
builder.ins().brif(ok, store, &[], error, &[]);
builder.switch_to_block(store);
builder
.ins()
.store(MemFlags::trusted(), value_i8, value_ptr, 0);
builder.seal_block(store);
}
FormatListElementKind::U8 => {
let (value_u8, err) = format.emit_parse_u8(module, &mut builder, &mut cursor);
builder.def_var(err_var, err);
let ok = builder.ins().icmp_imm(IntCC::Equal, err, 0);
let store = builder.create_block();
builder.ins().brif(ok, store, &[], error, &[]);
builder.switch_to_block(store);
builder
.ins()
.store(MemFlags::trusted(), value_u8, value_ptr, 0);
builder.seal_block(store);
}
FormatListElementKind::I64 => {
use facet_core::ScalarType;
let (value_i64, err) = format.emit_parse_i64(module, &mut builder, &mut cursor);
builder.def_var(err_var, err);
let ok = builder.ins().icmp_imm(IntCC::Equal, err, 0);
let store = builder.create_block();
builder.ins().brif(ok, store, &[], error, &[]);
builder.switch_to_block(store);
let scalar = value_shape.scalar_type().unwrap();
let value = match scalar {
ScalarType::I8 => builder.ins().ireduce(types::I8, value_i64),
ScalarType::I16 => builder.ins().ireduce(types::I16, value_i64),
ScalarType::I32 => builder.ins().ireduce(types::I32, value_i64),
ScalarType::I64 => value_i64,
_ => value_i64,
};
builder
.ins()
.store(MemFlags::trusted(), value, value_ptr, 0);
builder.seal_block(store);
}
FormatListElementKind::U64 => {
use facet_core::ScalarType;
let (value_u64, err) = format.emit_parse_u64(module, &mut builder, &mut cursor);
builder.def_var(err_var, err);
let ok = builder.ins().icmp_imm(IntCC::Equal, err, 0);
let store = builder.create_block();
builder.ins().brif(ok, store, &[], error, &[]);
builder.switch_to_block(store);
let scalar = value_shape.scalar_type().unwrap();
let value = match scalar {
ScalarType::U8 => builder.ins().ireduce(types::I8, value_u64),
ScalarType::U16 => builder.ins().ireduce(types::I16, value_u64),
ScalarType::U32 => builder.ins().ireduce(types::I32, value_u64),
ScalarType::U64 => value_u64,
_ => value_u64,
};
builder
.ins()
.store(MemFlags::trusted(), value, value_ptr, 0);
builder.seal_block(store);
}
FormatListElementKind::F64 => {
use facet_core::ScalarType;
let (value_f64, err) = format.emit_parse_f64(module, &mut builder, &mut cursor);
builder.def_var(err_var, err);
let ok = builder.ins().icmp_imm(IntCC::Equal, err, 0);
let store = builder.create_block();
builder.ins().brif(ok, store, &[], error, &[]);
builder.switch_to_block(store);
let scalar = value_shape.scalar_type().unwrap();
let value = if matches!(scalar, ScalarType::F32) {
builder.ins().fdemote(types::F32, value_f64)
} else {
value_f64
};
builder
.ins()
.store(MemFlags::trusted(), value, value_ptr, 0);
builder.seal_block(store);
}
FormatListElementKind::String => {
let (string_value, err) = format.emit_parse_string(module, &mut builder, &mut cursor);
builder.def_var(err_var, err);
let ok = builder.ins().icmp_imm(IntCC::Equal, err, 0);
let store = builder.create_block();
builder.ins().brif(ok, store, &[], error, &[]);
builder.switch_to_block(store);
let zero_offset = builder.ins().iconst(pointer_type, 0);
builder.ins().call_indirect(
write_string_sig_ref,
write_string_ptr,
&[
value_ptr,
zero_offset,
string_value.ptr,
string_value.len,
string_value.cap,
string_value.owned,
],
);
builder.seal_block(store);
}
FormatListElementKind::Struct(_) => {
let struct_func_id = match F::STRUCT_ENCODING {
crate::jit::StructEncoding::Map => {
compile_struct_format_deserializer::<F>(module, value_shape, memo)?
}
crate::jit::StructEncoding::Positional => {
compile_struct_positional_deserializer::<F>(module, value_shape, memo)?
}
};
let struct_func_ref = module.declare_func_in_func(struct_func_id, builder.func);
let current_pos = builder.use_var(pos_var);
let struct_func_ptr = func_addr_value(&mut builder, pointer_type, struct_func_ref);
let call_result = builder.ins().call_indirect(
nested_call_sig_ref,
struct_func_ptr,
&[input_ptr, len, current_pos, value_ptr, scratch_ptr],
);
let new_pos = builder.inst_results(call_result)[0];
let is_error = builder.ins().icmp_imm(IntCC::SignedLessThan, new_pos, 0);
let nested_ok = builder.create_block();
builder
.ins()
.brif(is_error, nested_error_passthrough, &[], nested_ok, &[]);
builder.switch_to_block(nested_ok);
builder.def_var(pos_var, new_pos);
builder.seal_block(nested_ok);
}
FormatListElementKind::List(_) => {
let list_func_id = compile_list_format_deserializer::<F>(module, value_shape, memo)?;
let list_func_ref = module.declare_func_in_func(list_func_id, builder.func);
let current_pos = builder.use_var(pos_var);
let list_func_ptr = func_addr_value(&mut builder, pointer_type, list_func_ref);
let call_result = builder.ins().call_indirect(
nested_call_sig_ref,
list_func_ptr,
&[input_ptr, len, current_pos, value_ptr, scratch_ptr],
);
let new_pos = builder.inst_results(call_result)[0];
let is_error = builder.ins().icmp_imm(IntCC::SignedLessThan, new_pos, 0);
let nested_ok = builder.create_block();
builder
.ins()
.brif(is_error, nested_error_passthrough, &[], nested_ok, &[]);
builder.switch_to_block(nested_ok);
builder.def_var(pos_var, new_pos);
builder.seal_block(nested_ok);
}
FormatListElementKind::Map(_) => {
let map_func_id = compile_map_format_deserializer::<F>(module, value_shape, memo)?;
let map_func_ref = module.declare_func_in_func(map_func_id, builder.func);
let current_pos = builder.use_var(pos_var);
let map_func_ptr = func_addr_value(&mut builder, pointer_type, map_func_ref);
let call_result = builder.ins().call_indirect(
nested_call_sig_ref,
map_func_ptr,
&[input_ptr, len, current_pos, value_ptr, scratch_ptr],
);
let new_pos = builder.inst_results(call_result)[0];
let is_error = builder.ins().icmp_imm(IntCC::SignedLessThan, new_pos, 0);
let nested_ok = builder.create_block();
builder
.ins()
.brif(is_error, nested_error_passthrough, &[], nested_ok, &[]);
builder.switch_to_block(nested_ok);
builder.def_var(pos_var, new_pos);
builder.seal_block(nested_ok);
}
}
let key_ptr_raw = builder.use_var(key_ptr_var);
let key_len_raw = builder.use_var(key_len_var);
let key_cap_raw = builder.use_var(key_cap_var);
let key_owned_raw = builder.use_var(key_owned_var);
let value_size_val = builder.ins().iconst(pointer_type, value_size as i64);
let collector_val = builder.use_var(collector_var);
builder.ins().call_indirect(
collector_push_sig_ref,
collector_push_ptr,
&[
collector_val,
key_ptr_raw,
key_len_raw,
key_cap_raw,
key_owned_raw,
value_ptr,
value_size_val,
],
);
builder.def_var(key_owned_var, zero_i8);
let next_err = format.emit_map_next(module, &mut builder, &mut cursor, state_ptr);
builder.def_var(err_var, next_err);
let next_ok = builder.ins().icmp_imm(IntCC::Equal, next_err, 0);
let after_next = builder.create_block();
builder.ins().brif(next_ok, after_next, &[], error, &[]);
builder.switch_to_block(after_next);
builder.seal_block(after_next);
builder.ins().jump(loop_check_end, &[]);
builder.seal_block(loop_body);
builder.seal_block(loop_check_end);
builder.switch_to_block(done);
let collector_val = builder.use_var(collector_var);
let from_pair_slice_fn_ptr = builder
.ins()
.iconst(pointer_type, from_pair_slice_fn as *const u8 as i64);
builder.ins().call_indirect(
collector_finalize_sig_ref,
collector_finalize_ptr,
&[collector_val, out_ptr, from_pair_slice_fn_ptr],
);
let one_i8 = builder.ins().iconst(types::I8, 1);
builder.ins().store(
MemFlags::trusted(),
one_i8,
scratch_ptr,
JIT_SCRATCH_OUTPUT_INITIALIZED_OFFSET,
);
let final_pos = builder.use_var(pos_var);
builder.ins().return_(&[final_pos]);
builder.seal_block(done);
builder.switch_to_block(nested_error_passthrough);
let collector_val = builder.use_var(collector_var);
builder.ins().call_indirect(
collector_abort_sig_ref,
collector_abort_ptr,
&[collector_val],
);
let key_owned = builder.use_var(key_owned_var);
let need_drop = builder.ins().icmp_imm(IntCC::NotEqual, key_owned, 0);
let drop_key = builder.create_block();
let nested_after_drop = builder.create_block();
builder
.ins()
.brif(need_drop, drop_key, &[], nested_after_drop, &[]);
builder.switch_to_block(drop_key);
let key_ptr_val = builder.use_var(key_ptr_var);
let key_len_val = builder.use_var(key_len_var);
let key_cap_val = builder.use_var(key_cap_var);
builder.ins().call_indirect(
drop_owned_string_sig_ref,
drop_owned_string_ptr,
&[key_ptr_val, key_len_val, key_cap_val],
);
builder.ins().jump(nested_after_drop, &[]);
builder.seal_block(drop_key);
builder.switch_to_block(nested_after_drop);
let minus_one = builder.ins().iconst(pointer_type, -1i64);
builder.ins().return_(&[minus_one]);
builder.seal_block(nested_after_drop);
builder.seal_block(nested_error_passthrough);
builder.switch_to_block(error);
let collector_val = builder.use_var(collector_var);
let collector_is_null = builder.ins().icmp_imm(IntCC::Equal, collector_val, 0);
let abort_collector = builder.create_block();
let after_abort = builder.create_block();
builder
.ins()
.brif(collector_is_null, after_abort, &[], abort_collector, &[]);
builder.switch_to_block(abort_collector);
builder.ins().call_indirect(
collector_abort_sig_ref,
collector_abort_ptr,
&[collector_val],
);
builder.ins().jump(after_abort, &[]);
builder.seal_block(abort_collector);
builder.switch_to_block(after_abort);
let key_owned = builder.use_var(key_owned_var);
let need_drop = builder.ins().icmp_imm(IntCC::NotEqual, key_owned, 0);
let drop_key = builder.create_block();
let after_drop = builder.create_block();
builder
.ins()
.brif(need_drop, drop_key, &[], after_drop, &[]);
builder.switch_to_block(drop_key);
let key_ptr_val = builder.use_var(key_ptr_var);
let key_len_val = builder.use_var(key_len_var);
let key_cap_val = builder.use_var(key_cap_var);
builder.ins().call_indirect(
drop_owned_string_sig_ref,
drop_owned_string_ptr,
&[key_ptr_val, key_len_val, key_cap_val],
);
builder.ins().jump(after_drop, &[]);
builder.seal_block(drop_key);
builder.switch_to_block(after_drop);
let err_code = builder.use_var(err_var);
let err_pos = builder.use_var(pos_var);
builder.ins().store(
MemFlags::trusted(),
err_code,
scratch_ptr,
JIT_SCRATCH_ERROR_CODE_OFFSET,
);
builder.ins().store(
MemFlags::trusted(),
err_pos,
scratch_ptr,
JIT_SCRATCH_ERROR_POS_OFFSET,
);
let minus_one = builder.ins().iconst(pointer_type, -1i64);
builder.ins().return_(&[minus_one]);
builder.seal_block(after_drop);
builder.seal_block(after_abort);
builder.seal_block(error);
builder.finalize();
if let Err(_e) = module.define_function(func_id, &mut ctx) {
jit_debug!("[compile_map] define_function failed: {:?}", _e);
return None;
}
jit_debug!("[compile_map] SUCCESS - HashMap<String, V> function compiled");
Some(func_id)
}