use crate::core::error::{Error, Result};
use crate::{read_lock_safe, write_lock_safe};
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant, SystemTime};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct FunctionId {
pub name: String,
pub input_type: String,
pub output_type: String,
pub operation_signature: String,
pub optimization_level: u8,
}
impl FunctionId {
pub fn new(
name: impl Into<String>,
input_type: impl Into<String>,
output_type: impl Into<String>,
operation_signature: impl Into<String>,
optimization_level: u8,
) -> Self {
Self {
name: name.into(),
input_type: input_type.into(),
output_type: output_type.into(),
operation_signature: operation_signature.into(),
optimization_level,
}
}
pub fn hash_value(&self) -> u64 {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
self.hash(&mut hasher);
hasher.finish()
}
}
#[derive(Debug, Clone)]
pub struct CachedFunctionMetadata {
pub compiled_at: SystemTime,
pub execution_count: u64,
pub total_execution_time_ns: u64,
pub avg_execution_time_ns: f64,
pub function_size_bytes: usize,
pub compilation_time_ns: u64,
pub is_hot: bool,
pub last_accessed: Instant,
pub effectiveness_score: f64,
}
impl CachedFunctionMetadata {
pub fn new(compilation_time_ns: u64, function_size_bytes: usize) -> Self {
Self {
compiled_at: SystemTime::now(),
execution_count: 0,
total_execution_time_ns: 0,
avg_execution_time_ns: 0.0,
function_size_bytes,
compilation_time_ns,
is_hot: false,
last_accessed: Instant::now(),
effectiveness_score: 1.0,
}
}
pub fn record_execution(&mut self, execution_time_ns: u64) {
self.execution_count += 1;
self.total_execution_time_ns += execution_time_ns;
self.avg_execution_time_ns =
self.total_execution_time_ns as f64 / self.execution_count as f64;
self.last_accessed = Instant::now();
if self.execution_count > 100 && self.avg_execution_time_ns < 1_000_000.0 {
self.is_hot = true;
}
}
pub fn cache_benefit(&self) -> f64 {
if self.execution_count == 0 {
return 0.0;
}
let amortized_compilation_cost =
self.compilation_time_ns as f64 / self.execution_count as f64;
let avg_execution_savings = amortized_compilation_cost * self.effectiveness_score;
avg_execution_savings
}
pub fn should_evict(&self, cache_pressure: f64) -> bool {
let time_since_access = self.last_accessed.elapsed().as_secs_f64();
let cache_benefit = self.cache_benefit();
(cache_pressure > 0.8 && cache_benefit < 1000.0) ||
(time_since_access > 3600.0) || (self.effectiveness_score < 0.5)
}
}
pub struct CachedFunction {
pub metadata: CachedFunctionMetadata,
pub function: Box<dyn std::any::Any + Send + Sync>,
pub signature: FunctionSignature,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FunctionSignature {
pub input_types: Vec<String>,
pub output_type: String,
pub is_variadic: bool,
}
impl FunctionSignature {
pub fn new(input_types: Vec<String>, output_type: String, is_variadic: bool) -> Self {
Self {
input_types,
output_type,
is_variadic,
}
}
pub fn matches(&self, other: &FunctionSignature) -> bool {
self.output_type == other.output_type
&& (self.is_variadic || self.input_types == other.input_types)
}
}
pub struct JitFunctionCache {
cache: RwLock<HashMap<FunctionId, CachedFunction>>,
max_cache_size_bytes: usize,
current_cache_size_bytes: RwLock<usize>,
cache_hits: RwLock<u64>,
cache_misses: RwLock<u64>,
cache_evictions: RwLock<u64>,
}
impl JitFunctionCache {
pub fn new(max_size_mb: usize) -> Self {
Self {
cache: RwLock::new(HashMap::new()),
max_cache_size_bytes: max_size_mb * 1024 * 1024,
current_cache_size_bytes: RwLock::new(0),
cache_hits: RwLock::new(0),
cache_misses: RwLock::new(0),
cache_evictions: RwLock::new(0),
}
}
pub fn get(&self, function_id: &FunctionId) -> Option<Arc<CachedFunction>> {
let cache = read_lock_safe!(self.cache, "jit cache read").ok()?;
if let Some(cached) = cache.get(function_id) {
if let Ok(mut hits) = write_lock_safe!(self.cache_hits, "jit cache hits write") {
*hits += 1;
}
Some(Arc::new(CachedFunction {
metadata: cached.metadata.clone(),
function: unsafe {
std::mem::transmute_copy(&cached.function)
},
signature: cached.signature.clone(),
}))
} else {
if let Ok(mut misses) = write_lock_safe!(self.cache_misses, "jit cache misses write") {
*misses += 1;
}
None
}
}
pub fn store(
&self,
function_id: FunctionId,
function: Box<dyn std::any::Any + Send + Sync>,
signature: FunctionSignature,
metadata: CachedFunctionMetadata,
) -> Result<()> {
self.evict_if_needed(metadata.function_size_bytes)?;
let cached_function = CachedFunction {
metadata,
function,
signature,
};
let mut cache = write_lock_safe!(self.cache, "jit cache write")?;
let mut current_size =
write_lock_safe!(self.current_cache_size_bytes, "jit cache size write")?;
if let Some(old_function) = cache.remove(&function_id) {
*current_size -= old_function.metadata.function_size_bytes;
}
*current_size += cached_function.metadata.function_size_bytes;
cache.insert(function_id, cached_function);
Ok(())
}
pub fn record_execution(&self, function_id: &FunctionId, execution_time_ns: u64) -> Result<()> {
let mut cache = write_lock_safe!(self.cache, "jit cache write")?;
if let Some(cached_function) = cache.get_mut(function_id) {
cached_function.metadata.record_execution(execution_time_ns);
}
Ok(())
}
pub fn get_stats(&self) -> Result<CacheStats> {
let hits = *read_lock_safe!(self.cache_hits, "jit cache hits read")?;
let misses = *read_lock_safe!(self.cache_misses, "jit cache misses read")?;
let evictions = *read_lock_safe!(self.cache_evictions, "jit cache evictions read")?;
let cache_size = *read_lock_safe!(self.current_cache_size_bytes, "jit cache size read")?;
let cache_entries = read_lock_safe!(self.cache, "jit cache read")?.len();
let hit_rate = if hits + misses > 0 {
hits as f64 / (hits + misses) as f64
} else {
0.0
};
Ok(CacheStats {
hit_rate,
hits,
misses,
evictions,
cache_size_bytes: cache_size,
cache_entries,
max_cache_size_bytes: self.max_cache_size_bytes,
})
}
pub fn clear(&self) -> Result<()> {
let mut cache = write_lock_safe!(self.cache, "jit cache write")?;
let mut current_size =
write_lock_safe!(self.current_cache_size_bytes, "jit cache size write")?;
cache.clear();
*current_size = 0;
Ok(())
}
fn evict_if_needed(&self, new_function_size: usize) -> Result<()> {
let current_size = *read_lock_safe!(self.current_cache_size_bytes, "jit cache size read")?;
if current_size + new_function_size <= self.max_cache_size_bytes {
return Ok(()); }
let mut cache = write_lock_safe!(self.cache, "jit cache write")?;
let mut size = write_lock_safe!(self.current_cache_size_bytes, "jit cache size write")?;
let mut evictions = write_lock_safe!(self.cache_evictions, "jit cache evictions write")?;
let cache_pressure = (*size + new_function_size) as f64 / self.max_cache_size_bytes as f64;
let mut to_evict = Vec::new();
for (id, cached_function) in cache.iter() {
if cached_function.metadata.should_evict(cache_pressure) {
let benefit = cached_function.metadata.cache_benefit();
to_evict.push((
id.clone(),
cached_function.metadata.function_size_bytes,
benefit,
));
}
}
to_evict.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal));
let needed_space = (*size + new_function_size).saturating_sub(self.max_cache_size_bytes);
let mut freed_space = 0;
for (function_id, function_size, _benefit) in to_evict {
if freed_space >= needed_space {
break;
}
cache.remove(&function_id);
*size -= function_size;
freed_space += function_size;
*evictions += 1;
}
if *size + new_function_size > self.max_cache_size_bytes {
return Err(Error::InvalidOperation(
"Unable to free enough cache space for new function".to_string(),
));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub hit_rate: f64,
pub hits: u64,
pub misses: u64,
pub evictions: u64,
pub cache_size_bytes: usize,
pub cache_entries: usize,
pub max_cache_size_bytes: usize,
}
impl CacheStats {
pub fn utilization_percent(&self) -> f64 {
if self.max_cache_size_bytes > 0 {
(self.cache_size_bytes as f64 / self.max_cache_size_bytes as f64) * 100.0
} else {
0.0
}
}
pub fn avg_function_size_bytes(&self) -> f64 {
if self.cache_entries > 0 {
self.cache_size_bytes as f64 / self.cache_entries as f64
} else {
0.0
}
}
}
static GLOBAL_CACHE: std::sync::OnceLock<JitFunctionCache> = std::sync::OnceLock::new();
pub fn get_global_cache() -> &'static JitFunctionCache {
GLOBAL_CACHE.get_or_init(|| JitFunctionCache::new(128)) }
pub fn init_global_cache(max_size_mb: usize) -> Result<()> {
GLOBAL_CACHE
.set(JitFunctionCache::new(max_size_mb))
.map_err(|_| Error::InvalidOperation("Global cache already initialized".to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_function_id_creation() {
let id = FunctionId::new("test_sum", "f64", "f64", "sum_operation", 2);
assert_eq!(id.name, "test_sum");
assert_eq!(id.input_type, "f64");
assert_eq!(id.output_type, "f64");
assert_eq!(id.optimization_level, 2);
}
#[test]
fn test_cache_metadata() {
let mut metadata = CachedFunctionMetadata::new(1_000_000, 1024);
assert_eq!(metadata.execution_count, 0);
metadata.record_execution(500_000);
assert_eq!(metadata.execution_count, 1);
assert_eq!(metadata.avg_execution_time_ns, 500_000.0);
}
#[test]
fn test_function_signature_matching() {
let sig1 = FunctionSignature::new(vec!["f64".to_string()], "f64".to_string(), false);
let sig2 = FunctionSignature::new(vec!["f64".to_string()], "f64".to_string(), false);
let sig3 = FunctionSignature::new(vec!["i64".to_string()], "f64".to_string(), false);
assert!(sig1.matches(&sig2));
assert!(!sig1.matches(&sig3));
}
#[test]
fn test_cache_operations() {
let cache = JitFunctionCache::new(1); let function_id = FunctionId::new("test", "f64", "f64", "test_op", 1);
assert!(cache.get(&function_id).is_none());
let stats = cache.get_stats().expect("operation should succeed");
assert_eq!(stats.misses, 1);
assert_eq!(stats.hits, 0);
}
}