1use serde::{Deserialize, Serialize};
23use std::collections::BTreeMap;
24
25use crate::hostcall_superinstructions::HostcallSuperinstructionPlan;
26
27const DEFAULT_MIN_JIT_EXECUTIONS: u64 = 8;
31const DEFAULT_MAX_COMPILED_TRACES: usize = 64;
33const DEFAULT_MAX_GUARD_FAILURES: u64 = 4;
35const JIT_DISPATCH_COST_UNITS: i64 = 3;
37const JIT_DISPATCH_STEP_COST_UNITS: i64 = 1;
39
40#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
44pub struct TraceJitConfig {
45 pub enabled: bool,
47 pub min_jit_executions: u64,
49 pub max_compiled_traces: usize,
51 pub max_guard_failures: u64,
53}
54
55impl Default for TraceJitConfig {
56 fn default() -> Self {
57 Self::from_env()
58 }
59}
60
61impl TraceJitConfig {
62 #[must_use]
64 pub const fn new(
65 enabled: bool,
66 min_jit_executions: u64,
67 max_compiled_traces: usize,
68 max_guard_failures: u64,
69 ) -> Self {
70 Self {
71 enabled,
72 min_jit_executions,
73 max_compiled_traces,
74 max_guard_failures,
75 }
76 }
77
78 #[must_use]
80 pub fn from_env() -> Self {
81 let enabled = bool_from_env("PI_HOSTCALL_TRACE_JIT", true);
82 let min_jit_executions = u64_from_env(
83 "PI_HOSTCALL_TRACE_JIT_MIN_EXECUTIONS",
84 DEFAULT_MIN_JIT_EXECUTIONS,
85 );
86 let max_compiled_traces = usize_from_env(
87 "PI_HOSTCALL_TRACE_JIT_MAX_TRACES",
88 DEFAULT_MAX_COMPILED_TRACES,
89 );
90 let max_guard_failures = u64_from_env(
91 "PI_HOSTCALL_TRACE_JIT_MAX_GUARD_FAILURES",
92 DEFAULT_MAX_GUARD_FAILURES,
93 );
94 Self::new(
95 enabled,
96 min_jit_executions,
97 max_compiled_traces,
98 max_guard_failures,
99 )
100 }
101}
102
103#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
107pub enum TraceGuard {
108 OpcodePrefix(Vec<String>),
110 SafetyEnvelopeNotVetoing,
112 MinSupportCount(u32),
114}
115
116impl TraceGuard {
117 #[must_use]
119 pub fn check(&self, trace: &[String], ctx: &GuardContext) -> bool {
120 match self {
121 Self::OpcodePrefix(window) => {
122 trace.len() >= window.len()
123 && trace
124 .iter()
125 .zip(window.iter())
126 .all(|(actual, expected)| actual == expected)
127 }
128 Self::SafetyEnvelopeNotVetoing => !ctx.safety_envelope_vetoing,
129 Self::MinSupportCount(min) => ctx.current_support_count >= *min,
130 }
131 }
132}
133
134#[derive(Debug, Clone, Default)]
136pub struct GuardContext {
137 pub safety_envelope_vetoing: bool,
139 pub current_support_count: u32,
141}
142
143#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
147pub enum CompilationTier {
148 Superinstruction,
150 TraceJit,
152}
153
154#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
156pub struct CompiledTrace {
157 pub plan_id: String,
159 pub trace_signature: String,
161 pub guards: Vec<TraceGuard>,
163 pub opcode_window: Vec<String>,
165 pub width: usize,
167 pub estimated_cost_jit: i64,
169 pub estimated_cost_fused: i64,
171 pub tier_improvement_delta: i64,
173 pub tier: CompilationTier,
175}
176
177impl CompiledTrace {
178 #[must_use]
180 pub fn from_plan(plan: &HostcallSuperinstructionPlan) -> Self {
181 let width = plan.width();
182 let estimated_cost_jit = estimated_jit_cost(width);
183 let tier_improvement_delta = plan.estimated_cost_fused.saturating_sub(estimated_cost_jit);
184
185 let guards = vec![
186 TraceGuard::OpcodePrefix(plan.opcode_window.clone()),
187 TraceGuard::SafetyEnvelopeNotVetoing,
188 TraceGuard::MinSupportCount(plan.support_count / 2),
189 ];
190
191 Self {
192 plan_id: plan.plan_id.clone(),
193 trace_signature: plan.trace_signature.clone(),
194 guards,
195 opcode_window: plan.opcode_window.clone(),
196 width,
197 estimated_cost_jit,
198 estimated_cost_fused: plan.estimated_cost_fused,
199 tier_improvement_delta,
200 tier: CompilationTier::TraceJit,
201 }
202 }
203
204 #[must_use]
206 pub fn guards_pass(&self, trace: &[String], ctx: &GuardContext) -> bool {
207 self.guards.iter().all(|guard| guard.check(trace, ctx))
208 }
209}
210
211#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
215pub enum DeoptReason {
216 GuardFailure {
218 guard_index: usize,
220 description: String,
222 },
223 TraceInvalidated {
225 total_failures: u64,
227 },
228 JitDisabled,
230 NotCompiled,
232 SafetyVeto,
234}
235
236#[derive(Debug, Clone)]
238pub struct JitExecutionResult {
239 pub jit_hit: bool,
241 pub plan_id: Option<String>,
243 pub deopt_reason: Option<DeoptReason>,
245 pub cost_delta: i64,
247}
248
249#[derive(Debug, Clone, Default)]
253struct PlanProfile {
254 execution_count: u64,
256 consecutive_guard_failures: u64,
258 invalidated: bool,
260 last_access_generation: u64,
262}
263
264#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
268pub struct TraceJitTelemetry {
269 pub plans_evaluated: u64,
271 pub traces_compiled: u64,
273 pub jit_hits: u64,
275 pub jit_misses: u64,
277 pub deopts: u64,
279 pub invalidations: u64,
281 pub evictions: u64,
283 pub cache_size: u64,
285}
286
287#[derive(Debug, Clone)]
294pub struct TraceJitCompiler {
295 config: TraceJitConfig,
296 cache: BTreeMap<String, CompiledTrace>,
298 profiles: BTreeMap<String, PlanProfile>,
300 generation: u64,
302 telemetry: TraceJitTelemetry,
304}
305
306impl Default for TraceJitCompiler {
307 fn default() -> Self {
308 Self::new(TraceJitConfig::default())
309 }
310}
311
312impl TraceJitCompiler {
313 #[must_use]
315 pub fn new(config: TraceJitConfig) -> Self {
316 Self {
317 config,
318 cache: BTreeMap::new(),
319 profiles: BTreeMap::new(),
320 generation: 0,
321 telemetry: TraceJitTelemetry::default(),
322 }
323 }
324
325 #[must_use]
327 pub const fn enabled(&self) -> bool {
328 self.config.enabled
329 }
330
331 #[must_use]
333 pub const fn config(&self) -> &TraceJitConfig {
334 &self.config
335 }
336
337 #[must_use]
339 pub const fn telemetry(&self) -> &TraceJitTelemetry {
340 &self.telemetry
341 }
342
343 #[must_use]
345 pub fn cache_size(&self) -> usize {
346 self.cache.len()
347 }
348
349 pub fn record_plan_execution(&mut self, plan: &HostcallSuperinstructionPlan) -> bool {
354 if !self.config.enabled {
355 return false;
356 }
357
358 self.telemetry.plans_evaluated += 1;
359 self.generation += 1;
360
361 let profile = self.profiles.entry(plan.plan_id.clone()).or_default();
362
363 profile.execution_count += 1;
364 profile.last_access_generation = self.generation;
365
366 if profile.invalidated {
367 return false;
368 }
369
370 if profile.execution_count >= self.config.min_jit_executions
372 && !self.cache.contains_key(&plan.plan_id)
373 {
374 self.compile_trace(plan);
375 return true;
376 }
377
378 false
379 }
380
381 fn compile_trace(&mut self, plan: &HostcallSuperinstructionPlan) {
383 if self.cache.len() >= self.config.max_compiled_traces {
385 self.evict_lru();
386 }
387
388 let compiled = CompiledTrace::from_plan(plan);
389 self.cache.insert(plan.plan_id.clone(), compiled);
390 self.telemetry.traces_compiled += 1;
391 self.telemetry.cache_size = u64::try_from(self.cache.len()).unwrap_or(u64::MAX);
392 }
393
394 fn evict_lru(&mut self) {
396 let lru_plan_id = self
397 .cache
398 .keys()
399 .min_by_key(|plan_id| {
400 self.profiles
401 .get(*plan_id)
402 .map_or(0, |profile| profile.last_access_generation)
403 })
404 .cloned();
405
406 if let Some(plan_id) = lru_plan_id {
407 self.cache.remove(&plan_id);
408 self.telemetry.evictions += 1;
409 self.telemetry.cache_size = u64::try_from(self.cache.len()).unwrap_or(u64::MAX);
410 }
411 }
412
413 pub fn try_jit_dispatch(
418 &mut self,
419 plan_id: &str,
420 trace: &[String],
421 ctx: &GuardContext,
422 ) -> JitExecutionResult {
423 if !self.config.enabled {
424 return JitExecutionResult {
425 jit_hit: false,
426 plan_id: Some(plan_id.to_string()),
427 deopt_reason: Some(DeoptReason::JitDisabled),
428 cost_delta: 0,
429 };
430 }
431
432 self.generation += 1;
433
434 let compiled = match self.cache.get(plan_id) {
436 Some(compiled) => compiled.clone(),
437 None => {
438 return JitExecutionResult {
439 jit_hit: false,
440 plan_id: Some(plan_id.to_string()),
441 deopt_reason: Some(DeoptReason::NotCompiled),
442 cost_delta: 0,
443 };
444 }
445 };
446
447 if let Some(profile) = self.profiles.get_mut(plan_id) {
449 profile.last_access_generation = self.generation;
450 }
451
452 for (idx, guard) in compiled.guards.iter().enumerate() {
454 if !guard.check(trace, ctx) {
455 let invalidated_after_failures = self.record_guard_failure(plan_id);
456 let description = match guard {
457 TraceGuard::OpcodePrefix(_) => "opcode_prefix_mismatch",
458 TraceGuard::SafetyEnvelopeNotVetoing => "safety_envelope_vetoing",
459 TraceGuard::MinSupportCount(_) => "support_count_below_threshold",
460 };
461 let deopt_reason = invalidated_after_failures.map_or_else(
462 || DeoptReason::GuardFailure {
463 guard_index: idx,
464 description: description.to_string(),
465 },
466 |total_failures| DeoptReason::TraceInvalidated { total_failures },
467 );
468 return JitExecutionResult {
469 jit_hit: false,
470 plan_id: Some(plan_id.to_string()),
471 deopt_reason: Some(deopt_reason),
472 cost_delta: 0,
473 };
474 }
475 }
476
477 self.telemetry.jit_hits += 1;
479 if let Some(profile) = self.profiles.get_mut(plan_id) {
480 profile.consecutive_guard_failures = 0;
481 }
482
483 JitExecutionResult {
484 jit_hit: true,
485 plan_id: Some(plan_id.to_string()),
486 deopt_reason: None,
487 cost_delta: compiled.tier_improvement_delta,
488 }
489 }
490
491 fn record_guard_failure(&mut self, plan_id: &str) -> Option<u64> {
493 self.telemetry.deopts += 1;
494 self.telemetry.jit_misses += 1;
495
496 if let Some(profile) = self.profiles.get_mut(plan_id) {
497 profile.consecutive_guard_failures += 1;
498 if !profile.invalidated
499 && profile.consecutive_guard_failures >= self.config.max_guard_failures
500 {
501 profile.invalidated = true;
502 self.cache.remove(plan_id);
503 self.telemetry.invalidations += 1;
504 self.telemetry.cache_size = u64::try_from(self.cache.len()).unwrap_or(u64::MAX);
505 return Some(profile.consecutive_guard_failures);
506 }
507 }
508 None
509 }
510
511 #[must_use]
513 pub fn get_compiled_trace(&self, plan_id: &str) -> Option<&CompiledTrace> {
514 self.cache.get(plan_id)
515 }
516
517 #[must_use]
519 pub fn is_invalidated(&self, plan_id: &str) -> bool {
520 self.profiles
521 .get(plan_id)
522 .is_some_and(|profile| profile.invalidated)
523 }
524
525 pub fn reset(&mut self) {
527 self.cache.clear();
528 self.profiles.clear();
529 self.generation = 0;
530 self.telemetry = TraceJitTelemetry::default();
531 }
532}
533
534#[must_use]
538pub fn estimated_jit_cost(width: usize) -> i64 {
539 let width_units = i64::try_from(width).unwrap_or(i64::MAX);
540 JIT_DISPATCH_COST_UNITS.saturating_add(width_units.saturating_mul(JIT_DISPATCH_STEP_COST_UNITS))
541}
542
543fn bool_from_env(var: &str, default: bool) -> bool {
546 std::env::var(var).ok().as_deref().map_or(default, |value| {
547 !matches!(
548 value.trim().to_ascii_lowercase().as_str(),
549 "0" | "false" | "off" | "disabled"
550 )
551 })
552}
553
554fn u64_from_env(var: &str, default: u64) -> u64 {
555 std::env::var(var)
556 .ok()
557 .and_then(|raw| raw.trim().parse::<u64>().ok())
558 .unwrap_or(default)
559}
560
561fn usize_from_env(var: &str, default: usize) -> usize {
562 std::env::var(var)
563 .ok()
564 .and_then(|raw| raw.trim().parse::<usize>().ok())
565 .unwrap_or(default)
566}
567
568#[cfg(test)]
571mod tests {
572 use super::*;
573 use crate::hostcall_superinstructions::{
574 HOSTCALL_SUPERINSTRUCTION_PLAN_VERSION, HOSTCALL_SUPERINSTRUCTION_SCHEMA_VERSION,
575 HostcallSuperinstructionPlan,
576 };
577
578 fn make_plan(
579 plan_id: &str,
580 window: &[&str],
581 support_count: u32,
582 ) -> HostcallSuperinstructionPlan {
583 let opcode_window: Vec<String> = window.iter().map(ToString::to_string).collect();
584 let width = opcode_window.len();
585 HostcallSuperinstructionPlan {
586 schema: HOSTCALL_SUPERINSTRUCTION_SCHEMA_VERSION.to_string(),
587 version: HOSTCALL_SUPERINSTRUCTION_PLAN_VERSION,
588 plan_id: plan_id.to_string(),
589 trace_signature: format!("sig_{plan_id}"),
590 opcode_window,
591 support_count,
592 estimated_cost_baseline: i64::try_from(width).unwrap_or(0) * 10,
593 estimated_cost_fused: 6 + i64::try_from(width).unwrap_or(0) * 2,
594 expected_cost_delta: i64::try_from(width).unwrap_or(0) * 8 - 6,
595 }
596 }
597
598 fn trace(opcodes: &[&str]) -> Vec<String> {
599 opcodes.iter().map(ToString::to_string).collect()
600 }
601
602 fn default_ctx() -> GuardContext {
603 GuardContext {
604 safety_envelope_vetoing: false,
605 current_support_count: 100,
606 }
607 }
608
609 #[test]
612 fn config_default_values() {
613 let config = TraceJitConfig::new(true, 8, 64, 4);
614 assert!(config.enabled);
615 assert_eq!(config.min_jit_executions, 8);
616 assert_eq!(config.max_compiled_traces, 64);
617 assert_eq!(config.max_guard_failures, 4);
618 }
619
620 #[test]
621 fn config_disabled_prevents_compilation() {
622 let config = TraceJitConfig::new(false, 1, 64, 4);
623 let mut jit = TraceJitCompiler::new(config);
624 let plan = make_plan("p1", &["a", "b"], 10);
625
626 let promoted = jit.record_plan_execution(&plan);
627 assert!(!promoted);
628 assert_eq!(jit.cache_size(), 0);
629 }
630
631 #[test]
634 fn plan_promoted_after_reaching_threshold() {
635 let config = TraceJitConfig::new(true, 3, 64, 4);
636 let mut jit = TraceJitCompiler::new(config);
637 let plan = make_plan("p1", &["session.get_state", "session.get_messages"], 10);
638
639 assert!(!jit.record_plan_execution(&plan));
640 assert!(!jit.record_plan_execution(&plan));
641 assert!(jit.record_plan_execution(&plan)); assert_eq!(jit.cache_size(), 1);
643
644 assert!(!jit.record_plan_execution(&plan));
646 assert_eq!(jit.telemetry().traces_compiled, 1);
647 }
648
649 #[test]
650 fn plan_not_promoted_before_threshold() {
651 let config = TraceJitConfig::new(true, 10, 64, 4);
652 let mut jit = TraceJitCompiler::new(config);
653 let plan = make_plan("p1", &["a", "b"], 5);
654
655 for _ in 0..9 {
656 assert!(!jit.record_plan_execution(&plan));
657 }
658 assert_eq!(jit.cache_size(), 0);
659 assert!(jit.record_plan_execution(&plan)); assert_eq!(jit.cache_size(), 1);
661 }
662
663 #[test]
666 fn guard_opcode_prefix_passes_on_match() {
667 let guard = TraceGuard::OpcodePrefix(trace(&["a", "b"]));
668 let ctx = default_ctx();
669 assert!(guard.check(&trace(&["a", "b", "c"]), &ctx));
670 assert!(guard.check(&trace(&["a", "b"]), &ctx));
671 }
672
673 #[test]
674 fn guard_opcode_prefix_fails_on_mismatch() {
675 let guard = TraceGuard::OpcodePrefix(trace(&["a", "b"]));
676 let ctx = default_ctx();
677 assert!(!guard.check(&trace(&["a", "c"]), &ctx));
678 assert!(!guard.check(&trace(&["a"]), &ctx));
679 assert!(!guard.check(&trace(&[]), &ctx));
680 }
681
682 #[test]
683 fn guard_safety_envelope_passes_when_not_vetoing() {
684 let guard = TraceGuard::SafetyEnvelopeNotVetoing;
685 let ctx = GuardContext {
686 safety_envelope_vetoing: false,
687 ..default_ctx()
688 };
689 assert!(guard.check(&[], &ctx));
690 }
691
692 #[test]
693 fn guard_safety_envelope_fails_when_vetoing() {
694 let guard = TraceGuard::SafetyEnvelopeNotVetoing;
695 let ctx = GuardContext {
696 safety_envelope_vetoing: true,
697 ..default_ctx()
698 };
699 assert!(!guard.check(&[], &ctx));
700 }
701
702 #[test]
703 fn guard_min_support_count_passes() {
704 let guard = TraceGuard::MinSupportCount(5);
705 let ctx = GuardContext {
706 current_support_count: 10,
707 ..default_ctx()
708 };
709 assert!(guard.check(&[], &ctx));
710 }
711
712 #[test]
713 fn guard_min_support_count_fails() {
714 let guard = TraceGuard::MinSupportCount(5);
715 let ctx = GuardContext {
716 current_support_count: 3,
717 ..default_ctx()
718 };
719 assert!(!guard.check(&[], &ctx));
720 }
721
722 #[test]
725 fn compiled_trace_from_plan_sets_tier() {
726 let plan = make_plan("p1", &["a", "b", "c"], 10);
727 let compiled = CompiledTrace::from_plan(&plan);
728
729 assert_eq!(compiled.plan_id, "p1");
730 assert_eq!(compiled.tier, CompilationTier::TraceJit);
731 assert_eq!(compiled.width, 3);
732 assert_eq!(compiled.guards.len(), 3);
733 }
734
735 #[test]
736 fn compiled_trace_cost_lower_than_fused() {
737 let plan = make_plan("p1", &["a", "b", "c"], 10);
738 let compiled = CompiledTrace::from_plan(&plan);
739
740 assert!(
741 compiled.estimated_cost_jit < compiled.estimated_cost_fused,
742 "JIT cost ({}) should be less than fused cost ({})",
743 compiled.estimated_cost_jit,
744 compiled.estimated_cost_fused
745 );
746 assert!(compiled.tier_improvement_delta > 0);
747 }
748
749 #[test]
750 fn compiled_trace_guards_pass_on_matching_trace() {
751 let plan = make_plan("p1", &["a", "b"], 10);
752 let compiled = CompiledTrace::from_plan(&plan);
753 let ctx = default_ctx();
754
755 assert!(compiled.guards_pass(&trace(&["a", "b", "c"]), &ctx));
756 }
757
758 #[test]
759 fn compiled_trace_guards_fail_on_wrong_prefix() {
760 let plan = make_plan("p1", &["a", "b"], 10);
761 let compiled = CompiledTrace::from_plan(&plan);
762 let ctx = default_ctx();
763
764 assert!(!compiled.guards_pass(&trace(&["x", "y"]), &ctx));
765 }
766
767 #[test]
768 fn compiled_trace_guards_fail_on_safety_veto() {
769 let plan = make_plan("p1", &["a", "b"], 10);
770 let compiled = CompiledTrace::from_plan(&plan);
771 let ctx = GuardContext {
772 safety_envelope_vetoing: true,
773 ..default_ctx()
774 };
775
776 assert!(!compiled.guards_pass(&trace(&["a", "b"]), &ctx));
777 }
778
779 #[test]
782 fn jit_dispatch_hits_after_promotion() {
783 let config = TraceJitConfig::new(true, 2, 64, 4);
784 let mut jit = TraceJitCompiler::new(config);
785 let plan = make_plan("p1", &["a", "b"], 10);
786
787 jit.record_plan_execution(&plan);
789 jit.record_plan_execution(&plan);
790 assert_eq!(jit.cache_size(), 1);
791
792 let result = jit.try_jit_dispatch("p1", &trace(&["a", "b", "c"]), &default_ctx());
794 assert!(result.jit_hit);
795 assert!(result.deopt_reason.is_none());
796 assert!(result.cost_delta > 0);
797 assert_eq!(jit.telemetry().jit_hits, 1);
798 }
799
800 #[test]
801 fn jit_dispatch_returns_not_compiled_before_promotion() {
802 let config = TraceJitConfig::new(true, 10, 64, 4);
803 let mut jit = TraceJitCompiler::new(config);
804
805 let result = jit.try_jit_dispatch("p1", &trace(&["a", "b"]), &default_ctx());
806 assert!(!result.jit_hit);
807 assert_eq!(result.deopt_reason, Some(DeoptReason::NotCompiled));
808 }
809
810 #[test]
811 fn jit_dispatch_deopt_on_guard_failure() {
812 let config = TraceJitConfig::new(true, 1, 64, 4);
813 let mut jit = TraceJitCompiler::new(config);
814 let plan = make_plan("p1", &["a", "b"], 10);
815 jit.record_plan_execution(&plan);
816
817 let result = jit.try_jit_dispatch("p1", &trace(&["x", "y"]), &default_ctx());
819 assert!(!result.jit_hit);
820 assert!(matches!(
821 result.deopt_reason,
822 Some(DeoptReason::GuardFailure { guard_index: 0, .. })
823 ));
824 assert_eq!(jit.telemetry().deopts, 1);
825 }
826
827 #[test]
828 fn jit_dispatch_deopt_on_safety_veto() {
829 let config = TraceJitConfig::new(true, 1, 64, 4);
830 let mut jit = TraceJitCompiler::new(config);
831 let plan = make_plan("p1", &["a", "b"], 10);
832 jit.record_plan_execution(&plan);
833
834 let ctx = GuardContext {
835 safety_envelope_vetoing: true,
836 ..default_ctx()
837 };
838 let result = jit.try_jit_dispatch("p1", &trace(&["a", "b"]), &ctx);
839 assert!(!result.jit_hit);
840 assert!(matches!(
841 result.deopt_reason,
842 Some(DeoptReason::GuardFailure { guard_index: 1, .. })
843 ));
844 }
845
846 #[test]
847 fn jit_dispatch_deopt_on_support_count_guard() {
848 let config = TraceJitConfig::new(true, 1, 64, 4);
849 let mut jit = TraceJitCompiler::new(config);
850 let plan = make_plan("p1", &["a", "b"], 20);
851 jit.record_plan_execution(&plan);
852
853 let ctx = GuardContext {
854 safety_envelope_vetoing: false,
855 current_support_count: 9, };
857 let result = jit.try_jit_dispatch("p1", &trace(&["a", "b"]), &ctx);
858 assert!(!result.jit_hit);
859 assert_eq!(
860 result.deopt_reason,
861 Some(DeoptReason::GuardFailure {
862 guard_index: 2,
863 description: "support_count_below_threshold".to_string(),
864 })
865 );
866 }
867
868 #[test]
869 fn jit_dispatch_disabled_returns_jit_disabled() {
870 let config = TraceJitConfig::new(false, 1, 64, 4);
871 let mut jit = TraceJitCompiler::new(config);
872
873 let result = jit.try_jit_dispatch("p1", &trace(&["a"]), &default_ctx());
874 assert!(!result.jit_hit);
875 assert_eq!(result.deopt_reason, Some(DeoptReason::JitDisabled));
876 }
877
878 #[test]
881 fn trace_invalidated_after_max_guard_failures() {
882 let config = TraceJitConfig::new(true, 1, 64, 3);
883 let mut jit = TraceJitCompiler::new(config);
884 let plan = make_plan("p1", &["a", "b"], 10);
885 jit.record_plan_execution(&plan);
886 assert_eq!(jit.cache_size(), 1);
887
888 for _ in 0..3 {
890 jit.try_jit_dispatch("p1", &trace(&["x"]), &default_ctx());
891 }
892
893 assert!(jit.is_invalidated("p1"));
894 assert_eq!(jit.cache_size(), 0);
895 assert_eq!(jit.telemetry().invalidations, 1);
896
897 assert!(!jit.record_plan_execution(&plan));
899 }
900
901 #[test]
902 fn threshold_crossing_failure_reports_trace_invalidated() {
903 let config = TraceJitConfig::new(true, 1, 64, 2);
904 let mut jit = TraceJitCompiler::new(config);
905 let plan = make_plan("p1", &["a", "b"], 10);
906 jit.record_plan_execution(&plan);
907
908 let first = jit.try_jit_dispatch("p1", &trace(&["x"]), &default_ctx());
909 assert_eq!(
910 first.deopt_reason,
911 Some(DeoptReason::GuardFailure {
912 guard_index: 0,
913 description: "opcode_prefix_mismatch".to_string(),
914 })
915 );
916
917 let second = jit.try_jit_dispatch("p1", &trace(&["x"]), &default_ctx());
918 assert_eq!(
919 second.deopt_reason,
920 Some(DeoptReason::TraceInvalidated { total_failures: 2 })
921 );
922
923 assert!(jit.is_invalidated("p1"));
924 assert_eq!(jit.cache_size(), 0);
925
926 let after_invalidation = jit.try_jit_dispatch("p1", &trace(&["x"]), &default_ctx());
927 assert_eq!(
928 after_invalidation.deopt_reason,
929 Some(DeoptReason::NotCompiled)
930 );
931
932 let telemetry = jit.telemetry();
933 assert_eq!(telemetry.deopts, 2);
934 assert_eq!(telemetry.jit_misses, 2);
935 assert_eq!(telemetry.invalidations, 1);
936 }
937
938 #[test]
939 fn guard_failure_counter_resets_on_success() {
940 let config = TraceJitConfig::new(true, 1, 64, 3);
941 let mut jit = TraceJitCompiler::new(config);
942 let plan = make_plan("p1", &["a", "b"], 10);
943 jit.record_plan_execution(&plan);
944
945 jit.try_jit_dispatch("p1", &trace(&["x"]), &default_ctx());
947 jit.try_jit_dispatch("p1", &trace(&["x"]), &default_ctx());
948 let result = jit.try_jit_dispatch("p1", &trace(&["a", "b"]), &default_ctx());
949 assert!(result.jit_hit); jit.try_jit_dispatch("p1", &trace(&["x"]), &default_ctx());
951 jit.try_jit_dispatch("p1", &trace(&["x"]), &default_ctx());
952
953 assert!(!jit.is_invalidated("p1"));
955 assert_eq!(jit.cache_size(), 1);
956 }
957
958 #[test]
961 fn lru_eviction_when_cache_full() {
962 let config = TraceJitConfig::new(true, 1, 2, 4);
963 let mut jit = TraceJitCompiler::new(config);
964
965 let p1 = make_plan("p1", &["a", "b"], 10);
966 let p2 = make_plan("p2", &["c", "d"], 10);
967 let p3 = make_plan("p3", &["e", "f"], 10);
968
969 jit.record_plan_execution(&p1); jit.record_plan_execution(&p2); assert_eq!(jit.cache_size(), 2);
972
973 jit.try_jit_dispatch("p2", &trace(&["c", "d"]), &default_ctx());
975
976 jit.record_plan_execution(&p3);
978 assert_eq!(jit.cache_size(), 2);
979 assert!(jit.get_compiled_trace("p1").is_none());
980 assert!(jit.get_compiled_trace("p2").is_some());
981 assert!(jit.get_compiled_trace("p3").is_some());
982 assert_eq!(jit.telemetry().evictions, 1);
983 }
984
985 #[test]
988 fn telemetry_tracks_all_counters() {
989 let config = TraceJitConfig::new(true, 2, 64, 4);
990 let mut jit = TraceJitCompiler::new(config);
991 let plan = make_plan("p1", &["a", "b"], 10);
992
993 jit.record_plan_execution(&plan);
995 jit.record_plan_execution(&plan);
996
997 jit.try_jit_dispatch("p1", &trace(&["a", "b"]), &default_ctx());
999 jit.try_jit_dispatch("p1", &trace(&["x"]), &default_ctx());
1001
1002 let t = jit.telemetry();
1003 assert_eq!(t.plans_evaluated, 2);
1004 assert_eq!(t.traces_compiled, 1);
1005 assert_eq!(t.jit_hits, 1);
1006 assert_eq!(t.jit_misses, 1);
1007 assert_eq!(t.deopts, 1);
1008 assert_eq!(t.cache_size, 1);
1009 }
1010
1011 #[test]
1012 fn telemetry_serializes_round_trip() {
1013 let telemetry = TraceJitTelemetry {
1014 plans_evaluated: 100,
1015 traces_compiled: 10,
1016 jit_hits: 50,
1017 jit_misses: 5,
1018 deopts: 5,
1019 invalidations: 1,
1020 evictions: 2,
1021 cache_size: 8,
1022 };
1023
1024 let json = serde_json::to_string(&telemetry).expect("serialize");
1025 let parsed: TraceJitTelemetry = serde_json::from_str(&json).expect("deserialize");
1026 assert_eq!(telemetry, parsed);
1027 }
1028
1029 #[test]
1032 fn reset_clears_all_state() {
1033 let config = TraceJitConfig::new(true, 1, 64, 4);
1034 let mut jit = TraceJitCompiler::new(config);
1035 let plan = make_plan("p1", &["a", "b"], 10);
1036
1037 jit.record_plan_execution(&plan);
1038 jit.try_jit_dispatch("p1", &trace(&["a", "b"]), &default_ctx());
1039 assert!(jit.cache_size() > 0);
1040 assert!(jit.telemetry().jit_hits > 0);
1041
1042 jit.reset();
1043 assert_eq!(jit.cache_size(), 0);
1044 assert_eq!(jit.telemetry().jit_hits, 0);
1045 assert_eq!(jit.telemetry().traces_compiled, 0);
1046 }
1047
1048 #[test]
1051 fn jit_cost_less_than_fused_cost() {
1052 for width in 2..=8 {
1053 let jit_cost = estimated_jit_cost(width);
1054 let fused_cost = 6 + i64::try_from(width).unwrap() * 2;
1055 assert!(
1056 jit_cost < fused_cost,
1057 "JIT cost ({jit_cost}) should be less than fused cost ({fused_cost}) for width {width}"
1058 );
1059 }
1060 }
1061
1062 #[test]
1063 fn jit_cost_scales_linearly() {
1064 let cost_2 = estimated_jit_cost(2);
1065 let cost_4 = estimated_jit_cost(4);
1066 let delta = cost_4 - cost_2;
1067 assert_eq!(delta, 2);
1069 }
1070
1071 #[test]
1074 fn compiled_trace_serializes_round_trip() {
1075 let plan = make_plan("p_rt", &["a", "b", "c"], 10);
1076 let compiled = CompiledTrace::from_plan(&plan);
1077
1078 let json = serde_json::to_string(&compiled).expect("serialize");
1079 let parsed: CompiledTrace = serde_json::from_str(&json).expect("deserialize");
1080 assert_eq!(compiled, parsed);
1081 }
1082
1083 #[test]
1086 fn deopt_reason_serializes_round_trip() {
1087 let reasons = vec![
1088 DeoptReason::GuardFailure {
1089 guard_index: 1,
1090 description: "test".to_string(),
1091 },
1092 DeoptReason::TraceInvalidated { total_failures: 5 },
1093 DeoptReason::JitDisabled,
1094 DeoptReason::NotCompiled,
1095 DeoptReason::SafetyVeto,
1096 ];
1097
1098 for reason in &reasons {
1099 let value = serde_json::to_value(reason).expect("serialize to value");
1100 let parsed: DeoptReason =
1101 serde_json::from_value(value).expect("deserialize from value");
1102 assert_eq!(*reason, parsed);
1103 }
1104 }
1105
1106 #[test]
1109 fn multiple_plans_compile_independently() {
1110 let config = TraceJitConfig::new(true, 2, 64, 4);
1111 let mut jit = TraceJitCompiler::new(config);
1112
1113 let p1 = make_plan("p1", &["a", "b"], 10);
1114 let p2 = make_plan("p2", &["c", "d"], 10);
1115
1116 jit.record_plan_execution(&p1);
1118 jit.record_plan_execution(&p1);
1119 jit.record_plan_execution(&p2);
1120 jit.record_plan_execution(&p2);
1121 assert_eq!(jit.cache_size(), 2);
1122
1123 let r1 = jit.try_jit_dispatch("p1", &trace(&["a", "b"]), &default_ctx());
1125 let r2 = jit.try_jit_dispatch("p2", &trace(&["c", "d"]), &default_ctx());
1126 assert!(r1.jit_hit);
1127 assert!(r2.jit_hit);
1128 assert_eq!(jit.telemetry().jit_hits, 2);
1129 }
1130
1131 mod proptest_trace_jit {
1134 use super::*;
1135
1136 use proptest::prelude::*;
1137
1138 fn arb_opcode() -> impl Strategy<Value = String> {
1139 prop::sample::select(vec![
1140 "session.get_state".to_string(),
1141 "session.get_messages".to_string(),
1142 "events.list".to_string(),
1143 "tool.read".to_string(),
1144 "tool.write".to_string(),
1145 "events.emit".to_string(),
1146 ])
1147 }
1148
1149 fn arb_window() -> impl Strategy<Value = Vec<String>> {
1150 prop::collection::vec(arb_opcode(), 2..6)
1151 }
1152
1153 fn arb_plan() -> impl Strategy<Value = HostcallSuperinstructionPlan> {
1154 (arb_window(), 2..100u32).prop_map(|(window, support)| {
1155 let width = window.len();
1156 let baseline = i64::try_from(width).unwrap_or(0) * 10;
1157 let fused = 6 + i64::try_from(width).unwrap_or(0) * 2;
1158 HostcallSuperinstructionPlan {
1159 schema: HOSTCALL_SUPERINSTRUCTION_SCHEMA_VERSION.to_string(),
1160 version: HOSTCALL_SUPERINSTRUCTION_PLAN_VERSION,
1161 plan_id: format!("arb_{width}_{support}"),
1162 trace_signature: format!("sig_arb_{width}_{support}"),
1163 opcode_window: window,
1164 support_count: support,
1165 estimated_cost_baseline: baseline,
1166 estimated_cost_fused: fused,
1167 expected_cost_delta: baseline - fused,
1168 }
1169 })
1170 }
1171
1172 fn arb_guard_context() -> impl Strategy<Value = GuardContext> {
1173 (any::<bool>(), 0..200u32).prop_map(|(vetoing, support)| GuardContext {
1174 safety_envelope_vetoing: vetoing,
1175 current_support_count: support,
1176 })
1177 }
1178
1179 fn arb_config() -> impl Strategy<Value = TraceJitConfig> {
1180 (1..16u64, 2..32usize, 1..8u64).prop_map(|(min_exec, max_traces, max_failures)| {
1181 TraceJitConfig::new(true, min_exec, max_traces, max_failures)
1182 })
1183 }
1184
1185 proptest! {
1186 #[test]
1187 fn jit_cost_less_than_fused_for_width_ge_2(width in 2..1000usize) {
1188 let jit_cost = estimated_jit_cost(width);
1189 let fused_cost = 6 + i64::try_from(width).unwrap() * 2;
1190 assert!(
1191 jit_cost < fused_cost,
1192 "JIT cost ({jit_cost}) must be < fused cost ({fused_cost}) at width {width}"
1193 );
1194 }
1195
1196 #[test]
1197 fn compiled_trace_tier_improvement_nonnegative(plan in arb_plan()) {
1198 let compiled = CompiledTrace::from_plan(&plan);
1199 assert!(
1200 compiled.tier_improvement_delta >= 0,
1201 "tier_improvement_delta must be non-negative, got {}",
1202 compiled.tier_improvement_delta,
1203 );
1204 }
1205
1206 #[test]
1207 fn compiled_trace_always_has_three_guards(plan in arb_plan()) {
1208 let compiled = CompiledTrace::from_plan(&plan);
1209 assert!(
1210 compiled.guards.len() == 3,
1211 "compiled trace must have 3 guards (OpcodePrefix, SafetyEnvelope, MinSupport)"
1212 );
1213 }
1214
1215 #[test]
1216 fn compiled_trace_width_matches_plan(plan in arb_plan()) {
1217 let compiled = CompiledTrace::from_plan(&plan);
1218 assert!(
1219 compiled.width == plan.width(),
1220 "compiled width {} != plan width {}",
1221 compiled.width,
1222 plan.width(),
1223 );
1224 }
1225
1226 #[test]
1227 fn disabled_jit_never_promotes(
1228 plan in arb_plan(),
1229 executions in 1..50usize,
1230 ) {
1231 let config = TraceJitConfig::new(false, 1, 64, 4);
1232 let mut jit = TraceJitCompiler::new(config);
1233 for _ in 0..executions {
1234 let promoted = jit.record_plan_execution(&plan);
1235 assert!(!promoted, "disabled JIT must never promote");
1236 }
1237 assert!(
1238 jit.cache_size() == 0,
1239 "disabled JIT must have empty cache"
1240 );
1241 }
1242
1243 #[test]
1244 fn cache_size_never_exceeds_max(
1245 config in arb_config(),
1246 plans in prop::collection::vec(arb_plan(), 1..20),
1247 ) {
1248 let max = config.max_compiled_traces;
1249 let min_exec = config.min_jit_executions;
1250 let mut jit = TraceJitCompiler::new(config);
1251 for plan in &plans {
1252 for _ in 0..min_exec {
1253 jit.record_plan_execution(plan);
1254 }
1255 }
1256 assert!(
1257 jit.cache_size() <= max,
1258 "cache size {} exceeds max {}",
1259 jit.cache_size(),
1260 max,
1261 );
1262 }
1263
1264 #[test]
1265 fn telemetry_traces_compiled_matches_cache_plus_evictions(
1266 config in arb_config(),
1267 plans in prop::collection::vec(arb_plan(), 1..10),
1268 ) {
1269 let min_exec = config.min_jit_executions;
1270 let mut jit = TraceJitCompiler::new(config);
1271 for plan in &plans {
1272 for _ in 0..min_exec {
1273 jit.record_plan_execution(plan);
1274 }
1275 }
1276 let t = jit.telemetry();
1277 assert!(
1279 t.traces_compiled >= t.cache_size,
1280 "traces_compiled ({}) must be >= cache_size ({})",
1281 t.traces_compiled,
1282 t.cache_size,
1283 );
1284 }
1285
1286 #[test]
1287 fn guard_check_is_deterministic(
1288 plan in arb_plan(),
1289 trace_opcodes in arb_window(),
1290 ctx in arb_guard_context(),
1291 ) {
1292 let compiled = CompiledTrace::from_plan(&plan);
1293 let r1 = compiled.guards_pass(&trace_opcodes, &ctx);
1294 let r2 = compiled.guards_pass(&trace_opcodes, &ctx);
1295 assert!(r1 == r2, "guard check must be deterministic");
1296 }
1297
1298 #[test]
1299 fn jit_hit_implies_zero_deopt_reason(
1300 config in arb_config(),
1301 plan in arb_plan(),
1302 ) {
1303 let min_exec = config.min_jit_executions;
1304 let mut jit = TraceJitCompiler::new(config);
1305 for _ in 0..min_exec {
1307 jit.record_plan_execution(&plan);
1308 }
1309 let ctx = GuardContext {
1311 safety_envelope_vetoing: false,
1312 current_support_count: plan.support_count,
1313 };
1314 let result = jit.try_jit_dispatch(&plan.plan_id, &plan.opcode_window, &ctx);
1315 if result.jit_hit {
1316 assert!(
1317 result.deopt_reason.is_none(),
1318 "JIT hit must have no deopt reason"
1319 );
1320 assert!(
1321 result.cost_delta >= 0,
1322 "JIT hit must have non-negative cost delta"
1323 );
1324 }
1325 }
1326
1327 #[test]
1328 fn deopts_stop_growing_after_invalidation(
1329 max_guard_failures in 1..8u64,
1330 attempts in 1..40u64,
1331 ) {
1332 let config = TraceJitConfig::new(true, 1, 8, max_guard_failures);
1333 let mut jit = TraceJitCompiler::new(config);
1334 let plan = make_plan("prop_invalidation", &["a", "b"], 10);
1335 jit.record_plan_execution(&plan);
1336
1337 for _ in 0..attempts {
1338 let _ = jit.try_jit_dispatch("prop_invalidation", &trace(&["x"]), &default_ctx());
1339 }
1340
1341 let telemetry = jit.telemetry();
1342 let expected_deopts = attempts.min(max_guard_failures);
1343 prop_assert_eq!(telemetry.deopts, expected_deopts);
1344 prop_assert_eq!(telemetry.jit_misses, expected_deopts);
1345 prop_assert!(telemetry.invalidations <= 1);
1346
1347 if attempts >= max_guard_failures {
1348 prop_assert!(jit.is_invalidated("prop_invalidation"));
1349 prop_assert_eq!(telemetry.invalidations, 1);
1350 } else {
1351 prop_assert!(!jit.is_invalidated("prop_invalidation"));
1352 prop_assert_eq!(telemetry.invalidations, 0);
1353 }
1354 }
1355 }
1356 }
1357}