use std::collections::HashMap;
use std::marker::PhantomData;
use std::mem::MaybeUninit;
use std::sync::Arc;
use cranelift::codegen::ir::FuncRef;
use cranelift::prelude::*;
use cranelift_jit::{JITBuilder, JITModule};
use cranelift_module::FuncId;
use facet_core::{Def, Facet, Shape, Type, UserType};
use super::Tier2Incompatibility;
use super::format::{JitFormat, JitScratch, StructEncoding, make_c_sig};
use super::helpers;
use super::jit_debug;
use crate::jit::FormatJitParser;
use crate::{DeserializeError, DeserializeErrorKind};
mod support;
pub use support::*;
mod map_format_deserializer;
use map_format_deserializer::*;
mod struct_format_deserializer;
use struct_format_deserializer::*;
mod struct_positional_deserializer;
use struct_positional_deserializer::*;
mod enum_positional_deserializer;
use enum_positional_deserializer::*;
mod list_format_deserializer;
use list_format_deserializer::*;
fn tier2_call_sig(module: &mut JITModule, pointer_type: cranelift::prelude::Type) -> Signature {
let mut s = make_c_sig(module);
s.params.push(AbiParam::new(pointer_type)); s.params.push(AbiParam::new(pointer_type)); s.params.push(AbiParam::new(pointer_type)); s.params.push(AbiParam::new(pointer_type)); s.params.push(AbiParam::new(pointer_type)); s.returns.push(AbiParam::new(pointer_type)); s
}
fn func_addr_value(
builder: &mut FunctionBuilder,
pointer_type: cranelift::prelude::Type,
func_ref: FuncRef,
) -> Value {
builder.ins().func_addr(pointer_type, func_ref)
}
type ShapeMemo = HashMap<*const Shape, FuncId>;
struct BudgetLimits {
max_fields: usize,
max_nesting_depth: usize,
}
impl BudgetLimits {
fn from_env() -> Self {
let max_fields = std::env::var("FACET_TIER2_MAX_FIELDS")
.ok()
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(100);
let max_nesting_depth = std::env::var("FACET_TIER2_MAX_NESTING")
.ok()
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(10);
Self {
max_fields,
max_nesting_depth,
}
}
fn check_shape(
&self,
shape: &'static Shape,
type_name: &'static str,
) -> Result<(), Tier2Incompatibility> {
self.check_shape_recursive(shape, 0, type_name)
}
fn check_shape_recursive(
&self,
shape: &'static Shape,
depth: usize,
type_name: &'static str,
) -> Result<(), Tier2Incompatibility> {
if depth > self.max_nesting_depth {
jit_debug!(
"[Tier-2 JIT] Budget exceeded: nesting depth {} > {} max",
depth,
self.max_nesting_depth
);
return Err(Tier2Incompatibility::BudgetExceeded {
type_name,
reason: "nesting depth exceeded",
});
}
match &shape.def {
Def::Option(opt) => self.check_shape_recursive(opt.t, depth, type_name),
Def::List(list) => self.check_shape_recursive(list.t, depth + 1, type_name),
_ => {
if let Type::User(UserType::Struct(struct_def)) = &shape.ty {
if struct_def.fields.len() > self.max_fields {
jit_debug!(
"[Tier-2 JIT] Budget exceeded: {} fields > {} max",
struct_def.fields.len(),
self.max_fields
);
return Err(Tier2Incompatibility::BudgetExceeded {
type_name,
reason: "too many fields",
});
}
for field in struct_def.fields {
self.check_shape_recursive(field.shape(), depth + 1, type_name)?;
}
}
Ok(())
}
}
}
}
pub const T2_ERR_UNSUPPORTED: i32 = -1;
pub struct CachedFormatModule {
#[allow(dead_code)]
module: JITModule,
fn_ptr: *const u8,
}
impl CachedFormatModule {
pub const fn new(module: JITModule, fn_ptr: *const u8) -> Self {
Self { module, fn_ptr }
}
pub const fn fn_ptr(&self) -> *const u8 {
self.fn_ptr
}
}
unsafe impl Send for CachedFormatModule {}
unsafe impl Sync for CachedFormatModule {}
pub struct CompiledFormatDeserializer<T, P> {
fn_ptr: *const u8,
_cached: Arc<CachedFormatModule>,
_phantom: PhantomData<fn(&mut P) -> T>,
}
unsafe impl<T, P> Send for CompiledFormatDeserializer<T, P> {}
unsafe impl<T, P> Sync for CompiledFormatDeserializer<T, P> {}
impl<T, P> CompiledFormatDeserializer<T, P> {
pub fn from_cached(cached: Arc<CachedFormatModule>) -> Self {
let fn_ptr = cached.fn_ptr();
Self {
fn_ptr,
_cached: cached,
_phantom: PhantomData,
}
}
#[inline(always)]
pub const fn fn_ptr(&self) -> *const u8 {
self.fn_ptr
}
}
impl<'de, T: Facet<'de>, P: FormatJitParser<'de>> CompiledFormatDeserializer<T, P> {
pub fn deserialize(&self, parser: &mut P) -> Result<T, DeserializeError> {
let input = parser.jit_input();
let Some(pos) = parser.jit_pos() else {
return Err(DeserializeError {
span: None,
path: None,
kind: DeserializeErrorKind::Unsupported {
message: "Tier-2 JIT: parser has buffered state".into(),
},
});
};
jit_debug!("[Tier-2] Executing: input_len={}, pos={}", input.len(), pos);
let mut output: MaybeUninit<T> = MaybeUninit::uninit();
let mut scratch = JitScratch::default();
if let Some(max) = parser.jit_max_collection_elements() {
scratch.max_collection_elements = max;
}
type CompiledFn =
unsafe extern "C" fn(*const u8, usize, usize, *mut u8, *mut JitScratch) -> isize;
let fn_ptr = self.fn_ptr();
let func: CompiledFn = unsafe { std::mem::transmute(fn_ptr) };
jit_debug!("[Tier-2] Calling JIT function at {:p}", fn_ptr);
let result = unsafe {
func(
input.as_ptr(),
input.len(),
pos,
output.as_mut_ptr() as *mut u8,
&mut scratch,
)
};
jit_debug!("[Tier-2] JIT function returned: result={}", result);
if result >= 0 {
let new_pos = result as usize;
parser.jit_set_pos(new_pos);
jit_debug!("[Tier-2] Success! new_pos={}", new_pos);
Ok(unsafe { output.assume_init() })
} else {
jit_debug!(
"[Tier-2] Error: code={}, pos={}, output_initialized={}",
scratch.error_code,
scratch.error_pos,
scratch.output_initialized
);
if scratch.output_initialized != 0 {
match T::SHAPE.def {
Def::List(_) | Def::Map(_) => {
unsafe { output.assume_init_drop() };
}
_ => {
jit_debug!(
"[Tier-2] WARNING: Struct deserializer incorrectly set output_initialized=1"
);
}
}
}
if scratch.error_code == T2_ERR_UNSUPPORTED {
return Err(DeserializeError {
span: None,
path: None,
kind: DeserializeErrorKind::Unsupported {
message: "Tier-2 format operation not implemented".into(),
},
});
}
Err(parser
.jit_error(input, scratch.error_pos, scratch.error_code)
.into())
}
}
}
pub fn try_compile_format_module<'de, T, P>() -> Result<(JITModule, *const u8), Tier2Incompatibility>
where
T: Facet<'de>,
P: FormatJitParser<'de>,
{
let type_name = std::any::type_name::<T>();
let shape = T::SHAPE;
let encoding = P::FormatJit::STRUCT_ENCODING;
ensure_format_jit_compatible_with_encoding(shape, encoding, type_name)?;
let builder = match JITBuilder::new(cranelift_module::default_libcall_names()) {
Ok(b) => b,
Err(e) => {
jit_debug!("[Tier-2 JIT] JITBuilder::new failed: {:?}", e);
return Err(Tier2Incompatibility::JitBuilderFailed {
error: format!("{:?}", e),
});
}
};
let mut builder = builder;
let budget = BudgetLimits::from_env();
budget.check_shape(shape, type_name)?;
register_helpers(&mut builder);
P::FormatJit::register_helpers(&mut builder);
let mut module = JITModule::new(builder);
let mut memo = ShapeMemo::new();
let func_id = if let Def::List(_) = &shape.def {
match compile_list_format_deserializer::<P::FormatJit>(&mut module, shape, &mut memo) {
Some(id) => id,
None => {
jit_debug!("[Tier-2 JIT] compile_list_format_deserializer returned None");
return Err(Tier2Incompatibility::CompilationFailed {
type_name,
stage: "list deserializer",
});
}
}
} else if let Def::Map(_) = &shape.def {
match compile_map_format_deserializer::<P::FormatJit>(&mut module, shape, &mut memo) {
Some(id) => id,
None => {
jit_debug!("[Tier-2 JIT] compile_map_format_deserializer returned None");
return Err(Tier2Incompatibility::CompilationFailed {
type_name,
stage: "map deserializer",
});
}
}
} else if let Type::User(UserType::Struct(_)) = &shape.ty {
let func_id = match <P::FormatJit as JitFormat>::STRUCT_ENCODING {
StructEncoding::Map => {
compile_struct_format_deserializer::<P::FormatJit>(&mut module, shape, &mut memo)
}
StructEncoding::Positional => compile_struct_positional_deserializer::<P::FormatJit>(
&mut module,
shape,
&mut memo,
),
};
match func_id {
Some(id) => id,
None => {
jit_debug!("[Tier-2 JIT] compile_struct_format_deserializer returned None");
return Err(Tier2Incompatibility::CompilationFailed {
type_name,
stage: "struct deserializer",
});
}
}
} else if let Type::User(UserType::Enum(_)) = &shape.ty {
match compile_enum_positional_deserializer::<P::FormatJit>(&mut module, shape, &mut memo) {
Some(id) => id,
None => {
jit_debug!("[Tier-2 JIT] compile_enum_positional_deserializer returned None");
return Err(Tier2Incompatibility::CompilationFailed {
type_name,
stage: "enum deserializer",
});
}
}
} else {
jit_debug!("[Tier-2 JIT] Unsupported shape type");
return Err(Tier2Incompatibility::UnrecognizedShapeType { type_name });
};
if let Err(e) = module.finalize_definitions() {
jit_debug!("[Tier-2 JIT] finalize_definitions failed: {:?}", e);
return Err(Tier2Incompatibility::FinalizationFailed {
type_name,
error: format!("{:?}", e),
});
}
let fn_ptr = module.get_finalized_function(func_id);
Ok((module, fn_ptr))
}
fn register_helpers(builder: &mut JITBuilder) {
builder.symbol(
"jit_vec_init_with_capacity",
helpers::jit_vec_init_with_capacity as *const u8,
);
builder.symbol("jit_vec_push_bool", helpers::jit_vec_push_bool as *const u8);
builder.symbol("jit_vec_push_u8", helpers::jit_vec_push_u8 as *const u8);
builder.symbol("jit_vec_push_i64", helpers::jit_vec_push_i64 as *const u8);
builder.symbol("jit_vec_push_u64", helpers::jit_vec_push_u64 as *const u8);
builder.symbol("jit_vec_push_f64", helpers::jit_vec_push_f64 as *const u8);
builder.symbol(
"jit_vec_push_string",
helpers::jit_vec_push_string as *const u8,
);
builder.symbol("jit_vec_set_len", helpers::jit_vec_set_len as *const u8);
builder.symbol(
"jit_vec_as_mut_ptr_typed",
helpers::jit_vec_as_mut_ptr_typed as *const u8,
);
builder.symbol(
"jit_map_init_with_capacity",
helpers::jit_map_init_with_capacity as *const u8,
);
builder.symbol(
"jit_drop_owned_string",
helpers::jit_drop_owned_string as *const u8,
);
builder.symbol(
"jit_option_init_none",
helpers::jit_option_init_none as *const u8,
);
builder.symbol(
"jit_option_init_some_from_value",
helpers::jit_option_init_some_from_value as *const u8,
);
builder.symbol(
"jit_result_init_ok_from_value",
helpers::jit_result_init_ok_from_value as *const u8,
);
builder.symbol(
"jit_result_init_err_from_value",
helpers::jit_result_init_err_from_value as *const u8,
);
builder.symbol("jit_drop_in_place", helpers::jit_drop_in_place as *const u8);
builder.symbol("jit_write_string", helpers::jit_write_string as *const u8);
builder.symbol("jit_memcpy", helpers::jit_memcpy as *const u8);
builder.symbol(
"jit_write_error_string",
helpers::jit_write_error_string as *const u8,
);
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum FormatListElementKind {
Bool,
U8, I64,
U64,
F64,
String,
Struct(&'static Shape),
List(&'static Shape),
Map(&'static Shape),
}
impl FormatListElementKind {
fn from_shape(shape: &'static Shape) -> Option<Self> {
use facet_core::ScalarType;
if let Def::List(_) = &shape.def {
return Some(Self::List(shape));
}
if let Def::Map(_) = &shape.def {
return Some(Self::Map(shape));
}
if shape.is_type::<String>() {
return Some(Self::String);
}
if matches!(shape.ty, Type::User(UserType::Struct(_))) {
return Some(Self::Struct(shape));
}
let scalar_type = shape.scalar_type()?;
match scalar_type {
ScalarType::Bool => Some(Self::Bool),
ScalarType::U8 => Some(Self::U8), ScalarType::I8 | ScalarType::I16 | ScalarType::I32 | ScalarType::I64 => Some(Self::I64),
ScalarType::U16 | ScalarType::U32 | ScalarType::U64 => Some(Self::U64),
ScalarType::F32 | ScalarType::F64 => Some(Self::F64),
ScalarType::String => Some(Self::String),
_ => None,
}
}
}
#[derive(Debug)]
struct FieldCodegenInfo {
name: &'static str,
offset: usize,
shape: &'static Shape,
is_option: bool,
required_bit_index: Option<u8>,
}
struct FlattenedVariantInfo {
variant_name: &'static str,
enum_field_offset: usize,
discriminant: usize,
payload_shape: &'static Shape,
payload_offset_in_enum: usize,
enum_seen_bit_index: u8,
}
struct FlattenedMapInfo {
map_field_offset: usize,
value_shape: &'static Shape,
value_kind: FormatListElementKind,
}
enum DispatchTarget {
Field(usize),
FlattenEnumVariant(usize),
}
#[derive(Debug)]
enum KeyDispatchStrategy {
Inline,
Linear,
PrefixSwitch { prefix_len: usize },
}
fn compute_field_prefix(name: &str, prefix_len: usize) -> (u64, usize) {
let bytes = name.as_bytes();
let actual_len = bytes.len().min(prefix_len);
let mut prefix: u64 = 0;
for (i, &byte) in bytes.iter().take(actual_len).enumerate() {
prefix |= (byte as u64) << (i * 8);
}
(prefix, actual_len)
}
#[derive(Clone, Copy, Debug)]
struct KeyColonPattern {
pattern1: u64,
pattern1_len: usize,
pattern2: u64,
pattern2_len: usize,
total_len: usize,
}
fn compute_key_colon_pattern_extended(name: &str) -> Option<KeyColonPattern> {
let bytes = name.as_bytes();
let total_len = bytes.len() + 3;
if total_len > 16 {
return None; }
let mut full_pattern = [0u8; 16];
full_pattern[0] = b'"';
full_pattern[1..=bytes.len()].copy_from_slice(bytes);
full_pattern[bytes.len() + 1] = b'"';
full_pattern[bytes.len() + 2] = b':';
let pattern1_len = total_len.min(8);
let pattern1 = u64::from_le_bytes([
full_pattern[0],
full_pattern[1],
full_pattern[2],
full_pattern[3],
full_pattern[4],
full_pattern[5],
full_pattern[6],
full_pattern[7],
]);
let pattern2_len = total_len.saturating_sub(8);
let pattern2 = if pattern2_len > 0 {
u64::from_le_bytes([
full_pattern[8],
full_pattern[9],
full_pattern[10],
full_pattern[11],
full_pattern[12],
full_pattern[13],
full_pattern[14],
full_pattern[15],
])
} else {
0
};
Some(KeyColonPattern {
pattern1,
pattern1_len,
pattern2,
pattern2_len,
total_len,
})
}
struct PositionalFieldInfo {
name: &'static str,
offset: usize,
#[allow(dead_code)]
shape: &'static Shape,
kind: PositionalFieldKind,
}
#[derive(Clone, Debug)]
enum PositionalFieldKind {
Bool,
U8,
I8,
I64(facet_core::ScalarType),
U64(facet_core::ScalarType),
F32,
F64,
String,
Option(&'static facet_core::OptionDef),
Result(&'static facet_core::ResultDef),
Struct(&'static Shape),
List(&'static Shape),
Map(&'static Shape),
Enum(&'static Shape),
}
fn classify_positional_field(shape: &'static Shape) -> Option<PositionalFieldKind> {
use facet_core::ScalarType;
if let Def::Option(opt_def) = &shape.def {
return Some(PositionalFieldKind::Option(opt_def));
}
if let Def::Result(result_def) = &shape.def {
return Some(PositionalFieldKind::Result(result_def));
}
if let Def::List(_) = &shape.def {
return Some(PositionalFieldKind::List(shape));
}
if let Def::Map(_) = &shape.def {
return Some(PositionalFieldKind::Map(shape));
}
if matches!(shape.ty, Type::User(UserType::Enum(_))) {
return Some(PositionalFieldKind::Enum(shape));
}
if matches!(shape.ty, Type::User(UserType::Struct(_))) {
return Some(PositionalFieldKind::Struct(shape));
}
if shape.is_type::<String>() {
return Some(PositionalFieldKind::String);
}
let scalar_type = shape.scalar_type()?;
match scalar_type {
ScalarType::Bool => Some(PositionalFieldKind::Bool),
ScalarType::U8 => Some(PositionalFieldKind::U8),
ScalarType::I8 => Some(PositionalFieldKind::I8),
ScalarType::I16 | ScalarType::I32 | ScalarType::I64 => {
Some(PositionalFieldKind::I64(scalar_type))
}
ScalarType::U16 | ScalarType::U32 | ScalarType::U64 => {
Some(PositionalFieldKind::U64(scalar_type))
}
ScalarType::F32 => Some(PositionalFieldKind::F32),
ScalarType::F64 => Some(PositionalFieldKind::F64),
_ => None,
}
}