1use std::sync::Arc;
17
18pub const BASIS_POINTS: u64 = 10_000;
20const SCALED_BASIS_POINTS: u64 = BASIS_POINTS * BASIS_POINTS;
21const COST_SCALE: u64 = 1_000_000;
22const COST_ROUNDING: u64 = COST_SCALE - 1;
23pub const ALL_PROVIDERS: u64 = u64::MAX;
25pub const ALL_REGIONS: u64 = u64::MAX;
27pub const MAX_PROVIDER_ID: u16 = 63;
29
30#[repr(C)]
32#[derive(Clone, Copy, Debug, Eq, PartialEq)]
33#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
34pub struct KernelModel {
35 pub model_id: u32,
37 pub provider_id: u16,
39 pub quality_bps: u16,
41 pub risk_ceiling_bps: u16,
43 pub enabled: u8,
45 pub p95_latency_ms: u32,
47 pub capabilities: u64,
49 pub region_mask: u64,
51 pub input_cost_microunits_per_million_tokens: u64,
53 pub output_cost_microunits_per_million_tokens: u64,
55}
56
57#[repr(C)]
59#[derive(Clone, Copy, Debug, Eq, PartialEq)]
60#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
61pub struct KernelInput {
62 pub request_sequence: u64,
64 pub requested_model_id: u32,
66 pub input_tokens: u32,
68 pub output_tokens: u32,
70 pub business_value_microunits: i64,
72 pub budget_limit_microunits: u64,
74 pub risk_bps: u16,
76 pub confidence_bps: u16,
78 pub minimum_quality_bps: u16,
80 pub max_p95_latency_ms: u32,
82 pub required_capabilities: u64,
84 pub allowed_provider_mask: u64,
86 pub required_region_mask: u64,
88}
89
90#[repr(u8)]
92#[derive(Clone, Copy, Debug, Eq, PartialEq)]
93#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
94pub enum KernelAction {
95 ExecuteRequested = 1,
97 Substitute = 2,
99 Reject = 3,
101}
102
103impl std::fmt::Display for KernelAction {
104 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105 match self {
106 Self::ExecuteRequested => write!(f, "execute_requested"),
107 Self::Substitute => write!(f, "substitute"),
108 Self::Reject => write!(f, "reject"),
109 }
110 }
111}
112
113#[repr(u16)]
115#[derive(Clone, Copy, Debug, Eq, PartialEq)]
116#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
117pub enum KernelReason {
118 RequestedModelMaximizesUtility = 1,
120 AlternativeMaximizesUtility = 2,
122 RiskHardLimit = 100,
124 ConfidenceHardLimit = 101,
126 NoEnabledModel = 102,
128 QualityConstraint = 103,
130 LatencyConstraint = 104,
132 CapabilityConstraint = 105,
134 ProviderConstraint = 106,
136 RegionConstraint = 107,
138 BudgetConstraint = 108,
140 NonPositiveUtility = 109,
142 RiskCeilingConstraint = 110,
144}
145
146impl std::fmt::Display for KernelReason {
147 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148 match self {
149 Self::RequestedModelMaximizesUtility => write!(f, "requested_model_maximizes_utility"),
150 Self::AlternativeMaximizesUtility => write!(f, "alternative_maximizes_utility"),
151 Self::RiskHardLimit => write!(f, "risk_hard_limit"),
152 Self::ConfidenceHardLimit => write!(f, "confidence_hard_limit"),
153 Self::NoEnabledModel => write!(f, "no_enabled_model"),
154 Self::QualityConstraint => write!(f, "quality_constraint"),
155 Self::LatencyConstraint => write!(f, "latency_constraint"),
156 Self::CapabilityConstraint => write!(f, "capability_constraint"),
157 Self::ProviderConstraint => write!(f, "provider_constraint"),
158 Self::RegionConstraint => write!(f, "region_constraint"),
159 Self::BudgetConstraint => write!(f, "budget_constraint"),
160 Self::NonPositiveUtility => write!(f, "non_positive_utility"),
161 Self::RiskCeilingConstraint => write!(f, "risk_ceiling_constraint"),
162 }
163 }
164}
165
166#[repr(C)]
168#[derive(Clone, Copy, Debug, Eq, PartialEq)]
169#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
170pub struct KernelDecision {
171 pub request_sequence: u64,
173 pub action: KernelAction,
175 pub reason: KernelReason,
177 pub selected_model_id: u32,
179 pub selected_model_index: u16,
181 pub estimated_cost_microunits: u64,
183 pub expected_utility_microunits: i64,
185 pub counterfactual_model_id: u32,
187 pub counterfactual_utility_microunits: i64,
189 pub evaluated_models: u16,
191 pub eligible_models: u16,
193 pub policy_epoch: u64,
195 pub catalog_epoch: u64,
197}
198
199#[derive(Clone, Debug)]
205pub struct PolicySnapshot {
206 pub policy_epoch: u64,
207 pub catalog_epoch: u64,
208 pub hard_risk_limit_bps: u16,
209 pub minimum_confidence_bps: u16,
210 pub risk_penalty_multiplier_bps: u16,
211 pub latency_penalty_microunits_per_ms: u64,
212 max_quality_bps: u16,
213 max_p95_latency_ms: u32,
214 max_input_cost: u64,
215 max_output_cost: u64,
216 models: Arc<[KernelModel]>,
217}
218
219#[derive(Clone, Copy)]
220struct Candidate {
221 model_id: u32,
222 model_index: u16,
223 quality_bps: u16,
224 cost: u64,
225 utility: i64,
226}
227
228#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
230#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
231pub struct RejectionHistogram {
232 pub disabled: u16,
234 pub quality: u16,
236 pub risk_ceiling: u16,
238 pub latency: u16,
240 pub capability: u16,
242 pub provider: u16,
244 pub region: u16,
246 pub budget: u16,
248 pub utility: u16,
250}
251
252#[derive(Clone, Copy, Debug, PartialEq, Eq)]
254#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
255pub struct DecisionTrace {
256 pub rejections: RejectionHistogram,
257 pub evaluated_models: u16,
258 pub eligible_models: u16,
259}
260
261pub const MAX_BPS: u16 = 10_000;
263
264pub const MAX_RISK_PENALTY_MULTIPLIER_BPS: u16 = 50_000;
266
267#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
269pub enum PolicyError {
270 #[error("model catalog is empty")]
271 EmptyCatalog,
272 #[error("duplicate model_id {model_id}")]
273 DuplicateModelId { model_id: u32 },
274 #[error("model_id {model_id} has provider_id {provider_id} > MAX_PROVIDER_ID")]
275 InvalidProviderId { model_id: u32, provider_id: u16 },
276 #[error("no enabled models in catalog")]
277 NoEnabledModels,
278 #[error("{field} must be <= {max}, got {value}")]
279 OutOfRangeBps {
280 field: &'static str,
281 value: u16,
282 max: u16,
283 },
284}
285
286type RejectionCounts = RejectionHistogram;
287
288impl PolicySnapshot {
289 pub fn new_unchecked(
299 policy_epoch: u64,
300 catalog_epoch: u64,
301 hard_risk_limit_bps: u16,
302 minimum_confidence_bps: u16,
303 risk_penalty_multiplier_bps: u16,
304 latency_penalty_microunits_per_ms: u64,
305 models: Vec<KernelModel>,
306 ) -> Self {
307 let max_quality_bps = models
308 .iter()
309 .map(|model| model.quality_bps)
310 .max()
311 .unwrap_or_default();
312 let max_p95_latency_ms = models
313 .iter()
314 .map(|model| model.p95_latency_ms)
315 .max()
316 .unwrap_or_default();
317 let max_input_cost = models
318 .iter()
319 .map(|model| model.input_cost_microunits_per_million_tokens)
320 .max()
321 .unwrap_or_default();
322 let max_output_cost = models
323 .iter()
324 .map(|model| model.output_cost_microunits_per_million_tokens)
325 .max()
326 .unwrap_or_default();
327 Self {
328 policy_epoch,
329 catalog_epoch,
330 hard_risk_limit_bps,
331 minimum_confidence_bps,
332 risk_penalty_multiplier_bps,
333 latency_penalty_microunits_per_ms,
334 max_quality_bps,
335 max_p95_latency_ms,
336 max_input_cost,
337 max_output_cost,
338 models: Arc::from(models),
339 }
340 }
341
342 #[deprecated(
344 since = "0.3.9",
345 note = "use PolicySnapshot::try_new for validated snapshots or new_unchecked for tests"
346 )]
347 pub fn new(
348 policy_epoch: u64,
349 catalog_epoch: u64,
350 hard_risk_limit_bps: u16,
351 minimum_confidence_bps: u16,
352 risk_penalty_multiplier_bps: u16,
353 latency_penalty_microunits_per_ms: u64,
354 models: Vec<KernelModel>,
355 ) -> Self {
356 Self::new_unchecked(
357 policy_epoch,
358 catalog_epoch,
359 hard_risk_limit_bps,
360 minimum_confidence_bps,
361 risk_penalty_multiplier_bps,
362 latency_penalty_microunits_per_ms,
363 models,
364 )
365 }
366
367 pub fn models(&self) -> &[KernelModel] {
369 &self.models
370 }
371
372 pub fn validate(&self) -> Result<(), PolicyError> {
374 if self.hard_risk_limit_bps > MAX_BPS {
375 return Err(PolicyError::OutOfRangeBps {
376 field: "hard_risk_limit_bps",
377 value: self.hard_risk_limit_bps,
378 max: MAX_BPS,
379 });
380 }
381 if self.minimum_confidence_bps > MAX_BPS {
382 return Err(PolicyError::OutOfRangeBps {
383 field: "minimum_confidence_bps",
384 value: self.minimum_confidence_bps,
385 max: MAX_BPS,
386 });
387 }
388 if self.risk_penalty_multiplier_bps > MAX_RISK_PENALTY_MULTIPLIER_BPS {
389 return Err(PolicyError::OutOfRangeBps {
390 field: "risk_penalty_multiplier_bps",
391 value: self.risk_penalty_multiplier_bps,
392 max: MAX_RISK_PENALTY_MULTIPLIER_BPS,
393 });
394 }
395 if self.models.is_empty() {
396 return Err(PolicyError::EmptyCatalog);
397 }
398 let mut seen = std::collections::HashSet::new();
399 let mut any_enabled = false;
400 for model in self.models.iter() {
401 if !seen.insert(model.model_id) {
402 return Err(PolicyError::DuplicateModelId {
403 model_id: model.model_id,
404 });
405 }
406 if model.provider_id > MAX_PROVIDER_ID {
407 return Err(PolicyError::InvalidProviderId {
408 model_id: model.model_id,
409 provider_id: model.provider_id,
410 });
411 }
412 if model.quality_bps > MAX_BPS {
413 return Err(PolicyError::OutOfRangeBps {
414 field: "model.quality_bps",
415 value: model.quality_bps,
416 max: MAX_BPS,
417 });
418 }
419 if model.risk_ceiling_bps > MAX_BPS {
420 return Err(PolicyError::OutOfRangeBps {
421 field: "model.risk_ceiling_bps",
422 value: model.risk_ceiling_bps,
423 max: MAX_BPS,
424 });
425 }
426 if model.enabled != 0 {
427 any_enabled = true;
428 }
429 }
430 if !any_enabled {
431 return Err(PolicyError::NoEnabledModels);
432 }
433 Ok(())
434 }
435
436 pub fn try_new(
438 policy_epoch: u64,
439 catalog_epoch: u64,
440 hard_risk_limit_bps: u16,
441 minimum_confidence_bps: u16,
442 risk_penalty_multiplier_bps: u16,
443 latency_penalty_microunits_per_ms: u64,
444 models: Vec<KernelModel>,
445 ) -> Result<Self, PolicyError> {
446 let snapshot = Self::new_unchecked(
447 policy_epoch,
448 catalog_epoch,
449 hard_risk_limit_bps,
450 minimum_confidence_bps,
451 risk_penalty_multiplier_bps,
452 latency_penalty_microunits_per_ms,
453 models,
454 );
455 snapshot.validate()?;
456 Ok(snapshot)
457 }
458
459 pub fn prescribe_batch(&self, inputs: &[KernelInput]) -> Vec<KernelDecision> {
461 inputs.iter().map(|&input| self.prescribe(input)).collect()
462 }
463
464 pub fn prescribe_with_trace(&self, input: KernelInput) -> (KernelDecision, DecisionTrace) {
466 let (decision, rejections) = self.prescribe_inner(input);
467 let trace = DecisionTrace {
468 rejections,
469 evaluated_models: decision.evaluated_models,
470 eligible_models: decision.eligible_models,
471 };
472 (decision, trace)
473 }
474
475 #[must_use]
486 pub fn prescribe(&self, input: KernelInput) -> KernelDecision {
487 self.prescribe_inner(input).0
488 }
489
490 #[must_use]
495 pub fn utility_for_model(&self, input: KernelInput, model_id: u32) -> Option<i64> {
496 if input.risk_bps >= self.hard_risk_limit_bps {
497 return None;
498 }
499 if input.confidence_bps < self.minimum_confidence_bps {
500 return None;
501 }
502 let model = self.models.iter().find(|m| m.model_id == model_id)?;
503
504 if model.enabled == 0 {
505 return None;
506 }
507 if model.quality_bps < input.minimum_quality_bps {
508 return None;
509 }
510 if input.max_p95_latency_ms > 0 && model.p95_latency_ms > input.max_p95_latency_ms {
511 return None;
512 }
513 if model.capabilities & input.required_capabilities != input.required_capabilities {
514 return None;
515 }
516 if model.provider_id > MAX_PROVIDER_ID {
517 return None;
518 }
519 if input.allowed_provider_mask != ALL_PROVIDERS
520 && input.allowed_provider_mask & (1_u64 << model.provider_id) == 0
521 {
522 return None;
523 }
524 if input.required_region_mask != 0 && model.region_mask & input.required_region_mask == 0 {
525 return None;
526 }
527 if input.risk_bps > model.risk_ceiling_bps {
528 return None;
529 }
530
531 let all_costs_fit = self.all_costs_fit_u64(input.input_tokens, input.output_tokens);
532 let cost = if all_costs_fit {
533 model_cost_fast(model, input.input_tokens, input.output_tokens)
534 } else {
535 model_cost_reference(model, input.input_tokens, input.output_tokens)
536 };
537 if cost > input.budget_limit_microunits {
538 return None;
539 }
540
541 let value = input.business_value_microunits.max(0) as u64;
542 let confidence_bps = u64::from(input.confidence_bps);
543 let quality_prefix = value.checked_mul(confidence_bps).filter(|prefix| {
544 prefix
545 .checked_mul(u64::from(self.max_quality_bps))
546 .is_some()
547 });
548 let risk_penalty = scaled_term_exact(
549 value,
550 u64::from(input.risk_bps),
551 u64::from(self.risk_penalty_multiplier_bps),
552 );
553 let all_latencies_fit = u64::from(self.max_p95_latency_ms)
554 .checked_mul(self.latency_penalty_microunits_per_ms)
555 .is_some();
556 let quality_adjusted = quality_prefix.map_or_else(
557 || scaled_term_reference(value, confidence_bps, u64::from(model.quality_bps)),
558 |prefix| {
559 i128::from(prefix.wrapping_mul(u64::from(model.quality_bps)) / SCALED_BASIS_POINTS)
560 },
561 );
562 let latency_penalty = if all_latencies_fit {
563 i128::from(
564 u64::from(model.p95_latency_ms)
565 .wrapping_mul(self.latency_penalty_microunits_per_ms),
566 )
567 } else {
568 i128::from(model.p95_latency_ms) * i128::from(self.latency_penalty_microunits_per_ms)
569 };
570 let utility =
571 clamp_i128_to_i64(quality_adjusted - risk_penalty - i128::from(cost) - latency_penalty);
572 if utility <= 0 {
573 return None;
574 }
575 Some(utility)
576 }
577
578 fn prescribe_inner(&self, input: KernelInput) -> (KernelDecision, RejectionHistogram) {
579 if input.risk_bps >= self.hard_risk_limit_bps {
580 return self.reject(input, KernelReason::RiskHardLimit, 0, 0);
581 }
582 if input.confidence_bps < self.minimum_confidence_bps {
583 return self.reject(input, KernelReason::ConfidenceHardLimit, 0, 0);
584 }
585
586 let mut best: Option<Candidate> = None;
587 let mut second: Option<Candidate> = None;
588 let mut eligible_models = 0_u16;
589 let mut rejected = RejectionCounts::default();
590
591 let value = input.business_value_microunits.max(0) as u64;
592 let confidence_bps = u64::from(input.confidence_bps);
593 let quality_prefix = value.checked_mul(confidence_bps).filter(|prefix| {
594 prefix
595 .checked_mul(u64::from(self.max_quality_bps))
596 .is_some()
597 });
598 let risk_penalty = scaled_term_exact(
599 value,
600 u64::from(input.risk_bps),
601 u64::from(self.risk_penalty_multiplier_bps),
602 );
603 let all_costs_fit = self.all_costs_fit_u64(input.input_tokens, input.output_tokens);
604 let all_latencies_fit = u64::from(self.max_p95_latency_ms)
605 .checked_mul(self.latency_penalty_microunits_per_ms)
606 .is_some();
607
608 let check_provider = input.allowed_provider_mask != ALL_PROVIDERS;
609 let check_region = input.required_region_mask != 0;
610 let check_latency = input.max_p95_latency_ms > 0;
611 let latency_pen_per_ms = self.latency_penalty_microunits_per_ms;
612
613 for (index, model) in self.models.iter().enumerate() {
614 if model.enabled == 0 {
616 rejected.disabled += 1;
617 continue;
618 }
619 if model.quality_bps < input.minimum_quality_bps {
620 rejected.quality += 1;
621 continue;
622 }
623 if check_latency && model.p95_latency_ms > input.max_p95_latency_ms {
624 rejected.latency += 1;
625 continue;
626 }
627 if model.capabilities & input.required_capabilities != input.required_capabilities {
628 rejected.capability += 1;
629 continue;
630 }
631 if model.provider_id > MAX_PROVIDER_ID {
634 rejected.provider += 1;
635 continue;
636 }
637 if check_provider && input.allowed_provider_mask & (1_u64 << model.provider_id) == 0 {
638 rejected.provider += 1;
639 continue;
640 }
641 if check_region && model.region_mask & input.required_region_mask == 0 {
642 rejected.region += 1;
643 continue;
644 }
645 if input.risk_bps > model.risk_ceiling_bps {
646 rejected.risk_ceiling += 1;
647 continue;
648 }
649
650 let cost = if all_costs_fit {
651 model_cost_fast(model, input.input_tokens, input.output_tokens)
652 } else {
653 model_cost_reference(model, input.input_tokens, input.output_tokens)
654 };
655 if cost > input.budget_limit_microunits {
656 rejected.budget += 1;
657 continue;
658 }
659
660 let quality_adjusted = quality_prefix.map_or_else(
661 || scaled_term_reference(value, confidence_bps, u64::from(model.quality_bps)),
662 |prefix| {
663 i128::from(
665 prefix.wrapping_mul(u64::from(model.quality_bps)) / SCALED_BASIS_POINTS,
666 )
667 },
668 );
669 let latency_penalty = if all_latencies_fit {
670 i128::from(u64::from(model.p95_latency_ms).wrapping_mul(latency_pen_per_ms))
671 } else {
672 i128::from(model.p95_latency_ms) * i128::from(latency_pen_per_ms)
673 };
674 let utility = clamp_i128_to_i64(
675 quality_adjusted - risk_penalty - i128::from(cost) - latency_penalty,
676 );
677
678 if utility <= 0 {
679 rejected.utility += 1;
680 continue;
681 }
682 eligible_models = eligible_models.saturating_add(1);
683
684 let candidate = Candidate {
685 model_id: model.model_id,
686 model_index: u16::try_from(index).unwrap_or(u16::MAX),
687 quality_bps: model.quality_bps,
688 cost,
689 utility,
690 };
691 match best {
692 None => best = Some(candidate),
693 Some(current) if candidate_better(candidate, current) => {
694 second = best;
695 best = Some(candidate);
696 }
697 _ => {
698 if second.is_none_or(|s| candidate_better(candidate, s)) {
699 second = Some(candidate);
700 }
701 }
702 }
703 }
704
705 let evaluated_models = u16::try_from(self.models.len()).unwrap_or(u16::MAX);
706 let Some(best) = best else {
707 return self.reject(
708 input,
709 dominant_rejection_reason(&rejected),
710 evaluated_models,
711 eligible_models,
712 );
713 };
714 let action = if best.model_id == input.requested_model_id {
715 KernelAction::ExecuteRequested
716 } else {
717 KernelAction::Substitute
718 };
719 (
720 KernelDecision {
721 request_sequence: input.request_sequence,
722 action,
723 reason: if action == KernelAction::ExecuteRequested {
724 KernelReason::RequestedModelMaximizesUtility
725 } else {
726 KernelReason::AlternativeMaximizesUtility
727 },
728 selected_model_id: best.model_id,
729 selected_model_index: best.model_index,
730 estimated_cost_microunits: best.cost,
731 expected_utility_microunits: best.utility,
732 counterfactual_model_id: second.map_or(0, |candidate| candidate.model_id),
733 counterfactual_utility_microunits: second.map_or(0, |candidate| candidate.utility),
734 evaluated_models,
735 eligible_models,
736 policy_epoch: self.policy_epoch,
737 catalog_epoch: self.catalog_epoch,
738 },
739 rejected,
740 )
741 }
742
743 #[inline]
744 fn all_costs_fit_u64(&self, input_tokens: u32, output_tokens: u32) -> bool {
745 let input = u64::from(input_tokens)
746 .checked_mul(self.max_input_cost)
747 .and_then(|value| value.checked_add(COST_ROUNDING));
748 let output = u64::from(output_tokens)
749 .checked_mul(self.max_output_cost)
750 .and_then(|value| value.checked_add(COST_ROUNDING));
751 input
752 .zip(output)
753 .is_some_and(|(input, output)| input.checked_add(output).is_some())
754 }
755
756 fn reject(
757 &self,
758 input: KernelInput,
759 reason: KernelReason,
760 evaluated_models: u16,
761 eligible_models: u16,
762 ) -> (KernelDecision, RejectionHistogram) {
763 (
764 KernelDecision {
765 request_sequence: input.request_sequence,
766 action: KernelAction::Reject,
767 reason,
768 selected_model_id: 0,
769 selected_model_index: u16::MAX,
770 estimated_cost_microunits: 0,
771 expected_utility_microunits: 0,
772 counterfactual_model_id: 0,
773 counterfactual_utility_microunits: 0,
774 evaluated_models,
775 eligible_models,
776 policy_epoch: self.policy_epoch,
777 catalog_epoch: self.catalog_epoch,
778 },
779 RejectionHistogram::default(),
780 )
781 }
782}
783
784#[inline(always)]
785fn model_cost_fast(model: &KernelModel, input_tokens: u32, output_tokens: u32) -> u64 {
786 let input = u64::from(input_tokens)
787 .wrapping_mul(model.input_cost_microunits_per_million_tokens)
788 .wrapping_add(COST_ROUNDING)
789 / COST_SCALE;
790 let output = u64::from(output_tokens)
791 .wrapping_mul(model.output_cost_microunits_per_million_tokens)
792 .wrapping_add(COST_ROUNDING)
793 / COST_SCALE;
794 input.wrapping_add(output)
795}
796
797fn model_cost_reference(model: &KernelModel, input_tokens: u32, output_tokens: u32) -> u64 {
798 let input = u128::from(input_tokens)
799 .saturating_mul(u128::from(model.input_cost_microunits_per_million_tokens))
800 .saturating_add(u128::from(COST_ROUNDING))
801 / u128::from(COST_SCALE);
802 let output = u128::from(output_tokens)
803 .saturating_mul(u128::from(model.output_cost_microunits_per_million_tokens))
804 .saturating_add(u128::from(COST_ROUNDING))
805 / u128::from(COST_SCALE);
806 u64::try_from(input.saturating_add(output)).unwrap_or(u64::MAX)
807}
808
809#[inline]
810fn scaled_term_exact(value: u64, first_bps: u64, second_bps: u64) -> i128 {
811 value
812 .checked_mul(first_bps)
813 .and_then(|value| value.checked_mul(second_bps))
814 .map_or_else(
815 || scaled_term_reference(value, first_bps, second_bps),
816 |numerator| i128::from(numerator / SCALED_BASIS_POINTS),
817 )
818}
819
820#[inline]
821fn scaled_term_reference(value: u64, first_bps: u64, second_bps: u64) -> i128 {
822 i128::from(value) * i128::from(first_bps) * i128::from(second_bps)
823 / i128::from(SCALED_BASIS_POINTS)
824}
825
826#[inline(always)]
828fn candidate_better(left: Candidate, right: Candidate) -> bool {
829 left.utility > right.utility
830 || (left.utility == right.utility && left.cost < right.cost)
831 || (left.utility == right.utility
832 && left.cost == right.cost
833 && left.quality_bps > right.quality_bps)
834 || (left.utility == right.utility
835 && left.cost == right.cost
836 && left.quality_bps == right.quality_bps
837 && left.model_id < right.model_id)
838}
839
840fn dominant_rejection_reason(counts: &RejectionCounts) -> KernelReason {
841 let candidates = [
842 (counts.capability, KernelReason::CapabilityConstraint),
843 (counts.region, KernelReason::RegionConstraint),
844 (counts.provider, KernelReason::ProviderConstraint),
845 (counts.quality, KernelReason::QualityConstraint),
846 (counts.risk_ceiling, KernelReason::RiskCeilingConstraint),
847 (counts.latency, KernelReason::LatencyConstraint),
848 (counts.budget, KernelReason::BudgetConstraint),
849 (counts.utility, KernelReason::NonPositiveUtility),
850 (counts.disabled, KernelReason::NoEnabledModel),
851 ];
852 candidates
853 .into_iter()
854 .max_by_key(|(count, _)| *count)
855 .filter(|(count, _)| *count > 0)
856 .map_or(KernelReason::NoEnabledModel, |(_, reason)| reason)
857}
858
859fn clamp_i128_to_i64(value: i128) -> i64 {
860 value.clamp(i128::from(i64::MIN), i128::from(i64::MAX)) as i64
861}
862
863#[cfg(test)]
864mod tests {
865 use std::{hint::black_box, time::Instant};
866
867 use proptest::prelude::*;
868
869 use super::*;
870
871 const TOOLS: u64 = 1 << 0;
872 const REGION_EU: u64 = 1 << 0;
873
874 fn snapshot() -> PolicySnapshot {
875 PolicySnapshot::new_unchecked(
876 7,
877 11,
878 9_600,
879 5_500,
880 10_000,
881 2,
882 vec![
883 KernelModel {
884 model_id: 10,
885 provider_id: 0,
886 quality_bps: 7_500,
887 risk_ceiling_bps: 9_500,
888 enabled: 1,
889 p95_latency_ms: 180,
890 capabilities: TOOLS,
891 region_mask: REGION_EU,
892 input_cost_microunits_per_million_tokens: 150_000,
893 output_cost_microunits_per_million_tokens: 600_000,
894 },
895 KernelModel {
896 model_id: 20,
897 provider_id: 1,
898 quality_bps: 9_500,
899 risk_ceiling_bps: 9_500,
900 enabled: 1,
901 p95_latency_ms: 450,
902 capabilities: TOOLS,
903 region_mask: REGION_EU,
904 input_cost_microunits_per_million_tokens: 2_500_000,
905 output_cost_microunits_per_million_tokens: 10_000_000,
906 },
907 ],
908 )
909 }
910
911 fn input() -> KernelInput {
912 KernelInput {
913 request_sequence: 1,
914 requested_model_id: 20,
915 input_tokens: 2_000,
916 output_tokens: 500,
917 business_value_microunits: 100_000_000,
918 budget_limit_microunits: 20_000_000,
919 risk_bps: 1_000,
920 confidence_bps: 9_000,
921 minimum_quality_bps: 7_000,
922 max_p95_latency_ms: 1_000,
923 required_capabilities: TOOLS,
924 allowed_provider_mask: ALL_PROVIDERS,
925 required_region_mask: REGION_EU,
926 }
927 }
928
929 fn prescribe_reference(snapshot: &PolicySnapshot, input: KernelInput) -> KernelDecision {
930 if input.risk_bps >= snapshot.hard_risk_limit_bps {
931 return snapshot.reject(input, KernelReason::RiskHardLimit, 0, 0).0;
932 }
933 if input.confidence_bps < snapshot.minimum_confidence_bps {
934 return snapshot
935 .reject(input, KernelReason::ConfidenceHardLimit, 0, 0)
936 .0;
937 }
938
939 let mut best: Option<Candidate> = None;
940 let mut second: Option<Candidate> = None;
941 let mut eligible_models = 0_u16;
942 let mut rejected = RejectionCounts::default();
943 let value = input.business_value_microunits.max(0) as u64;
944 let risk_penalty = scaled_term_reference(
945 value,
946 u64::from(input.risk_bps),
947 u64::from(snapshot.risk_penalty_multiplier_bps),
948 );
949
950 for (index, model) in snapshot.models.iter().enumerate() {
951 if model.enabled == 0 {
952 rejected.disabled += 1;
953 continue;
954 }
955 if model.quality_bps < input.minimum_quality_bps {
956 rejected.quality += 1;
957 continue;
958 }
959 if input.max_p95_latency_ms > 0 && model.p95_latency_ms > input.max_p95_latency_ms {
960 rejected.latency += 1;
961 continue;
962 }
963 if model.capabilities & input.required_capabilities != input.required_capabilities {
964 rejected.capability += 1;
965 continue;
966 }
967 if model.provider_id > MAX_PROVIDER_ID {
968 rejected.provider += 1;
969 continue;
970 }
971 if input.allowed_provider_mask != ALL_PROVIDERS
972 && input.allowed_provider_mask & (1_u64 << model.provider_id) == 0
973 {
974 rejected.provider += 1;
975 continue;
976 }
977 if input.required_region_mask != 0
978 && model.region_mask & input.required_region_mask == 0
979 {
980 rejected.region += 1;
981 continue;
982 }
983 if input.risk_bps > model.risk_ceiling_bps {
984 rejected.risk_ceiling += 1;
985 continue;
986 }
987
988 let cost = model_cost_reference(model, input.input_tokens, input.output_tokens);
989 if cost > input.budget_limit_microunits {
990 rejected.budget += 1;
991 continue;
992 }
993 let quality_adjusted = scaled_term_reference(
994 value,
995 u64::from(input.confidence_bps),
996 u64::from(model.quality_bps),
997 );
998 let latency_penalty = i128::from(model.p95_latency_ms)
999 * i128::from(snapshot.latency_penalty_microunits_per_ms);
1000 let utility = clamp_i128_to_i64(
1001 quality_adjusted - risk_penalty - i128::from(cost) - latency_penalty,
1002 );
1003 if utility <= 0 {
1004 rejected.utility += 1;
1005 continue;
1006 }
1007 eligible_models = eligible_models.saturating_add(1);
1008 let candidate = Candidate {
1009 model_id: model.model_id,
1010 model_index: u16::try_from(index).unwrap_or(u16::MAX),
1011 quality_bps: model.quality_bps,
1012 cost,
1013 utility,
1014 };
1015 if best.is_none_or(|current| candidate_better(candidate, current)) {
1016 second = best;
1017 best = Some(candidate);
1018 } else if second.is_none_or(|current| candidate_better(candidate, current)) {
1019 second = Some(candidate);
1020 }
1021 }
1022
1023 let evaluated_models = u16::try_from(snapshot.models.len()).unwrap_or(u16::MAX);
1024 let Some(best) = best else {
1025 return snapshot
1026 .reject(
1027 input,
1028 dominant_rejection_reason(&rejected),
1029 evaluated_models,
1030 eligible_models,
1031 )
1032 .0;
1033 };
1034 let action = if best.model_id == input.requested_model_id {
1035 KernelAction::ExecuteRequested
1036 } else {
1037 KernelAction::Substitute
1038 };
1039 KernelDecision {
1040 request_sequence: input.request_sequence,
1041 action,
1042 reason: if action == KernelAction::ExecuteRequested {
1043 KernelReason::RequestedModelMaximizesUtility
1044 } else {
1045 KernelReason::AlternativeMaximizesUtility
1046 },
1047 selected_model_id: best.model_id,
1048 selected_model_index: best.model_index,
1049 estimated_cost_microunits: best.cost,
1050 expected_utility_microunits: best.utility,
1051 counterfactual_model_id: second.map_or(0, |candidate| candidate.model_id),
1052 counterfactual_utility_microunits: second.map_or(0, |candidate| candidate.utility),
1053 evaluated_models,
1054 eligible_models,
1055 policy_epoch: snapshot.policy_epoch,
1056 catalog_epoch: snapshot.catalog_epoch,
1057 }
1058 }
1059
1060 #[test]
1061 fn prescribes_maximum_utility_not_minimum_price() {
1062 let decision = snapshot().prescribe(input());
1063 assert_eq!(decision.action, KernelAction::ExecuteRequested);
1064 assert_eq!(decision.selected_model_id, 20);
1065 assert_eq!(decision.counterfactual_model_id, 10);
1066 assert!(decision.expected_utility_microunits > decision.counterfactual_utility_microunits);
1067 }
1068
1069 #[test]
1070 fn hard_budget_can_prescribe_substitution() {
1071 let mut request = input();
1072 request.budget_limit_microunits = 1_000;
1073 let decision = snapshot().prescribe(request);
1074 assert_eq!(decision.action, KernelAction::Substitute);
1075 assert_eq!(decision.selected_model_id, 10);
1076 }
1077
1078 fn base_model(model_id: u32, enabled: u8) -> KernelModel {
1079 KernelModel {
1080 model_id,
1081 provider_id: 0,
1082 quality_bps: 8_000,
1083 risk_ceiling_bps: 9_500,
1084 enabled,
1085 p95_latency_ms: 200,
1086 capabilities: 0,
1087 region_mask: ALL_REGIONS,
1088 input_cost_microunits_per_million_tokens: 100,
1089 output_cost_microunits_per_million_tokens: 400,
1090 }
1091 }
1092
1093 #[test]
1094 fn policy_error_empty_catalog() {
1095 let snap = PolicySnapshot::new_unchecked(1, 1, 9_600, 5_500, 3_500, 0, vec![]);
1096 assert_eq!(snap.validate(), Err(PolicyError::EmptyCatalog));
1097 assert!(matches!(
1098 PolicySnapshot::try_new(1, 1, 9_600, 5_500, 3_500, 0, vec![]),
1099 Err(PolicyError::EmptyCatalog)
1100 ));
1101 }
1102
1103 #[test]
1104 fn policy_error_duplicate_model_id() {
1105 let snap = PolicySnapshot::new_unchecked(
1106 1,
1107 1,
1108 9_600,
1109 5_500,
1110 3_500,
1111 0,
1112 vec![base_model(1, 1), base_model(1, 1)],
1113 );
1114 assert_eq!(
1115 snap.validate(),
1116 Err(PolicyError::DuplicateModelId { model_id: 1 })
1117 );
1118 }
1119
1120 #[test]
1121 fn policy_error_invalid_provider_id() {
1122 let mut model = base_model(1, 1);
1123 model.provider_id = MAX_PROVIDER_ID + 1;
1124 let snap = PolicySnapshot::new_unchecked(1, 1, 9_600, 5_500, 3_500, 0, vec![model]);
1125 assert_eq!(
1126 snap.validate(),
1127 Err(PolicyError::InvalidProviderId {
1128 model_id: 1,
1129 provider_id: MAX_PROVIDER_ID + 1,
1130 })
1131 );
1132 }
1133
1134 #[test]
1135 fn policy_error_no_enabled_models() {
1136 let snap = PolicySnapshot::new_unchecked(
1137 1,
1138 1,
1139 9_600,
1140 5_500,
1141 3_500,
1142 0,
1143 vec![base_model(1, 0), base_model(2, 0)],
1144 );
1145 assert_eq!(snap.validate(), Err(PolicyError::NoEnabledModels));
1146 }
1147
1148 #[test]
1149 fn policy_error_out_of_range_bps() {
1150 let models = vec![base_model(1, 1)];
1151 assert!(matches!(
1152 PolicySnapshot::try_new(1, 1, 10_001, 5_500, 3_500, 0, models.clone()),
1153 Err(PolicyError::OutOfRangeBps { .. })
1154 ));
1155 assert!(matches!(
1156 PolicySnapshot::try_new(1, 1, 9_600, 10_001, 3_500, 0, models.clone()),
1157 Err(PolicyError::OutOfRangeBps { .. })
1158 ));
1159 assert!(matches!(
1160 PolicySnapshot::try_new(1, 1, 9_600, 5_500, 50_001, 0, models.clone()),
1161 Err(PolicyError::OutOfRangeBps { .. })
1162 ));
1163 let mut bad_quality = base_model(2, 1);
1164 bad_quality.quality_bps = 10_001;
1165 assert!(matches!(
1166 PolicySnapshot::try_new(1, 1, 9_600, 5_500, 3_500, 0, vec![bad_quality]),
1167 Err(PolicyError::OutOfRangeBps { .. })
1168 ));
1169 }
1170
1171 #[test]
1172 fn utility_for_model_matches_eligible_catalog_entry() {
1173 let snap = snapshot();
1174 let input = input();
1175 let utility = snap.utility_for_model(input, 20);
1176 assert!(utility.is_some());
1177 assert_eq!(
1178 utility,
1179 Some(snap.prescribe(input).expected_utility_microunits)
1180 );
1181 }
1182
1183 #[test]
1184 fn utility_for_model_none_for_missing_id() {
1185 let snap = snapshot();
1186 assert!(snap.utility_for_model(input(), 999).is_none());
1187 }
1188
1189 #[test]
1190 fn prescribe_batch_matches_individual() {
1191 let snap = snapshot();
1192 let inputs = [
1193 input(),
1194 KernelInput {
1195 request_sequence: 2,
1196 requested_model_id: 10,
1197 input_tokens: 500,
1198 output_tokens: 100,
1199 business_value_microunits: 50_000_000,
1200 budget_limit_microunits: 5_000_000,
1201 risk_bps: 500,
1202 confidence_bps: 9_500,
1203 minimum_quality_bps: 7_000,
1204 max_p95_latency_ms: 500,
1205 required_capabilities: TOOLS,
1206 allowed_provider_mask: ALL_PROVIDERS,
1207 required_region_mask: REGION_EU,
1208 },
1209 ];
1210 let batch = snap.prescribe_batch(&inputs);
1211 assert_eq!(batch.len(), inputs.len());
1212 for (i, &inp) in inputs.iter().enumerate() {
1213 assert_eq!(batch[i], snap.prescribe(inp));
1214 }
1215 }
1216
1217 #[test]
1218 fn hard_constraints_fail_closed() {
1219 let mut request = input();
1220 request.risk_bps = 9_900;
1221 let decision = snapshot().prescribe(request);
1222 assert_eq!(decision.action, KernelAction::Reject);
1223 assert_eq!(decision.reason, KernelReason::RiskHardLimit);
1224
1225 request.risk_bps = 1_000;
1226 request.required_capabilities = 1 << 9;
1227 let decision = snapshot().prescribe(request);
1228 assert_eq!(decision.action, KernelAction::Reject);
1229 assert_eq!(decision.reason, KernelReason::CapabilityConstraint);
1230 }
1231
1232 #[test]
1233 fn extreme_inputs_saturate_without_panicking() {
1234 let mut request = input();
1235 request.input_tokens = u32::MAX;
1236 request.output_tokens = u32::MAX;
1237 request.business_value_microunits = i64::MAX;
1238 request.budget_limit_microunits = u64::MAX;
1239 let decision = snapshot().prescribe(request);
1240 assert_eq!(decision.request_sequence, request.request_sequence);
1241 }
1242
1243 #[test]
1244 fn exact_fast_path_preserves_single_rounding_step() {
1245 assert_eq!(scaled_term_reference(2, 5_001, 9_999), 1);
1246 assert_eq!(scaled_term_exact(2, 5_001, 9_999), 1);
1247 }
1248
1249 proptest! {
1250 #[test]
1251 fn optimized_scaled_term_matches_i128_reference(
1252 value in any::<u64>(),
1253 first_bps in any::<u16>(),
1254 second_bps in any::<u16>(),
1255 ) {
1256 prop_assert_eq!(
1257 scaled_term_exact(value, u64::from(first_bps), u64::from(second_bps)),
1258 scaled_term_reference(value, u64::from(first_bps), u64::from(second_bps)),
1259 );
1260 }
1261
1262 #[test]
1263 fn optimized_cost_matches_u128_reference_when_guard_admits(
1264 input_tokens in any::<u32>(),
1265 output_tokens in any::<u32>(),
1266 input_price in any::<u64>(),
1267 output_price in any::<u64>(),
1268 ) {
1269 let model = KernelModel {
1270 model_id: 1,
1271 provider_id: 0,
1272 quality_bps: 10_000,
1273 risk_ceiling_bps: u16::MAX,
1274 enabled: 1,
1275 p95_latency_ms: 1,
1276 capabilities: 0,
1277 region_mask: ALL_REGIONS,
1278 input_cost_microunits_per_million_tokens: input_price,
1279 output_cost_microunits_per_million_tokens: output_price,
1280 };
1281 let snapshot = PolicySnapshot::new_unchecked(1, 1, u16::MAX, 0, 0, 0, vec![model]);
1282 if snapshot.all_costs_fit_u64(input_tokens, output_tokens) {
1283 prop_assert_eq!(
1284 model_cost_fast(&model, input_tokens, output_tokens),
1285 model_cost_reference(&model, input_tokens, output_tokens),
1286 );
1287 }
1288 }
1289
1290 #[test]
1291 fn optimized_kernel_matches_reference_decision(
1292 input_tokens in any::<u32>(),
1293 output_tokens in any::<u32>(),
1294 value in any::<i64>(),
1295 budget in any::<u64>(),
1296 risk in any::<u16>(),
1297 confidence in any::<u16>(),
1298 minimum_quality in any::<u16>(),
1299 maximum_latency in any::<u32>(),
1300 provider_mask in any::<u64>(),
1301 region_mask in any::<u64>(),
1302 ) {
1303 let mut request = input();
1304 request.input_tokens = input_tokens;
1305 request.output_tokens = output_tokens;
1306 request.business_value_microunits = value;
1307 request.budget_limit_microunits = budget;
1308 request.risk_bps = risk;
1309 request.confidence_bps = confidence;
1310 request.minimum_quality_bps = minimum_quality;
1311 request.max_p95_latency_ms = maximum_latency;
1312 request.allowed_provider_mask = provider_mask;
1313 request.required_region_mask = region_mask;
1314 let snapshot = snapshot();
1315 prop_assert_eq!(snapshot.prescribe(request), prescribe_reference(&snapshot, request));
1316 }
1317
1318 #[test]
1319 fn arbitrary_inputs_never_bypass_provider_fence(
1320 input_tokens in any::<u32>(),
1321 output_tokens in any::<u32>(),
1322 value in any::<i64>(),
1323 budget in any::<u64>(),
1324 risk in any::<u16>(),
1325 confidence in any::<u16>(),
1326 ) {
1327 let mut request = input();
1328 request.input_tokens = input_tokens;
1329 request.output_tokens = output_tokens;
1330 request.business_value_microunits = value;
1331 request.budget_limit_microunits = budget;
1332 request.risk_bps = risk;
1333 request.confidence_bps = confidence;
1334 request.allowed_provider_mask = 0;
1335 let decision = snapshot().prescribe(request);
1336 prop_assert_eq!(decision.action, KernelAction::Reject);
1337 }
1338 }
1339
1340 #[test]
1341 fn provider_id_above_64_rejected_even_with_all_providers() {
1342 let models = vec![KernelModel {
1343 model_id: 1,
1344 provider_id: 65,
1345 quality_bps: 9500,
1346 risk_ceiling_bps: 10000,
1347 enabled: 1,
1348 p95_latency_ms: 500,
1349 capabilities: 0,
1350 region_mask: ALL_REGIONS,
1351 input_cost_microunits_per_million_tokens: 100,
1352 output_cost_microunits_per_million_tokens: 400,
1353 }];
1354 let snapshot = PolicySnapshot::new_unchecked(1, 1, 9600, 5500, 3500, 0, models);
1355 let mut request = input();
1356 request.allowed_provider_mask = ALL_PROVIDERS;
1357 let decision = snapshot.prescribe(request);
1358 assert_eq!(
1359 decision.action,
1360 KernelAction::Reject,
1361 "provider_id >= 64 must be rejected even when mask is ALL_PROVIDERS"
1362 );
1363 }
1364
1365 #[test]
1366 fn provider_id_below_64_accepted_with_all_providers() {
1367 let mut request = input();
1368 request.allowed_provider_mask = ALL_PROVIDERS;
1369 let decision = snapshot().prescribe(request);
1370 assert_ne!(
1371 decision.action,
1372 KernelAction::Reject,
1373 "provider_id < 64 with ALL_PROVIDERS should not be rejected by provider fence"
1374 );
1375 }
1376
1377 #[test]
1378 #[ignore = "release-only kernel guard"]
1379 fn prescriptive_kernel_latency_guard() {
1380 let snapshot = snapshot();
1381 let base = input();
1382 let iterations = 1_000_000_u64;
1383 let started = Instant::now();
1384 for sequence in 0..iterations {
1385 let mut request = base;
1386 request.request_sequence = sequence;
1387 request.input_tokens = 1_000 + (sequence % 1_024) as u32;
1388 black_box(snapshot.prescribe(black_box(request)));
1389 }
1390 let average_ns = started.elapsed().as_nanos() / u128::from(iterations);
1391 assert!(
1392 average_ns < 2_000,
1393 "prescriptive kernel exceeded 2us average guard: {average_ns}ns"
1394 );
1395 }
1396}