use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use wasmtime::{Config, OptLevel, Strategy};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum CompilationTier {
Interpreter,
Baseline,
Optimized,
}
impl CompilationTier {
pub fn is_higher_than(&self, other: &Self) -> bool {
(*self as u8) > (*other as u8)
}
}
#[derive(Debug, Clone)]
pub struct FunctionProfile {
pub function_index: u32,
pub call_count: u64,
pub total_execution_ns: u64,
pub avg_execution_ns: u64,
pub peak_execution_ns: u64,
pub tier: CompilationTier,
pub last_optimized: Option<Instant>,
}
impl FunctionProfile {
pub fn new(function_index: u32) -> Self {
Self {
function_index,
call_count: 0,
total_execution_ns: 0,
avg_execution_ns: 0,
peak_execution_ns: 0,
tier: CompilationTier::Interpreter,
last_optimized: None,
}
}
pub fn record_call(&mut self, duration_ns: u64) {
self.call_count += 1;
self.total_execution_ns += duration_ns;
self.avg_execution_ns = self.total_execution_ns / self.call_count;
if duration_ns > self.peak_execution_ns {
self.peak_execution_ns = duration_ns;
}
}
pub fn is_hot(&self, threshold: &HotFunctionThreshold) -> bool {
self.call_count >= threshold.min_call_count
&& self.total_execution_ns >= threshold.min_total_time_ns
}
pub fn set_tier(&mut self, tier: CompilationTier) {
self.tier = tier;
self.last_optimized = Some(Instant::now());
}
}
#[derive(Debug, Clone)]
pub struct HotFunctionThreshold {
pub min_call_count: u64,
pub min_total_time_ns: u64,
pub min_reopt_interval: Duration,
}
impl Default for HotFunctionThreshold {
fn default() -> Self {
Self {
min_call_count: 100,
min_total_time_ns: 1_000_000, min_reopt_interval: Duration::from_millis(100),
}
}
}
impl HotFunctionThreshold {
pub fn interactive() -> Self {
Self {
min_call_count: 50,
min_total_time_ns: 500_000, min_reopt_interval: Duration::from_millis(50),
}
}
pub fn batch() -> Self {
Self {
min_call_count: 500,
min_total_time_ns: 5_000_000, min_reopt_interval: Duration::from_millis(500),
}
}
pub fn embedded() -> Self {
Self {
min_call_count: 200,
min_total_time_ns: 2_000_000, min_reopt_interval: Duration::from_millis(200),
}
}
}
#[derive(Debug, Clone)]
pub struct InlineCacheEntry {
pub type_hash: u64,
pub function_index: u32,
pub hit_count: u64,
pub miss_count: u64,
}
impl InlineCacheEntry {
pub fn new(type_hash: u64, function_index: u32) -> Self {
Self {
type_hash,
function_index,
hit_count: 0,
miss_count: 0,
}
}
pub fn hit(&mut self) {
self.hit_count += 1;
}
pub fn miss(&mut self) {
self.miss_count += 1;
}
pub fn hit_rate(&self) -> f64 {
let total = self.hit_count + self.miss_count;
if total == 0 {
0.0
} else {
self.hit_count as f64 / total as f64
}
}
}
#[derive(Debug, Clone)]
pub struct InlineCache {
entries: HashMap<u32, Vec<InlineCacheEntry>>,
max_entries_per_site: usize,
}
impl InlineCache {
pub fn new() -> Self {
Self {
entries: HashMap::new(),
max_entries_per_site: 4,
}
}
pub fn lookup(&mut self, call_site_id: u32, type_hash: u64) -> Option<u32> {
if let Some(entries) = self.entries.get_mut(&call_site_id) {
for entry in entries.iter_mut() {
if entry.type_hash == type_hash {
entry.hit();
return Some(entry.function_index);
}
}
for entry in entries.iter_mut() {
entry.miss();
}
}
None
}
pub fn insert(&mut self, call_site_id: u32, type_hash: u64, function_index: u32) {
let entries = self.entries.entry(call_site_id).or_default();
for entry in entries.iter_mut() {
if entry.type_hash == type_hash {
entry.function_index = function_index;
return;
}
}
if entries.len() < self.max_entries_per_site {
entries.push(InlineCacheEntry::new(type_hash, function_index));
} else {
if let Some((idx, _)) = entries.iter().enumerate().min_by(|(_, a), (_, b)| {
a.hit_rate()
.partial_cmp(&b.hit_rate())
.unwrap_or(std::cmp::Ordering::Equal)
}) {
entries[idx] = InlineCacheEntry::new(type_hash, function_index);
}
}
}
pub fn stats(&self) -> InlineCacheStats {
let mut total_hits = 0;
let mut total_misses = 0;
let mut total_entries = 0;
for entries in self.entries.values() {
for entry in entries {
total_hits += entry.hit_count;
total_misses += entry.miss_count;
total_entries += 1;
}
}
InlineCacheStats {
total_call_sites: self.entries.len(),
total_entries,
total_hits,
total_misses,
}
}
}
impl Default for InlineCache {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy)]
pub struct InlineCacheStats {
pub total_call_sites: usize,
pub total_entries: usize,
pub total_hits: u64,
pub total_misses: u64,
}
impl InlineCacheStats {
pub fn hit_rate(&self) -> f64 {
let total = self.total_hits + self.total_misses;
if total == 0 {
0.0
} else {
self.total_hits as f64 / total as f64
}
}
}
#[derive(Clone)]
pub struct ProfileGuidedOptimizer {
profiles: Arc<Mutex<HashMap<u32, FunctionProfile>>>,
threshold: HotFunctionThreshold,
inline_cache: Arc<Mutex<InlineCache>>,
profiling_enabled: bool,
}
impl ProfileGuidedOptimizer {
pub fn new(threshold: HotFunctionThreshold) -> Self {
Self {
profiles: Arc::new(Mutex::new(HashMap::new())),
threshold,
inline_cache: Arc::new(Mutex::new(InlineCache::new())),
profiling_enabled: true,
}
}
pub fn set_profiling_enabled(&mut self, enabled: bool) {
self.profiling_enabled = enabled;
}
pub fn record_call(&self, function_index: u32, duration_ns: u64) {
if !self.profiling_enabled {
return;
}
let mut profiles = self.profiles.lock().expect("JIT profiles lock poisoned");
let profile = profiles
.entry(function_index)
.or_insert_with(|| FunctionProfile::new(function_index));
profile.record_call(duration_ns);
}
pub fn get_hot_functions(&self) -> Vec<u32> {
let profiles = self.profiles.lock().expect("JIT profiles lock poisoned");
profiles
.values()
.filter(|p| p.is_hot(&self.threshold))
.map(|p| p.function_index)
.collect()
}
pub fn get_profiles(&self) -> HashMap<u32, FunctionProfile> {
self.profiles
.lock()
.expect("JIT profiles lock poisoned")
.clone()
}
pub fn set_tier(&self, function_index: u32, tier: CompilationTier) {
let mut profiles = self.profiles.lock().expect("JIT profiles lock poisoned");
if let Some(profile) = profiles.get_mut(&function_index) {
profile.set_tier(tier);
}
}
pub fn inline_cache(&self) -> Arc<Mutex<InlineCache>> {
Arc::clone(&self.inline_cache)
}
pub fn inline_cache_stats(&self) -> InlineCacheStats {
self.inline_cache
.lock()
.expect("Inline cache lock poisoned")
.stats()
}
pub fn reset(&self) {
self.profiles
.lock()
.expect("JIT profiles lock poisoned")
.clear();
*self
.inline_cache
.lock()
.expect("Inline cache lock poisoned") = InlineCache::new();
}
}
impl Default for ProfileGuidedOptimizer {
fn default() -> Self {
Self::new(HotFunctionThreshold::default())
}
}
pub struct JitConfig {
pub tiered_compilation: bool,
pub profile_guided_optimization: bool,
pub inline_caching: bool,
pub hot_threshold: HotFunctionThreshold,
pub initial_tier: CompilationTier,
}
impl JitConfig {
pub fn new() -> Self {
Self {
tiered_compilation: true,
profile_guided_optimization: true,
inline_caching: true,
hot_threshold: HotFunctionThreshold::default(),
initial_tier: CompilationTier::Baseline,
}
}
pub fn interactive() -> Self {
Self {
tiered_compilation: true,
profile_guided_optimization: true,
inline_caching: true,
hot_threshold: HotFunctionThreshold::interactive(),
initial_tier: CompilationTier::Baseline,
}
}
pub fn batch() -> Self {
Self {
tiered_compilation: true,
profile_guided_optimization: true,
inline_caching: true,
hot_threshold: HotFunctionThreshold::batch(),
initial_tier: CompilationTier::Optimized,
}
}
pub fn embedded() -> Self {
Self {
tiered_compilation: false,
profile_guided_optimization: false,
inline_caching: false,
hot_threshold: HotFunctionThreshold::embedded(),
initial_tier: CompilationTier::Baseline,
}
}
pub fn no_optimization() -> Self {
Self {
tiered_compilation: false,
profile_guided_optimization: false,
inline_caching: false,
hot_threshold: HotFunctionThreshold::default(),
initial_tier: CompilationTier::Interpreter,
}
}
pub fn apply_to_config(&self, config: &mut Config) {
match self.initial_tier {
CompilationTier::Interpreter => {
config.strategy(Strategy::Auto);
config.cranelift_opt_level(OptLevel::None);
}
CompilationTier::Baseline => {
config.strategy(Strategy::Cranelift);
config.cranelift_opt_level(OptLevel::Speed);
}
CompilationTier::Optimized => {
config.strategy(Strategy::Cranelift);
config.cranelift_opt_level(OptLevel::SpeedAndSize);
}
}
}
}
impl Default for JitConfig {
fn default() -> Self {
Self::new()
}
}
pub struct JitOptimizer {
config: JitConfig,
pgo: ProfileGuidedOptimizer,
}
impl JitOptimizer {
pub fn new(config: JitConfig) -> Self {
let pgo = ProfileGuidedOptimizer::new(config.hot_threshold.clone());
Self { config, pgo }
}
pub fn with_default() -> Self {
Self::new(JitConfig::default())
}
pub fn for_interactive() -> Self {
Self::new(JitConfig::interactive())
}
pub fn for_batch() -> Self {
Self::new(JitConfig::batch())
}
pub fn for_embedded() -> Self {
Self::new(JitConfig::embedded())
}
pub fn config(&self) -> &JitConfig {
&self.config
}
pub fn pgo(&self) -> &ProfileGuidedOptimizer {
&self.pgo
}
pub fn record_execution(&self, function_index: u32, duration: Duration) {
if self.config.profile_guided_optimization {
self.pgo
.record_call(function_index, duration.as_nanos() as u64);
}
}
pub fn get_recompilation_candidates(&self) -> Vec<(u32, CompilationTier)> {
if !self.config.tiered_compilation {
return Vec::new();
}
let hot_functions = self.pgo.get_hot_functions();
let profiles = self.pgo.get_profiles();
hot_functions
.into_iter()
.filter_map(|idx| {
let profile = profiles.get(&idx)?;
let target_tier = match profile.tier {
CompilationTier::Interpreter => CompilationTier::Baseline,
CompilationTier::Baseline => CompilationTier::Optimized,
CompilationTier::Optimized => return None,
};
Some((idx, target_tier))
})
.collect()
}
pub fn configure_engine(&self, config: &mut Config) {
self.config.apply_to_config(config);
}
}
impl Default for JitOptimizer {
fn default() -> Self {
Self::with_default()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compilation_tier_ordering() {
assert!(CompilationTier::Baseline.is_higher_than(&CompilationTier::Interpreter));
assert!(CompilationTier::Optimized.is_higher_than(&CompilationTier::Baseline));
assert!(!CompilationTier::Interpreter.is_higher_than(&CompilationTier::Baseline));
}
#[test]
fn test_function_profile_recording() {
let mut profile = FunctionProfile::new(0);
assert_eq!(profile.call_count, 0);
assert_eq!(profile.total_execution_ns, 0);
profile.record_call(1000);
assert_eq!(profile.call_count, 1);
assert_eq!(profile.total_execution_ns, 1000);
assert_eq!(profile.avg_execution_ns, 1000);
profile.record_call(2000);
assert_eq!(profile.call_count, 2);
assert_eq!(profile.total_execution_ns, 3000);
assert_eq!(profile.avg_execution_ns, 1500);
assert_eq!(profile.peak_execution_ns, 2000);
}
#[test]
fn test_hot_function_detection() {
let threshold = HotFunctionThreshold::default();
let mut profile = FunctionProfile::new(0);
assert!(!profile.is_hot(&threshold));
for _ in 0..100 {
profile.record_call(10_000);
}
assert!(profile.is_hot(&threshold));
}
#[test]
fn test_inline_cache_basic() {
let mut cache = InlineCache::new();
assert_eq!(cache.lookup(0, 123), None);
cache.insert(0, 123, 456);
assert_eq!(cache.lookup(0, 123), Some(456));
assert_eq!(cache.lookup(0, 999), None);
}
#[test]
fn test_inline_cache_eviction() {
let mut cache = InlineCache::new();
cache.max_entries_per_site = 2;
cache.insert(0, 1, 10);
cache.insert(0, 2, 20);
assert_eq!(cache.lookup(0, 1), Some(10));
assert_eq!(cache.lookup(0, 2), Some(20));
cache.insert(0, 3, 30);
let entries = &cache.entries[&0];
assert_eq!(entries.len(), 2);
}
#[test]
fn test_inline_cache_stats() {
let mut cache = InlineCache::new();
cache.insert(0, 1, 10);
cache.lookup(0, 1);
cache.lookup(0, 2);
let stats = cache.stats();
assert_eq!(stats.total_hits, 1);
assert!(stats.total_misses > 0);
}
#[test]
fn test_pgo_basic() {
let pgo = ProfileGuidedOptimizer::default();
pgo.record_call(0, 1000);
pgo.record_call(0, 2000);
let profiles = pgo.get_profiles();
assert_eq!(profiles.len(), 1);
assert_eq!(profiles[&0].call_count, 2);
}
#[test]
fn test_pgo_hot_functions() {
let threshold = HotFunctionThreshold {
min_call_count: 5,
min_total_time_ns: 10_000,
min_reopt_interval: Duration::from_millis(0),
};
let pgo = ProfileGuidedOptimizer::new(threshold);
for _ in 0..10 {
pgo.record_call(0, 2000);
}
pgo.record_call(1, 10_000);
for _ in 0..10 {
pgo.record_call(2, 100);
}
let hot = pgo.get_hot_functions();
assert_eq!(hot.len(), 1);
assert!(hot.contains(&0));
}
#[test]
fn test_jit_config_presets() {
let interactive = JitConfig::interactive();
assert!(interactive.tiered_compilation);
assert!(interactive.profile_guided_optimization);
let batch = JitConfig::batch();
assert_eq!(batch.initial_tier, CompilationTier::Optimized);
let embedded = JitConfig::embedded();
assert!(!embedded.tiered_compilation);
let no_opt = JitConfig::no_optimization();
assert!(!no_opt.tiered_compilation);
assert_eq!(no_opt.initial_tier, CompilationTier::Interpreter);
}
#[test]
fn test_jit_optimizer_recompilation_candidates() {
let config = JitConfig {
tiered_compilation: true,
profile_guided_optimization: true,
inline_caching: false,
hot_threshold: HotFunctionThreshold {
min_call_count: 5,
min_total_time_ns: 10_000,
min_reopt_interval: Duration::from_millis(0),
},
initial_tier: CompilationTier::Interpreter,
};
let optimizer = JitOptimizer::new(config);
for _ in 0..10 {
optimizer.record_execution(0, Duration::from_micros(2));
}
let candidates = optimizer.get_recompilation_candidates();
assert_eq!(candidates.len(), 1);
assert_eq!(candidates[0].0, 0);
assert_eq!(candidates[0].1, CompilationTier::Baseline);
}
#[test]
fn test_jit_optimizer_with_tiered_compilation_disabled() {
let config = JitConfig {
tiered_compilation: false,
profile_guided_optimization: true,
inline_caching: false,
hot_threshold: HotFunctionThreshold::default(),
initial_tier: CompilationTier::Baseline,
};
let optimizer = JitOptimizer::new(config);
for _ in 0..1000 {
optimizer.record_execution(0, Duration::from_micros(10));
}
let candidates = optimizer.get_recompilation_candidates();
assert_eq!(candidates.len(), 0);
}
}