use std::cell::RefCell;
use std::collections::VecDeque;
use std::sync::Arc;
use std::sync::OnceLock;
use facet_core::{ConstTypeId, Facet};
use museair::bfast::HashMap;
use parking_lot::RwLock;
use super::compiler::{self, CachedModule, CompiledDeserializer};
use super::format_compiler::{self, CompiledFormatDeserializer};
use super::helpers;
use crate::{FormatJitParser, FormatParser};
type CacheKey = (ConstTypeId, ConstTypeId);
static CACHE: OnceLock<RwLock<HashMap<CacheKey, Arc<CachedModule>>>> = OnceLock::new();
fn cache() -> &'static RwLock<HashMap<CacheKey, Arc<CachedModule>>> {
CACHE.get_or_init(|| {
#[cfg(all(debug_assertions, unix))]
if std::env::var("FACET_JIT_CRASH_HANDLER").is_ok() {
super::crash_handler::install_crash_handler();
}
RwLock::new(HashMap::default())
})
}
pub fn get_or_compile<'de, T, P>(key: CacheKey) -> Option<CompiledDeserializer<T, P>>
where
T: Facet<'de>,
P: FormatParser<'de>,
{
{
let cache = cache().read();
if let Some(cached) = cache.get(&key) {
let vtable = helpers::make_vtable::<P>();
return Some(CompiledDeserializer::from_cached(
Arc::clone(cached),
vtable,
));
}
}
let result = compiler::try_compile_module::<T>()?;
let cached = Arc::new(CachedModule::new(
result.module,
result.nested_modules,
result.fn_ptr,
));
{
let mut cache = cache().write();
cache.entry(key).or_insert_with(|| Arc::clone(&cached));
}
let vtable = helpers::make_vtable::<P>();
Some(CompiledDeserializer::from_cached(cached, vtable))
}
#[cfg(test)]
#[allow(dead_code)]
pub fn clear_cache() {
if let Some(cache) = CACHE.get() {
cache.write().clear();
}
}
use std::sync::atomic::{AtomicU64, Ordering};
use super::Tier2Incompatibility;
use super::format_compiler::CachedFormatModule;
#[derive(Clone)]
pub enum CachedFormatCacheEntry {
Hit(Arc<CachedFormatModule>),
Miss(Tier2Incompatibility),
}
struct BoundedFormatCache {
entries: HashMap<CacheKey, CachedFormatCacheEntry>,
insertion_order: VecDeque<CacheKey>,
max_entries: usize,
}
impl BoundedFormatCache {
fn new() -> Self {
let max_entries = std::env::var("FACET_TIER2_CACHE_MAX_ENTRIES")
.ok()
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(1024);
Self {
entries: HashMap::default(),
insertion_order: VecDeque::new(),
max_entries,
}
}
fn get(&self, key: &CacheKey) -> Option<&CachedFormatCacheEntry> {
self.entries.get(key)
}
fn insert(&mut self, key: CacheKey, value: CachedFormatCacheEntry) {
if self.entries.contains_key(&key) {
self.insertion_order.retain(|k| k != &key);
}
while self.entries.len() >= self.max_entries && !self.insertion_order.is_empty() {
if let Some(oldest_key) = self.insertion_order.pop_front() {
self.entries.remove(&oldest_key);
CACHE_EVICTIONS.fetch_add(1, Ordering::Relaxed);
}
}
self.entries.insert(key, value);
self.insertion_order.push_back(key);
}
fn clear(&mut self) {
self.entries.clear();
self.insertion_order.clear();
}
}
static FORMAT_CACHE: OnceLock<RwLock<BoundedFormatCache>> = OnceLock::new();
fn format_cache() -> &'static RwLock<BoundedFormatCache> {
FORMAT_CACHE.get_or_init(|| RwLock::new(BoundedFormatCache::new()))
}
static CACHE_HIT: AtomicU64 = AtomicU64::new(0);
static CACHE_MISS_NEGATIVE: AtomicU64 = AtomicU64::new(0);
static CACHE_MISS_COMPILE: AtomicU64 = AtomicU64::new(0);
static CACHE_EVICTIONS: AtomicU64 = AtomicU64::new(0);
#[allow(dead_code)]
pub fn get_cache_stats() -> (u64, u64, u64, u64) {
(
CACHE_HIT.load(Ordering::Relaxed),
CACHE_MISS_NEGATIVE.load(Ordering::Relaxed),
CACHE_MISS_COMPILE.load(Ordering::Relaxed),
CACHE_EVICTIONS.load(Ordering::Relaxed),
)
}
pub fn reset_cache_stats() {
CACHE_HIT.store(0, Ordering::Relaxed);
CACHE_MISS_NEGATIVE.store(0, Ordering::Relaxed);
CACHE_MISS_COMPILE.store(0, Ordering::Relaxed);
CACHE_EVICTIONS.store(0, Ordering::Relaxed);
}
struct TlsCacheEntry {
key: CacheKey,
entry: CachedFormatCacheEntry,
}
thread_local! {
static FORMAT_TLS_CACHE: RefCell<Option<TlsCacheEntry>> = const { RefCell::new(None) };
}
pub fn get_or_compile_format<'de, T, P>(key: CacheKey) -> Option<CompiledFormatDeserializer<T, P>>
where
T: Facet<'de>,
P: FormatJitParser<'de>,
{
let tls_result = FORMAT_TLS_CACHE.with(|cache| {
let cache = cache.borrow();
if let Some(entry) = cache.as_ref()
&& entry.key == key
{
match &entry.entry {
CachedFormatCacheEntry::Hit(module) => {
CACHE_HIT.fetch_add(1, Ordering::Relaxed);
return Some(Some(CompiledFormatDeserializer::from_cached(Arc::clone(
module,
))));
}
CachedFormatCacheEntry::Miss(_reason) => {
CACHE_MISS_NEGATIVE.fetch_add(1, Ordering::Relaxed);
return Some(None);
}
}
}
None
});
if let Some(result) = tls_result {
return result;
}
let global_result = {
let cache = format_cache().read();
cache.get(&key).cloned()
};
if let Some(cached_entry) = global_result {
let entry = cached_entry.clone();
FORMAT_TLS_CACHE.with(|tls| {
*tls.borrow_mut() = Some(TlsCacheEntry {
key,
entry: cached_entry,
});
});
match entry {
CachedFormatCacheEntry::Hit(module) => {
CACHE_HIT.fetch_add(1, Ordering::Relaxed);
return Some(CompiledFormatDeserializer::from_cached(module));
}
CachedFormatCacheEntry::Miss(_reason) => {
CACHE_MISS_NEGATIVE.fetch_add(1, Ordering::Relaxed);
return None;
}
}
}
CACHE_MISS_COMPILE.fetch_add(1, Ordering::Relaxed);
let cache_entry = match format_compiler::try_compile_format_module::<T, P>() {
Ok((module, fn_ptr)) => {
let cached_module = Arc::new(CachedFormatModule::new(module, fn_ptr));
CachedFormatCacheEntry::Hit(cached_module)
}
Err(reason) => {
CachedFormatCacheEntry::Miss(reason)
}
};
{
let mut cache = format_cache().write();
if cache.get(&key).is_none() {
cache.insert(key, cache_entry.clone());
}
}
FORMAT_TLS_CACHE.with(|tls| {
*tls.borrow_mut() = Some(TlsCacheEntry {
key,
entry: cache_entry.clone(),
});
});
match cache_entry {
CachedFormatCacheEntry::Hit(module) => {
Some(CompiledFormatDeserializer::from_cached(module))
}
CachedFormatCacheEntry::Miss(_reason) => None,
}
}
pub fn get_format_deserializer<'de, T, P>() -> Option<CompiledFormatDeserializer<T, P>>
where
T: Facet<'de>,
P: FormatJitParser<'de>,
{
let key = (T::SHAPE.id, ConstTypeId::of::<P>());
get_or_compile_format::<T, P>(key)
}
pub fn clear_format_cache() {
if let Some(cache) = FORMAT_CACHE.get() {
cache.write().clear();
}
FORMAT_TLS_CACHE.with(|tls| {
*tls.borrow_mut() = None;
});
}
pub fn get_or_compile_format_with_reason<'de, T, P>(
key: CacheKey,
) -> Result<CompiledFormatDeserializer<T, P>, Tier2Incompatibility>
where
T: Facet<'de>,
P: FormatJitParser<'de>,
{
let tls_result = FORMAT_TLS_CACHE.with(|cache| {
let cache = cache.borrow();
if let Some(entry) = cache.as_ref()
&& entry.key == key
{
match &entry.entry {
CachedFormatCacheEntry::Hit(module) => {
CACHE_HIT.fetch_add(1, Ordering::Relaxed);
return Some(Ok(CompiledFormatDeserializer::from_cached(Arc::clone(
module,
))));
}
CachedFormatCacheEntry::Miss(reason) => {
CACHE_MISS_NEGATIVE.fetch_add(1, Ordering::Relaxed);
return Some(Err(reason.clone()));
}
}
}
None
});
if let Some(result) = tls_result {
return result;
}
let global_result = {
let cache = format_cache().read();
cache.get(&key).cloned()
};
if let Some(cached_entry) = global_result {
let entry = cached_entry.clone();
FORMAT_TLS_CACHE.with(|tls| {
*tls.borrow_mut() = Some(TlsCacheEntry {
key,
entry: cached_entry,
});
});
return match entry {
CachedFormatCacheEntry::Hit(module) => {
CACHE_HIT.fetch_add(1, Ordering::Relaxed);
Ok(CompiledFormatDeserializer::from_cached(module))
}
CachedFormatCacheEntry::Miss(reason) => {
CACHE_MISS_NEGATIVE.fetch_add(1, Ordering::Relaxed);
Err(reason)
}
};
}
CACHE_MISS_COMPILE.fetch_add(1, Ordering::Relaxed);
let cache_entry = match format_compiler::try_compile_format_module::<T, P>() {
Ok((module, fn_ptr)) => {
let cached_module = Arc::new(CachedFormatModule::new(module, fn_ptr));
CachedFormatCacheEntry::Hit(cached_module)
}
Err(reason) => CachedFormatCacheEntry::Miss(reason),
};
{
let mut cache = format_cache().write();
if cache.get(&key).is_none() {
cache.insert(key, cache_entry.clone());
}
}
FORMAT_TLS_CACHE.with(|tls| {
*tls.borrow_mut() = Some(TlsCacheEntry {
key,
entry: cache_entry.clone(),
});
});
match cache_entry {
CachedFormatCacheEntry::Hit(module) => Ok(CompiledFormatDeserializer::from_cached(module)),
CachedFormatCacheEntry::Miss(reason) => Err(reason),
}
}
pub fn get_format_deserializer_with_reason<'de, T, P>()
-> Result<CompiledFormatDeserializer<T, P>, Tier2Incompatibility>
where
T: Facet<'de>,
P: FormatJitParser<'de>,
{
let key = (T::SHAPE.id, ConstTypeId::of::<P>());
get_or_compile_format_with_reason::<T, P>(key)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_entry_clone() {
let miss = CachedFormatCacheEntry::Miss(Tier2Incompatibility::Not64BitPlatform);
let _cloned = miss.clone();
}
}