1use serde::{Deserialize, Serialize};
23use std::collections::BTreeMap;
24
25use crate::hostcall_superinstructions::HostcallSuperinstructionPlan;
26
27const DEFAULT_MIN_JIT_EXECUTIONS: u64 = 8;
31const DEFAULT_MAX_COMPILED_TRACES: usize = 64;
33const DEFAULT_MAX_GUARD_FAILURES: u64 = 4;
35const JIT_DISPATCH_COST_UNITS: i64 = 3;
37const JIT_DISPATCH_STEP_COST_UNITS: i64 = 1;
39
40#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
44pub struct TraceJitConfig {
45 pub enabled: bool,
47 pub min_jit_executions: u64,
49 pub max_compiled_traces: usize,
51 pub max_guard_failures: u64,
53}
54
55impl Default for TraceJitConfig {
56 fn default() -> Self {
57 Self::from_env()
58 }
59}
60
61impl TraceJitConfig {
62 #[must_use]
64 pub const fn new(
65 enabled: bool,
66 min_jit_executions: u64,
67 max_compiled_traces: usize,
68 max_guard_failures: u64,
69 ) -> Self {
70 Self {
71 enabled,
72 min_jit_executions,
73 max_compiled_traces,
74 max_guard_failures,
75 }
76 }
77
78 #[must_use]
80 pub fn from_env() -> Self {
81 let enabled = bool_from_env("PI_HOSTCALL_TRACE_JIT", true);
82 let min_jit_executions = u64_from_env(
83 "PI_HOSTCALL_TRACE_JIT_MIN_EXECUTIONS",
84 DEFAULT_MIN_JIT_EXECUTIONS,
85 );
86 let max_compiled_traces = usize_from_env(
87 "PI_HOSTCALL_TRACE_JIT_MAX_TRACES",
88 DEFAULT_MAX_COMPILED_TRACES,
89 );
90 let max_guard_failures = u64_from_env(
91 "PI_HOSTCALL_TRACE_JIT_MAX_GUARD_FAILURES",
92 DEFAULT_MAX_GUARD_FAILURES,
93 );
94 Self::new(
95 enabled,
96 min_jit_executions,
97 max_compiled_traces,
98 max_guard_failures,
99 )
100 }
101}
102
103#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
107pub enum TraceGuard {
108 OpcodePrefix(Vec<String>),
110 SafetyEnvelopeNotVetoing,
112 MinSupportCount(u32),
114}
115
116impl TraceGuard {
117 #[must_use]
119 pub fn check(&self, trace: &[String], ctx: &GuardContext) -> bool {
120 match self {
121 Self::OpcodePrefix(window) => {
122 trace.len() >= window.len()
123 && trace
124 .iter()
125 .zip(window.iter())
126 .all(|(actual, expected)| actual == expected)
127 }
128 Self::SafetyEnvelopeNotVetoing => !ctx.safety_envelope_vetoing,
129 Self::MinSupportCount(min) => ctx.current_support_count >= *min,
130 }
131 }
132}
133
134#[derive(Debug, Clone, Default)]
136pub struct GuardContext {
137 pub safety_envelope_vetoing: bool,
139 pub current_support_count: u32,
141}
142
143#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
147pub enum CompilationTier {
148 Superinstruction,
150 TraceJit,
152}
153
154#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
156pub struct CompiledTrace {
157 pub plan_id: String,
159 pub trace_signature: String,
161 pub guards: Vec<TraceGuard>,
163 pub opcode_window: Vec<String>,
165 pub width: usize,
167 pub estimated_cost_jit: i64,
169 pub estimated_cost_fused: i64,
171 pub tier_improvement_delta: i64,
173 pub tier: CompilationTier,
175}
176
177impl CompiledTrace {
178 #[must_use]
180 pub fn from_plan(plan: &HostcallSuperinstructionPlan) -> Self {
181 let width = plan.width();
182 let estimated_cost_jit = estimated_jit_cost(width);
183 let tier_improvement_delta = plan.estimated_cost_fused.saturating_sub(estimated_cost_jit);
184
185 let guards = vec![
186 TraceGuard::OpcodePrefix(plan.opcode_window.clone()),
187 TraceGuard::SafetyEnvelopeNotVetoing,
188 TraceGuard::MinSupportCount(plan.support_count / 2),
189 ];
190
191 Self {
192 plan_id: plan.plan_id.clone(),
193 trace_signature: plan.trace_signature.clone(),
194 guards,
195 opcode_window: plan.opcode_window.clone(),
196 width,
197 estimated_cost_jit,
198 estimated_cost_fused: plan.estimated_cost_fused,
199 tier_improvement_delta,
200 tier: CompilationTier::TraceJit,
201 }
202 }
203
204 #[must_use]
206 pub fn guards_pass(&self, trace: &[String], ctx: &GuardContext) -> bool {
207 self.guards.iter().all(|guard| guard.check(trace, ctx))
208 }
209}
210
211#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
215pub enum DeoptReason {
216 GuardFailure {
218 guard_index: usize,
220 description: String,
222 },
223 TraceInvalidated {
225 total_failures: u64,
227 },
228 JitDisabled,
230 NotCompiled,
232 SafetyVeto,
234}
235
236#[derive(Debug, Clone)]
238pub struct JitExecutionResult {
239 pub jit_hit: bool,
241 pub plan_id: Option<String>,
243 pub deopt_reason: Option<DeoptReason>,
245 pub cost_delta: i64,
247}
248
249#[derive(Debug, Clone, Default)]
253struct PlanProfile {
254 execution_count: u64,
256 consecutive_guard_failures: u64,
258 invalidated: bool,
260 last_access_generation: u64,
262}
263
264#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
268pub struct TraceJitTelemetry {
269 pub plans_evaluated: u64,
271 pub traces_compiled: u64,
273 pub jit_hits: u64,
275 pub jit_misses: u64,
277 pub deopts: u64,
279 pub invalidations: u64,
281 pub evictions: u64,
283 pub cache_size: u64,
285}
286
287#[derive(Debug, Clone)]
294pub struct TraceJitCompiler {
295 config: TraceJitConfig,
296 cache: BTreeMap<String, CompiledTrace>,
298 profiles: BTreeMap<String, PlanProfile>,
300 generation: u64,
302 telemetry: TraceJitTelemetry,
304}
305
306impl Default for TraceJitCompiler {
307 fn default() -> Self {
308 Self::new(TraceJitConfig::default())
309 }
310}
311
312impl TraceJitCompiler {
313 #[must_use]
315 pub fn new(config: TraceJitConfig) -> Self {
316 Self {
317 config,
318 cache: BTreeMap::new(),
319 profiles: BTreeMap::new(),
320 generation: 0,
321 telemetry: TraceJitTelemetry::default(),
322 }
323 }
324
325 #[must_use]
327 pub const fn enabled(&self) -> bool {
328 self.config.enabled
329 }
330
331 #[must_use]
333 pub const fn config(&self) -> &TraceJitConfig {
334 &self.config
335 }
336
337 #[must_use]
339 pub const fn telemetry(&self) -> &TraceJitTelemetry {
340 &self.telemetry
341 }
342
343 #[must_use]
345 pub fn cache_size(&self) -> usize {
346 self.cache.len()
347 }
348
349 pub fn record_plan_execution(&mut self, plan: &HostcallSuperinstructionPlan) -> bool {
354 if !self.config.enabled {
355 return false;
356 }
357
358 self.telemetry.plans_evaluated += 1;
359 self.generation += 1;
360
361 let profile = self.profiles.entry(plan.plan_id.clone()).or_default();
362
363 profile.execution_count += 1;
364 profile.last_access_generation = self.generation;
365
366 if profile.invalidated {
367 return false;
368 }
369
370 if profile.execution_count >= self.config.min_jit_executions
372 && !self.cache.contains_key(&plan.plan_id)
373 {
374 self.compile_trace(plan);
375 return true;
376 }
377
378 false
379 }
380
381 fn compile_trace(&mut self, plan: &HostcallSuperinstructionPlan) {
383 if self.cache.len() >= self.config.max_compiled_traces {
385 self.evict_lru();
386 }
387
388 let compiled = CompiledTrace::from_plan(plan);
389 self.cache.insert(plan.plan_id.clone(), compiled);
390 self.telemetry.traces_compiled += 1;
391 self.telemetry.cache_size = u64::try_from(self.cache.len()).unwrap_or(u64::MAX);
392 }
393
394 fn evict_lru(&mut self) {
396 let lru_plan_id = self
397 .cache
398 .keys()
399 .min_by_key(|plan_id| {
400 self.profiles
401 .get(*plan_id)
402 .map_or(0, |profile| profile.last_access_generation)
403 })
404 .cloned();
405
406 if let Some(plan_id) = lru_plan_id {
407 self.cache.remove(&plan_id);
408 self.telemetry.evictions += 1;
409 self.telemetry.cache_size = u64::try_from(self.cache.len()).unwrap_or(u64::MAX);
410 }
411 }
412
413 pub fn try_jit_dispatch(
418 &mut self,
419 plan_id: &str,
420 trace: &[String],
421 ctx: &GuardContext,
422 ) -> JitExecutionResult {
423 if !self.config.enabled {
424 return JitExecutionResult {
425 jit_hit: false,
426 plan_id: Some(plan_id.to_string()),
427 deopt_reason: Some(DeoptReason::JitDisabled),
428 cost_delta: 0,
429 };
430 }
431
432 self.generation += 1;
433
434 let (tier_improvement_delta, failed_guard) = {
436 let Some(compiled) = self.cache.get(plan_id) else {
437 return JitExecutionResult {
438 jit_hit: false,
439 plan_id: Some(plan_id.to_string()),
440 deopt_reason: Some(DeoptReason::NotCompiled),
441 cost_delta: 0,
442 };
443 };
444
445 let mut failed_guard = None;
446 for (idx, guard) in compiled.guards.iter().enumerate() {
447 if !guard.check(trace, ctx) {
448 let description = match guard {
449 TraceGuard::OpcodePrefix(_) => "opcode_prefix_mismatch",
450 TraceGuard::SafetyEnvelopeNotVetoing => "safety_envelope_vetoing",
451 TraceGuard::MinSupportCount(_) => "support_count_below_threshold",
452 };
453 failed_guard = Some((idx, description));
454 break;
455 }
456 }
457
458 (compiled.tier_improvement_delta, failed_guard)
459 };
460
461 if let Some(profile) = self.profiles.get_mut(plan_id) {
463 profile.last_access_generation = self.generation;
464 }
465
466 if let Some((idx, description)) = failed_guard {
468 let invalidated_after_failures = self.record_guard_failure(plan_id);
469 let deopt_reason = invalidated_after_failures.map_or_else(
470 || DeoptReason::GuardFailure {
471 guard_index: idx,
472 description: description.to_string(),
473 },
474 |total_failures| DeoptReason::TraceInvalidated { total_failures },
475 );
476 return JitExecutionResult {
477 jit_hit: false,
478 plan_id: Some(plan_id.to_string()),
479 deopt_reason: Some(deopt_reason),
480 cost_delta: 0,
481 };
482 }
483
484 self.telemetry.jit_hits += 1;
486 if let Some(profile) = self.profiles.get_mut(plan_id) {
487 profile.consecutive_guard_failures = 0;
488 }
489
490 JitExecutionResult {
491 jit_hit: true,
492 plan_id: Some(plan_id.to_string()),
493 deopt_reason: None,
494 cost_delta: tier_improvement_delta,
495 }
496 }
497
498 fn record_guard_failure(&mut self, plan_id: &str) -> Option<u64> {
500 self.telemetry.deopts += 1;
501 self.telemetry.jit_misses += 1;
502
503 if let Some(profile) = self.profiles.get_mut(plan_id) {
504 profile.consecutive_guard_failures += 1;
505 if !profile.invalidated
506 && profile.consecutive_guard_failures >= self.config.max_guard_failures
507 {
508 profile.invalidated = true;
509 self.cache.remove(plan_id);
510 self.telemetry.invalidations += 1;
511 self.telemetry.cache_size = u64::try_from(self.cache.len()).unwrap_or(u64::MAX);
512 return Some(profile.consecutive_guard_failures);
513 }
514 }
515 None
516 }
517
518 #[must_use]
520 pub fn get_compiled_trace(&self, plan_id: &str) -> Option<&CompiledTrace> {
521 self.cache.get(plan_id)
522 }
523
524 #[must_use]
526 pub fn is_invalidated(&self, plan_id: &str) -> bool {
527 self.profiles
528 .get(plan_id)
529 .is_some_and(|profile| profile.invalidated)
530 }
531
532 pub fn reset(&mut self) {
534 self.cache.clear();
535 self.profiles.clear();
536 self.generation = 0;
537 self.telemetry = TraceJitTelemetry::default();
538 }
539}
540
541#[must_use]
545pub fn estimated_jit_cost(width: usize) -> i64 {
546 let width_units = i64::try_from(width).unwrap_or(i64::MAX);
547 JIT_DISPATCH_COST_UNITS.saturating_add(width_units.saturating_mul(JIT_DISPATCH_STEP_COST_UNITS))
548}
549
550fn bool_from_env(var: &str, default: bool) -> bool {
553 std::env::var(var).ok().as_deref().map_or(default, |value| {
554 !matches!(
555 value.trim().to_ascii_lowercase().as_str(),
556 "0" | "false" | "off" | "disabled"
557 )
558 })
559}
560
561fn u64_from_env(var: &str, default: u64) -> u64 {
562 std::env::var(var)
563 .ok()
564 .and_then(|raw| raw.trim().parse::<u64>().ok())
565 .unwrap_or(default)
566}
567
568fn usize_from_env(var: &str, default: usize) -> usize {
569 std::env::var(var)
570 .ok()
571 .and_then(|raw| raw.trim().parse::<usize>().ok())
572 .unwrap_or(default)
573}
574
575#[cfg(test)]
578mod tests {
579 use super::*;
580 use crate::hostcall_superinstructions::{
581 HOSTCALL_SUPERINSTRUCTION_PLAN_VERSION, HOSTCALL_SUPERINSTRUCTION_SCHEMA_VERSION,
582 HostcallSuperinstructionPlan,
583 };
584
585 fn make_plan(
586 plan_id: &str,
587 window: &[&str],
588 support_count: u32,
589 ) -> HostcallSuperinstructionPlan {
590 let opcode_window: Vec<String> = window.iter().map(ToString::to_string).collect();
591 let width = opcode_window.len();
592 HostcallSuperinstructionPlan {
593 schema: HOSTCALL_SUPERINSTRUCTION_SCHEMA_VERSION.to_string(),
594 version: HOSTCALL_SUPERINSTRUCTION_PLAN_VERSION,
595 plan_id: plan_id.to_string(),
596 trace_signature: format!("sig_{plan_id}"),
597 opcode_window,
598 support_count,
599 estimated_cost_baseline: i64::try_from(width).unwrap_or(0) * 10,
600 estimated_cost_fused: 6 + i64::try_from(width).unwrap_or(0) * 2,
601 expected_cost_delta: i64::try_from(width).unwrap_or(0) * 8 - 6,
602 }
603 }
604
605 fn trace(opcodes: &[&str]) -> Vec<String> {
606 opcodes.iter().map(ToString::to_string).collect()
607 }
608
609 fn default_ctx() -> GuardContext {
610 GuardContext {
611 safety_envelope_vetoing: false,
612 current_support_count: 100,
613 }
614 }
615
616 #[test]
619 fn config_default_values() {
620 let config = TraceJitConfig::new(true, 8, 64, 4);
621 assert!(config.enabled);
622 assert_eq!(config.min_jit_executions, 8);
623 assert_eq!(config.max_compiled_traces, 64);
624 assert_eq!(config.max_guard_failures, 4);
625 }
626
627 #[test]
628 fn config_disabled_prevents_compilation() {
629 let config = TraceJitConfig::new(false, 1, 64, 4);
630 let mut jit = TraceJitCompiler::new(config);
631 let plan = make_plan("p1", &["a", "b"], 10);
632
633 let promoted = jit.record_plan_execution(&plan);
634 assert!(!promoted);
635 assert_eq!(jit.cache_size(), 0);
636 }
637
638 #[test]
641 fn plan_promoted_after_reaching_threshold() {
642 let config = TraceJitConfig::new(true, 3, 64, 4);
643 let mut jit = TraceJitCompiler::new(config);
644 let plan = make_plan("p1", &["session.get_state", "session.get_messages"], 10);
645
646 assert!(!jit.record_plan_execution(&plan));
647 assert!(!jit.record_plan_execution(&plan));
648 assert!(jit.record_plan_execution(&plan)); assert_eq!(jit.cache_size(), 1);
650
651 assert!(!jit.record_plan_execution(&plan));
653 assert_eq!(jit.telemetry().traces_compiled, 1);
654 }
655
656 #[test]
657 fn plan_not_promoted_before_threshold() {
658 let config = TraceJitConfig::new(true, 10, 64, 4);
659 let mut jit = TraceJitCompiler::new(config);
660 let plan = make_plan("p1", &["a", "b"], 5);
661
662 for _ in 0..9 {
663 assert!(!jit.record_plan_execution(&plan));
664 }
665 assert_eq!(jit.cache_size(), 0);
666 assert!(jit.record_plan_execution(&plan)); assert_eq!(jit.cache_size(), 1);
668 }
669
670 #[test]
673 fn guard_opcode_prefix_passes_on_match() {
674 let guard = TraceGuard::OpcodePrefix(trace(&["a", "b"]));
675 let ctx = default_ctx();
676 assert!(guard.check(&trace(&["a", "b", "c"]), &ctx));
677 assert!(guard.check(&trace(&["a", "b"]), &ctx));
678 }
679
680 #[test]
681 fn guard_opcode_prefix_fails_on_mismatch() {
682 let guard = TraceGuard::OpcodePrefix(trace(&["a", "b"]));
683 let ctx = default_ctx();
684 assert!(!guard.check(&trace(&["a", "c"]), &ctx));
685 assert!(!guard.check(&trace(&["a"]), &ctx));
686 assert!(!guard.check(&trace(&[]), &ctx));
687 }
688
689 #[test]
690 fn guard_safety_envelope_passes_when_not_vetoing() {
691 let guard = TraceGuard::SafetyEnvelopeNotVetoing;
692 let ctx = GuardContext {
693 safety_envelope_vetoing: false,
694 ..default_ctx()
695 };
696 assert!(guard.check(&[], &ctx));
697 }
698
699 #[test]
700 fn guard_safety_envelope_fails_when_vetoing() {
701 let guard = TraceGuard::SafetyEnvelopeNotVetoing;
702 let ctx = GuardContext {
703 safety_envelope_vetoing: true,
704 ..default_ctx()
705 };
706 assert!(!guard.check(&[], &ctx));
707 }
708
709 #[test]
710 fn guard_min_support_count_passes() {
711 let guard = TraceGuard::MinSupportCount(5);
712 let ctx = GuardContext {
713 current_support_count: 10,
714 ..default_ctx()
715 };
716 assert!(guard.check(&[], &ctx));
717 }
718
719 #[test]
720 fn guard_min_support_count_fails() {
721 let guard = TraceGuard::MinSupportCount(5);
722 let ctx = GuardContext {
723 current_support_count: 3,
724 ..default_ctx()
725 };
726 assert!(!guard.check(&[], &ctx));
727 }
728
729 #[test]
732 fn compiled_trace_from_plan_sets_tier() {
733 let plan = make_plan("p1", &["a", "b", "c"], 10);
734 let compiled = CompiledTrace::from_plan(&plan);
735
736 assert_eq!(compiled.plan_id, "p1");
737 assert_eq!(compiled.tier, CompilationTier::TraceJit);
738 assert_eq!(compiled.width, 3);
739 assert_eq!(compiled.guards.len(), 3);
740 }
741
742 #[test]
743 fn compiled_trace_cost_lower_than_fused() {
744 let plan = make_plan("p1", &["a", "b", "c"], 10);
745 let compiled = CompiledTrace::from_plan(&plan);
746
747 assert!(
748 compiled.estimated_cost_jit < compiled.estimated_cost_fused,
749 "JIT cost ({}) should be less than fused cost ({})",
750 compiled.estimated_cost_jit,
751 compiled.estimated_cost_fused
752 );
753 assert!(compiled.tier_improvement_delta > 0);
754 }
755
756 #[test]
757 fn compiled_trace_guards_pass_on_matching_trace() {
758 let plan = make_plan("p1", &["a", "b"], 10);
759 let compiled = CompiledTrace::from_plan(&plan);
760 let ctx = default_ctx();
761
762 assert!(compiled.guards_pass(&trace(&["a", "b", "c"]), &ctx));
763 }
764
765 #[test]
766 fn compiled_trace_guards_fail_on_wrong_prefix() {
767 let plan = make_plan("p1", &["a", "b"], 10);
768 let compiled = CompiledTrace::from_plan(&plan);
769 let ctx = default_ctx();
770
771 assert!(!compiled.guards_pass(&trace(&["x", "y"]), &ctx));
772 }
773
774 #[test]
775 fn compiled_trace_guards_fail_on_safety_veto() {
776 let plan = make_plan("p1", &["a", "b"], 10);
777 let compiled = CompiledTrace::from_plan(&plan);
778 let ctx = GuardContext {
779 safety_envelope_vetoing: true,
780 ..default_ctx()
781 };
782
783 assert!(!compiled.guards_pass(&trace(&["a", "b"]), &ctx));
784 }
785
786 #[test]
789 fn jit_dispatch_hits_after_promotion() {
790 let config = TraceJitConfig::new(true, 2, 64, 4);
791 let mut jit = TraceJitCompiler::new(config);
792 let plan = make_plan("p1", &["a", "b"], 10);
793
794 jit.record_plan_execution(&plan);
796 jit.record_plan_execution(&plan);
797 assert_eq!(jit.cache_size(), 1);
798
799 let result = jit.try_jit_dispatch("p1", &trace(&["a", "b", "c"]), &default_ctx());
801 assert!(result.jit_hit);
802 assert!(result.deopt_reason.is_none());
803 assert!(result.cost_delta > 0);
804 assert_eq!(jit.telemetry().jit_hits, 1);
805 }
806
807 #[test]
808 fn jit_dispatch_returns_not_compiled_before_promotion() {
809 let config = TraceJitConfig::new(true, 10, 64, 4);
810 let mut jit = TraceJitCompiler::new(config);
811
812 let result = jit.try_jit_dispatch("p1", &trace(&["a", "b"]), &default_ctx());
813 assert!(!result.jit_hit);
814 assert_eq!(result.deopt_reason, Some(DeoptReason::NotCompiled));
815 }
816
817 #[test]
818 fn jit_dispatch_deopt_on_guard_failure() {
819 let config = TraceJitConfig::new(true, 1, 64, 4);
820 let mut jit = TraceJitCompiler::new(config);
821 let plan = make_plan("p1", &["a", "b"], 10);
822 jit.record_plan_execution(&plan);
823
824 let result = jit.try_jit_dispatch("p1", &trace(&["x", "y"]), &default_ctx());
826 assert!(!result.jit_hit);
827 assert!(matches!(
828 result.deopt_reason,
829 Some(DeoptReason::GuardFailure { guard_index: 0, .. })
830 ));
831 assert_eq!(jit.telemetry().deopts, 1);
832 }
833
834 #[test]
835 fn jit_dispatch_deopt_on_safety_veto() {
836 let config = TraceJitConfig::new(true, 1, 64, 4);
837 let mut jit = TraceJitCompiler::new(config);
838 let plan = make_plan("p1", &["a", "b"], 10);
839 jit.record_plan_execution(&plan);
840
841 let ctx = GuardContext {
842 safety_envelope_vetoing: true,
843 ..default_ctx()
844 };
845 let result = jit.try_jit_dispatch("p1", &trace(&["a", "b"]), &ctx);
846 assert!(!result.jit_hit);
847 assert!(matches!(
848 result.deopt_reason,
849 Some(DeoptReason::GuardFailure { guard_index: 1, .. })
850 ));
851 }
852
853 #[test]
854 fn jit_dispatch_deopt_on_support_count_guard() {
855 let config = TraceJitConfig::new(true, 1, 64, 4);
856 let mut jit = TraceJitCompiler::new(config);
857 let plan = make_plan("p1", &["a", "b"], 20);
858 jit.record_plan_execution(&plan);
859
860 let ctx = GuardContext {
861 safety_envelope_vetoing: false,
862 current_support_count: 9, };
864 let result = jit.try_jit_dispatch("p1", &trace(&["a", "b"]), &ctx);
865 assert!(!result.jit_hit);
866 assert_eq!(
867 result.deopt_reason,
868 Some(DeoptReason::GuardFailure {
869 guard_index: 2,
870 description: "support_count_below_threshold".to_string(),
871 })
872 );
873 }
874
875 #[test]
876 fn jit_dispatch_disabled_returns_jit_disabled() {
877 let config = TraceJitConfig::new(false, 1, 64, 4);
878 let mut jit = TraceJitCompiler::new(config);
879
880 let result = jit.try_jit_dispatch("p1", &trace(&["a"]), &default_ctx());
881 assert!(!result.jit_hit);
882 assert_eq!(result.deopt_reason, Some(DeoptReason::JitDisabled));
883 }
884
885 #[test]
888 fn trace_invalidated_after_max_guard_failures() {
889 let config = TraceJitConfig::new(true, 1, 64, 3);
890 let mut jit = TraceJitCompiler::new(config);
891 let plan = make_plan("p1", &["a", "b"], 10);
892 jit.record_plan_execution(&plan);
893 assert_eq!(jit.cache_size(), 1);
894
895 for _ in 0..3 {
897 jit.try_jit_dispatch("p1", &trace(&["x"]), &default_ctx());
898 }
899
900 assert!(jit.is_invalidated("p1"));
901 assert_eq!(jit.cache_size(), 0);
902 assert_eq!(jit.telemetry().invalidations, 1);
903
904 assert!(!jit.record_plan_execution(&plan));
906 }
907
908 #[test]
909 fn threshold_crossing_failure_reports_trace_invalidated() {
910 let config = TraceJitConfig::new(true, 1, 64, 2);
911 let mut jit = TraceJitCompiler::new(config);
912 let plan = make_plan("p1", &["a", "b"], 10);
913 jit.record_plan_execution(&plan);
914
915 let first = jit.try_jit_dispatch("p1", &trace(&["x"]), &default_ctx());
916 assert_eq!(
917 first.deopt_reason,
918 Some(DeoptReason::GuardFailure {
919 guard_index: 0,
920 description: "opcode_prefix_mismatch".to_string(),
921 })
922 );
923
924 let second = jit.try_jit_dispatch("p1", &trace(&["x"]), &default_ctx());
925 assert_eq!(
926 second.deopt_reason,
927 Some(DeoptReason::TraceInvalidated { total_failures: 2 })
928 );
929
930 assert!(jit.is_invalidated("p1"));
931 assert_eq!(jit.cache_size(), 0);
932
933 let after_invalidation = jit.try_jit_dispatch("p1", &trace(&["x"]), &default_ctx());
934 assert_eq!(
935 after_invalidation.deopt_reason,
936 Some(DeoptReason::NotCompiled)
937 );
938
939 let telemetry = jit.telemetry();
940 assert_eq!(telemetry.deopts, 2);
941 assert_eq!(telemetry.jit_misses, 2);
942 assert_eq!(telemetry.invalidations, 1);
943 }
944
945 #[test]
946 fn guard_failure_counter_resets_on_success() {
947 let config = TraceJitConfig::new(true, 1, 64, 3);
948 let mut jit = TraceJitCompiler::new(config);
949 let plan = make_plan("p1", &["a", "b"], 10);
950 jit.record_plan_execution(&plan);
951
952 jit.try_jit_dispatch("p1", &trace(&["x"]), &default_ctx());
954 jit.try_jit_dispatch("p1", &trace(&["x"]), &default_ctx());
955 let result = jit.try_jit_dispatch("p1", &trace(&["a", "b"]), &default_ctx());
956 assert!(result.jit_hit); jit.try_jit_dispatch("p1", &trace(&["x"]), &default_ctx());
958 jit.try_jit_dispatch("p1", &trace(&["x"]), &default_ctx());
959
960 assert!(!jit.is_invalidated("p1"));
962 assert_eq!(jit.cache_size(), 1);
963 }
964
965 #[test]
968 fn lru_eviction_when_cache_full() {
969 let config = TraceJitConfig::new(true, 1, 2, 4);
970 let mut jit = TraceJitCompiler::new(config);
971
972 let p1 = make_plan("p1", &["a", "b"], 10);
973 let p2 = make_plan("p2", &["c", "d"], 10);
974 let p3 = make_plan("p3", &["e", "f"], 10);
975
976 jit.record_plan_execution(&p1); jit.record_plan_execution(&p2); assert_eq!(jit.cache_size(), 2);
979
980 jit.try_jit_dispatch("p2", &trace(&["c", "d"]), &default_ctx());
982
983 jit.record_plan_execution(&p3);
985 assert_eq!(jit.cache_size(), 2);
986 assert!(jit.get_compiled_trace("p1").is_none());
987 assert!(jit.get_compiled_trace("p2").is_some());
988 assert!(jit.get_compiled_trace("p3").is_some());
989 assert_eq!(jit.telemetry().evictions, 1);
990 }
991
992 #[test]
995 fn telemetry_tracks_all_counters() {
996 let config = TraceJitConfig::new(true, 2, 64, 4);
997 let mut jit = TraceJitCompiler::new(config);
998 let plan = make_plan("p1", &["a", "b"], 10);
999
1000 jit.record_plan_execution(&plan);
1002 jit.record_plan_execution(&plan);
1003
1004 jit.try_jit_dispatch("p1", &trace(&["a", "b"]), &default_ctx());
1006 jit.try_jit_dispatch("p1", &trace(&["x"]), &default_ctx());
1008
1009 let t = jit.telemetry();
1010 assert_eq!(t.plans_evaluated, 2);
1011 assert_eq!(t.traces_compiled, 1);
1012 assert_eq!(t.jit_hits, 1);
1013 assert_eq!(t.jit_misses, 1);
1014 assert_eq!(t.deopts, 1);
1015 assert_eq!(t.cache_size, 1);
1016 }
1017
1018 #[test]
1019 fn telemetry_serializes_round_trip() {
1020 let telemetry = TraceJitTelemetry {
1021 plans_evaluated: 100,
1022 traces_compiled: 10,
1023 jit_hits: 50,
1024 jit_misses: 5,
1025 deopts: 5,
1026 invalidations: 1,
1027 evictions: 2,
1028 cache_size: 8,
1029 };
1030
1031 let json = serde_json::to_string(&telemetry).expect("serialize");
1032 let parsed: TraceJitTelemetry = serde_json::from_str(&json).expect("deserialize");
1033 assert_eq!(telemetry, parsed);
1034 }
1035
1036 #[test]
1039 fn reset_clears_all_state() {
1040 let config = TraceJitConfig::new(true, 1, 64, 4);
1041 let mut jit = TraceJitCompiler::new(config);
1042 let plan = make_plan("p1", &["a", "b"], 10);
1043
1044 jit.record_plan_execution(&plan);
1045 jit.try_jit_dispatch("p1", &trace(&["a", "b"]), &default_ctx());
1046 assert!(jit.cache_size() > 0);
1047 assert!(jit.telemetry().jit_hits > 0);
1048
1049 jit.reset();
1050 assert_eq!(jit.cache_size(), 0);
1051 assert_eq!(jit.telemetry().jit_hits, 0);
1052 assert_eq!(jit.telemetry().traces_compiled, 0);
1053 }
1054
1055 #[test]
1058 fn jit_cost_less_than_fused_cost() {
1059 for width in 2..=8 {
1060 let jit_cost = estimated_jit_cost(width);
1061 let fused_cost = 6 + i64::try_from(width).unwrap() * 2;
1062 assert!(
1063 jit_cost < fused_cost,
1064 "JIT cost ({jit_cost}) should be less than fused cost ({fused_cost}) for width {width}"
1065 );
1066 }
1067 }
1068
1069 #[test]
1070 fn jit_cost_scales_linearly() {
1071 let cost_2 = estimated_jit_cost(2);
1072 let cost_4 = estimated_jit_cost(4);
1073 let delta = cost_4 - cost_2;
1074 assert_eq!(delta, 2);
1076 }
1077
1078 #[test]
1081 fn compiled_trace_serializes_round_trip() {
1082 let plan = make_plan("p_rt", &["a", "b", "c"], 10);
1083 let compiled = CompiledTrace::from_plan(&plan);
1084
1085 let json = serde_json::to_string(&compiled).expect("serialize");
1086 let parsed: CompiledTrace = serde_json::from_str(&json).expect("deserialize");
1087 assert_eq!(compiled, parsed);
1088 }
1089
1090 #[test]
1093 fn deopt_reason_serializes_round_trip() {
1094 let reasons = vec![
1095 DeoptReason::GuardFailure {
1096 guard_index: 1,
1097 description: "test".to_string(),
1098 },
1099 DeoptReason::TraceInvalidated { total_failures: 5 },
1100 DeoptReason::JitDisabled,
1101 DeoptReason::NotCompiled,
1102 DeoptReason::SafetyVeto,
1103 ];
1104
1105 for reason in &reasons {
1106 let value = serde_json::to_value(reason).expect("serialize to value");
1107 let parsed: DeoptReason =
1108 serde_json::from_value(value).expect("deserialize from value");
1109 assert_eq!(*reason, parsed);
1110 }
1111 }
1112
1113 #[test]
1116 fn multiple_plans_compile_independently() {
1117 let config = TraceJitConfig::new(true, 2, 64, 4);
1118 let mut jit = TraceJitCompiler::new(config);
1119
1120 let p1 = make_plan("p1", &["a", "b"], 10);
1121 let p2 = make_plan("p2", &["c", "d"], 10);
1122
1123 jit.record_plan_execution(&p1);
1125 jit.record_plan_execution(&p1);
1126 jit.record_plan_execution(&p2);
1127 jit.record_plan_execution(&p2);
1128 assert_eq!(jit.cache_size(), 2);
1129
1130 let r1 = jit.try_jit_dispatch("p1", &trace(&["a", "b"]), &default_ctx());
1132 let r2 = jit.try_jit_dispatch("p2", &trace(&["c", "d"]), &default_ctx());
1133 assert!(r1.jit_hit);
1134 assert!(r2.jit_hit);
1135 assert_eq!(jit.telemetry().jit_hits, 2);
1136 }
1137
1138 mod proptest_trace_jit {
1141 use super::*;
1142
1143 use proptest::prelude::*;
1144
1145 fn arb_opcode() -> impl Strategy<Value = String> {
1146 prop::sample::select(vec![
1147 "session.get_state".to_string(),
1148 "session.get_messages".to_string(),
1149 "events.list".to_string(),
1150 "tool.read".to_string(),
1151 "tool.write".to_string(),
1152 "events.emit".to_string(),
1153 ])
1154 }
1155
1156 fn arb_window() -> impl Strategy<Value = Vec<String>> {
1157 prop::collection::vec(arb_opcode(), 2..6)
1158 }
1159
1160 fn arb_plan() -> impl Strategy<Value = HostcallSuperinstructionPlan> {
1161 (arb_window(), 2..100u32).prop_map(|(window, support)| {
1162 let width = window.len();
1163 let baseline = i64::try_from(width).unwrap_or(0) * 10;
1164 let fused = 6 + i64::try_from(width).unwrap_or(0) * 2;
1165 HostcallSuperinstructionPlan {
1166 schema: HOSTCALL_SUPERINSTRUCTION_SCHEMA_VERSION.to_string(),
1167 version: HOSTCALL_SUPERINSTRUCTION_PLAN_VERSION,
1168 plan_id: format!("arb_{width}_{support}"),
1169 trace_signature: format!("sig_arb_{width}_{support}"),
1170 opcode_window: window,
1171 support_count: support,
1172 estimated_cost_baseline: baseline,
1173 estimated_cost_fused: fused,
1174 expected_cost_delta: baseline - fused,
1175 }
1176 })
1177 }
1178
1179 fn arb_guard_context() -> impl Strategy<Value = GuardContext> {
1180 (any::<bool>(), 0..200u32).prop_map(|(vetoing, support)| GuardContext {
1181 safety_envelope_vetoing: vetoing,
1182 current_support_count: support,
1183 })
1184 }
1185
1186 fn arb_config() -> impl Strategy<Value = TraceJitConfig> {
1187 (1..16u64, 2..32usize, 1..8u64).prop_map(|(min_exec, max_traces, max_failures)| {
1188 TraceJitConfig::new(true, min_exec, max_traces, max_failures)
1189 })
1190 }
1191
1192 proptest! {
1193 #[test]
1194 fn jit_cost_less_than_fused_for_width_ge_2(width in 2..1000usize) {
1195 let jit_cost = estimated_jit_cost(width);
1196 let fused_cost = 6 + i64::try_from(width).unwrap() * 2;
1197 assert!(
1198 jit_cost < fused_cost,
1199 "JIT cost ({jit_cost}) must be < fused cost ({fused_cost}) at width {width}"
1200 );
1201 }
1202
1203 #[test]
1204 fn compiled_trace_tier_improvement_nonnegative(plan in arb_plan()) {
1205 let compiled = CompiledTrace::from_plan(&plan);
1206 assert!(
1207 compiled.tier_improvement_delta >= 0,
1208 "tier_improvement_delta must be non-negative, got {}",
1209 compiled.tier_improvement_delta,
1210 );
1211 }
1212
1213 #[test]
1214 fn compiled_trace_always_has_three_guards(plan in arb_plan()) {
1215 let compiled = CompiledTrace::from_plan(&plan);
1216 assert!(
1217 compiled.guards.len() == 3,
1218 "compiled trace must have 3 guards (OpcodePrefix, SafetyEnvelope, MinSupport)"
1219 );
1220 }
1221
1222 #[test]
1223 fn compiled_trace_width_matches_plan(plan in arb_plan()) {
1224 let compiled = CompiledTrace::from_plan(&plan);
1225 assert!(
1226 compiled.width == plan.width(),
1227 "compiled width {} != plan width {}",
1228 compiled.width,
1229 plan.width(),
1230 );
1231 }
1232
1233 #[test]
1234 fn disabled_jit_never_promotes(
1235 plan in arb_plan(),
1236 executions in 1..50usize,
1237 ) {
1238 let config = TraceJitConfig::new(false, 1, 64, 4);
1239 let mut jit = TraceJitCompiler::new(config);
1240 for _ in 0..executions {
1241 let promoted = jit.record_plan_execution(&plan);
1242 assert!(!promoted, "disabled JIT must never promote");
1243 }
1244 assert!(
1245 jit.cache_size() == 0,
1246 "disabled JIT must have empty cache"
1247 );
1248 }
1249
1250 #[test]
1251 fn cache_size_never_exceeds_max(
1252 config in arb_config(),
1253 plans in prop::collection::vec(arb_plan(), 1..20),
1254 ) {
1255 let max = config.max_compiled_traces;
1256 let min_exec = config.min_jit_executions;
1257 let mut jit = TraceJitCompiler::new(config);
1258 for plan in &plans {
1259 for _ in 0..min_exec {
1260 jit.record_plan_execution(plan);
1261 }
1262 }
1263 assert!(
1264 jit.cache_size() <= max,
1265 "cache size {} exceeds max {}",
1266 jit.cache_size(),
1267 max,
1268 );
1269 }
1270
1271 #[test]
1272 fn telemetry_traces_compiled_matches_cache_plus_evictions(
1273 config in arb_config(),
1274 plans in prop::collection::vec(arb_plan(), 1..10),
1275 ) {
1276 let min_exec = config.min_jit_executions;
1277 let mut jit = TraceJitCompiler::new(config);
1278 for plan in &plans {
1279 for _ in 0..min_exec {
1280 jit.record_plan_execution(plan);
1281 }
1282 }
1283 let t = jit.telemetry();
1284 assert!(
1286 t.traces_compiled >= t.cache_size,
1287 "traces_compiled ({}) must be >= cache_size ({})",
1288 t.traces_compiled,
1289 t.cache_size,
1290 );
1291 }
1292
1293 #[test]
1294 fn guard_check_is_deterministic(
1295 plan in arb_plan(),
1296 trace_opcodes in arb_window(),
1297 ctx in arb_guard_context(),
1298 ) {
1299 let compiled = CompiledTrace::from_plan(&plan);
1300 let r1 = compiled.guards_pass(&trace_opcodes, &ctx);
1301 let r2 = compiled.guards_pass(&trace_opcodes, &ctx);
1302 assert!(r1 == r2, "guard check must be deterministic");
1303 }
1304
1305 #[test]
1306 fn jit_hit_implies_zero_deopt_reason(
1307 config in arb_config(),
1308 plan in arb_plan(),
1309 ) {
1310 let min_exec = config.min_jit_executions;
1311 let mut jit = TraceJitCompiler::new(config);
1312 for _ in 0..min_exec {
1314 jit.record_plan_execution(&plan);
1315 }
1316 let ctx = GuardContext {
1318 safety_envelope_vetoing: false,
1319 current_support_count: plan.support_count,
1320 };
1321 let result = jit.try_jit_dispatch(&plan.plan_id, &plan.opcode_window, &ctx);
1322 if result.jit_hit {
1323 assert!(
1324 result.deopt_reason.is_none(),
1325 "JIT hit must have no deopt reason"
1326 );
1327 assert!(
1328 result.cost_delta >= 0,
1329 "JIT hit must have non-negative cost delta"
1330 );
1331 }
1332 }
1333
1334 #[test]
1335 fn deopts_stop_growing_after_invalidation(
1336 max_guard_failures in 1..8u64,
1337 attempts in 1..40u64,
1338 ) {
1339 let config = TraceJitConfig::new(true, 1, 8, max_guard_failures);
1340 let mut jit = TraceJitCompiler::new(config);
1341 let plan = make_plan("prop_invalidation", &["a", "b"], 10);
1342 jit.record_plan_execution(&plan);
1343
1344 for _ in 0..attempts {
1345 let _ = jit.try_jit_dispatch("prop_invalidation", &trace(&["x"]), &default_ctx());
1346 }
1347
1348 let telemetry = jit.telemetry();
1349 let expected_deopts = attempts.min(max_guard_failures);
1350 prop_assert_eq!(telemetry.deopts, expected_deopts);
1351 prop_assert_eq!(telemetry.jit_misses, expected_deopts);
1352 prop_assert!(telemetry.invalidations <= 1);
1353
1354 if attempts >= max_guard_failures {
1355 prop_assert!(jit.is_invalidated("prop_invalidation"));
1356 prop_assert_eq!(telemetry.invalidations, 1);
1357 } else {
1358 prop_assert!(!jit.is_invalidated("prop_invalidation"));
1359 prop_assert_eq!(telemetry.invalidations, 0);
1360 }
1361 }
1362 }
1363 }
1364}