use crate::buffer::{Reader, Writer};
use crate::config::Config;
use crate::context::{ContextCache, ReadContext, WriteContext};
use crate::ensure;
use crate::error::Error;
use crate::resolver::RefMode;
use crate::resolver::TypeResolver;
use crate::serializer::ForyDefault;
use crate::serializer::{Serializer, StructSerializer};
use crate::type_id::config_flags::{IS_CROSS_LANGUAGE_FLAG, IS_OUT_OF_BAND_FLAG};
use crate::type_id::SIZE_OF_REF_AND_TYPE;
use std::cell::UnsafeCell;
use std::mem;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::OnceLock;
static FORY_ID_COUNTER: AtomicU64 = AtomicU64::new(0);
thread_local! {
static WRITE_CONTEXTS: UnsafeCell<ContextCache<WriteContext<'static>>> =
UnsafeCell::new(ContextCache::new());
static READ_CONTEXTS: UnsafeCell<ContextCache<ReadContext<'static>>> =
UnsafeCell::new(ContextCache::new());
}
#[derive(Default)]
pub struct ForyBuilder {
config: Config,
compatible_set: bool,
}
impl ForyBuilder {
pub fn compatible(mut self, compatible: bool) -> Self {
self.compatible_set = true;
self.config.share_meta = compatible;
self.config.compatible = compatible;
if compatible {
self.config.check_struct_version = false;
} else if self.config.xlang {
self.config.check_struct_version = true;
}
self
}
pub fn xlang(mut self, xlang: bool) -> Self {
self.config.xlang = xlang;
if xlang && !self.compatible_set {
self.config.share_meta = true;
self.config.compatible = true;
self.config.check_struct_version = false;
return self;
}
if !self.config.check_struct_version {
self.config.check_struct_version = !self.config.compatible;
}
self
}
pub fn compress_string(mut self, compress_string: bool) -> Self {
self.config.compress_string = compress_string;
self
}
pub fn check_string_read(mut self, check_string_read: bool) -> Self {
self.config.check_string_read = check_string_read;
self
}
pub fn check_struct_version(mut self, check_struct_version: bool) -> Self {
if self.config.compatible && check_struct_version {
return self;
}
self.config.check_struct_version = check_struct_version;
self
}
pub fn track_ref(mut self, track_ref: bool) -> Self {
self.config.track_ref = track_ref;
self
}
pub fn max_dyn_depth(mut self, max_dyn_depth: u32) -> Self {
self.config.max_dyn_depth = max_dyn_depth;
self
}
pub fn max_binary_size(mut self, max_binary_size: u32) -> Self {
self.config.max_binary_size = max_binary_size;
self
}
pub fn max_collection_size(mut self, max_collection_size: u32) -> Self {
self.config.max_collection_size = max_collection_size;
self
}
pub fn build(self) -> Fory {
let mut config = self.config;
if config.xlang && !self.compatible_set {
config.share_meta = true;
config.compatible = true;
config.check_struct_version = false;
}
Fory::from_config(config)
}
}
pub struct Fory {
id: u64,
config: Config,
type_resolver: TypeResolver,
final_type_resolver: OnceLock<Result<TypeResolver, Error>>,
}
impl Default for Fory {
fn default() -> Self {
Self::builder().build()
}
}
impl Fory {
pub fn builder() -> ForyBuilder {
ForyBuilder::default()
}
fn from_config(config: Config) -> Self {
let mut type_resolver = TypeResolver::default();
type_resolver.set_compatible(config.compatible);
type_resolver.set_xlang(config.xlang);
Self {
id: FORY_ID_COUNTER.fetch_add(1, Ordering::Relaxed),
config,
type_resolver,
final_type_resolver: OnceLock::new(),
}
}
pub fn is_xlang(&self) -> bool {
self.config.xlang
}
pub fn is_compatible(&self) -> bool {
self.config.compatible
}
pub fn is_compress_string(&self) -> bool {
self.config.compress_string
}
pub fn is_check_string_read(&self) -> bool {
self.config.check_string_read
}
pub fn is_share_meta(&self) -> bool {
self.config.share_meta
}
pub fn get_max_dyn_depth(&self) -> u32 {
self.config.max_dyn_depth
}
pub fn get_max_binary_size(&self) -> u32 {
self.config.max_binary_size
}
pub fn get_max_collection_size(&self) -> u32 {
self.config.max_collection_size
}
pub fn is_check_struct_version(&self) -> bool {
self.config.check_struct_version
}
pub fn config(&self) -> &Config {
&self.config
}
fn check_registration_allowed(&self) -> Result<(), Error> {
if self.final_type_resolver.get().is_some() {
return Err(Error::not_allowed(
"Type registration is not allowed after the first serialize/deserialize call. \
The type resolver snapshot has already been finalized. \
Please complete all type registrations before performing any serialization or deserialization.",
));
}
Ok(())
}
pub fn serialize<T: Serializer>(&self, record: &T) -> Result<Vec<u8>, Error> {
self.with_write_context(
|context| match self.serialize_with_context(record, context) {
Ok(_) => {
let result = context.writer.dump();
context.writer.reset();
Ok(result)
}
Err(err) => {
context.writer.reset();
Err(err)
}
},
)
}
pub fn serialize_to<T: Serializer>(
&self,
buf: &mut Vec<u8>,
record: &T,
) -> Result<usize, Error> {
let start = buf.len();
self.with_write_context(|context| {
let outlive_buffer = unsafe { mem::transmute::<&mut Vec<u8>, &mut Vec<u8>>(buf) };
context.attach_writer(Writer::from_buffer(outlive_buffer));
let result = self.serialize_with_context(record, context);
let written_size = context.writer.len() - start;
context.detach_writer();
match result {
Ok(_) => Ok(written_size),
Err(err) => Err(err),
}
})
}
#[inline(always)]
fn get_final_type_resolver(&self) -> Result<&TypeResolver, Error> {
let result = self
.final_type_resolver
.get_or_init(|| self.type_resolver.build_final_type_resolver());
result
.as_ref()
.map_err(|e| Error::type_error(format!("Failed to build type resolver: {}", e)))
}
#[inline(always)]
fn with_write_context<R>(
&self,
f: impl FnOnce(&mut WriteContext) -> Result<R, Error>,
) -> Result<R, Error> {
WRITE_CONTEXTS.with(|cache| {
let cache = unsafe { &mut *cache.get() };
let id = self.id;
let context = cache.get_or_insert_result(id, || {
let type_resolver = self.get_final_type_resolver()?;
Ok(Box::new(WriteContext::new(
type_resolver.clone(),
self.config.clone(),
)))
})?;
f(context)
})
}
#[inline(always)]
fn serialize_with_context<T: Serializer>(
&self,
record: &T,
context: &mut WriteContext,
) -> Result<(), Error> {
let result = self.serialize_with_context_inner::<T>(record, context);
context.reset();
result
}
#[inline(always)]
fn serialize_with_context_inner<T: Serializer>(
&self,
record: &T,
context: &mut WriteContext,
) -> Result<(), Error> {
self.write_head::<T>(&mut context.writer);
let ref_mode = if self.config.track_ref {
RefMode::Tracking
} else {
RefMode::NullOnly
};
<T as Serializer>::fory_write(record, context, ref_mode, true, false)?;
Ok(())
}
pub fn register<T: 'static + StructSerializer + Serializer + ForyDefault>(
&mut self,
id: u32,
) -> Result<(), Error> {
self.check_registration_allowed()?;
self.type_resolver.register::<T>(id)
}
pub fn register_union<T: 'static + StructSerializer + Serializer + ForyDefault>(
&mut self,
id: u32,
) -> Result<(), Error> {
self.check_registration_allowed()?;
self.type_resolver.register_union::<T>(id)
}
pub fn register_by_name<T: 'static + StructSerializer + Serializer + ForyDefault>(
&mut self,
namespace: &str,
type_name: &str,
) -> Result<(), Error> {
self.check_registration_allowed()?;
self.type_resolver
.register_by_name::<T>(namespace, type_name)
}
pub fn register_union_by_name<T: 'static + StructSerializer + Serializer + ForyDefault>(
&mut self,
namespace: &str,
type_name: &str,
) -> Result<(), Error> {
self.check_registration_allowed()?;
self.type_resolver
.register_union_by_name::<T>(namespace, type_name)
}
pub fn register_serializer<T: Serializer + ForyDefault>(
&mut self,
id: u32,
) -> Result<(), Error> {
self.check_registration_allowed()?;
self.type_resolver.register_serializer::<T>(id)
}
pub fn register_serializer_by_name<T: Serializer + ForyDefault>(
&mut self,
namespace: &str,
type_name: &str,
) -> Result<(), Error> {
self.check_registration_allowed()?;
self.type_resolver
.register_serializer_by_name::<T>(namespace, type_name)
}
pub fn register_generic_trait<T: 'static + Serializer + ForyDefault>(
&mut self,
) -> Result<(), Error> {
self.check_registration_allowed()?;
self.type_resolver.register_generic_trait::<T>()
}
#[inline(always)]
pub fn write_head<T: Serializer>(&self, writer: &mut Writer) {
const HEAD_SIZE: usize = 10;
writer.reserve(T::fory_reserved_space() + SIZE_OF_REF_AND_TYPE + HEAD_SIZE);
let bitmap = if self.config.xlang {
IS_CROSS_LANGUAGE_FLAG
} else {
0
};
writer.write_u8(bitmap);
}
pub fn deserialize<T: Serializer + ForyDefault>(&self, bf: &[u8]) -> Result<T, Error> {
self.with_read_context(|context| {
let outlive_buffer = unsafe { mem::transmute::<&[u8], &[u8]>(bf) };
context.attach_reader(Reader::new(outlive_buffer));
let result = self.deserialize_with_context(context);
context.detach_reader();
result
})
}
pub fn deserialize_from<T: Serializer + ForyDefault>(
&self,
reader: &mut Reader,
) -> Result<T, Error> {
self.with_read_context(|context| {
let outlive_buffer = unsafe { mem::transmute::<&[u8], &[u8]>(reader.bf) };
let mut new_reader = Reader::new(outlive_buffer);
new_reader.set_cursor(reader.cursor);
context.attach_reader(new_reader);
let result = self.deserialize_with_context(context);
let end = context.detach_reader().get_cursor();
reader.set_cursor(end);
result
})
}
#[inline(always)]
fn with_read_context<R>(
&self,
f: impl FnOnce(&mut ReadContext) -> Result<R, Error>,
) -> Result<R, Error> {
READ_CONTEXTS.with(|cache| {
let cache = unsafe { &mut *cache.get() };
let id = self.id;
let context = cache.get_or_insert_result(id, || {
let type_resolver = self.get_final_type_resolver()?;
Ok(Box::new(ReadContext::new(
type_resolver.clone(),
self.config.clone(),
)))
})?;
f(context)
})
}
#[inline(always)]
fn deserialize_with_context<T: Serializer + ForyDefault>(
&self,
context: &mut ReadContext,
) -> Result<T, Error> {
let result = self.deserialize_with_context_inner::<T>(context);
context.reset();
result
}
#[inline(always)]
fn deserialize_with_context_inner<T: Serializer + ForyDefault>(
&self,
context: &mut ReadContext,
) -> Result<T, Error> {
self.read_head(&mut context.reader)?;
let ref_mode = if self.config.track_ref {
RefMode::Tracking
} else {
RefMode::NullOnly
};
let result = <T as Serializer>::fory_read(context, ref_mode, true);
context.ref_reader.resolve_callbacks();
result
}
#[inline(always)]
fn read_head(&self, reader: &mut Reader) -> Result<(), Error> {
let bitmap = reader.read_u8()?;
let expected = if self.config.xlang {
IS_CROSS_LANGUAGE_FLAG
} else {
0
};
if bitmap != expected {
return self.read_head_slow(bitmap, expected);
}
Ok(())
}
#[cold]
#[inline(never)]
fn read_head_slow(&self, bitmap: u8, expected: u8) -> Result<(), Error> {
const KNOWN_FLAGS: u8 = IS_CROSS_LANGUAGE_FLAG | IS_OUT_OF_BAND_FLAG;
ensure!(
(bitmap & !KNOWN_FLAGS) == 0 && (bitmap & IS_OUT_OF_BAND_FLAG) == 0,
Error::invalid_data("unsupported root header bitmap")
);
ensure!(
(bitmap & IS_CROSS_LANGUAGE_FLAG) == (expected & IS_CROSS_LANGUAGE_FLAG),
Error::invalid_data("header bitmap mismatch at xlang bit")
);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::Fory;
#[test]
fn xlang_defaults_to_compatible_unless_explicitly_set() {
let default_xlang = Fory::builder().xlang(true).build();
let explicit_schema_consistent = Fory::builder().compatible(false).xlang(true).build();
let explicit_schema_consistent_reverse_order =
Fory::builder().xlang(true).compatible(false).build();
assert!(default_xlang.is_compatible());
assert!(default_xlang.is_share_meta());
assert!(!default_xlang.is_check_struct_version());
assert!(!explicit_schema_consistent.is_compatible());
assert!(!explicit_schema_consistent.is_share_meta());
assert!(explicit_schema_consistent.is_check_struct_version());
assert!(!explicit_schema_consistent_reverse_order.is_compatible());
assert!(!explicit_schema_consistent_reverse_order.is_share_meta());
assert!(explicit_schema_consistent_reverse_order.is_check_struct_version());
}
}