use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
pub const OP_ID: &str = "vyre-primitives::math::conformal_threshold";
#[must_use]
pub fn conformal_threshold(scores_sorted: &str, q_hat: &str, n: u32, k: u32) -> Program {
if n == 0 {
return crate::invalid_output_program(
OP_ID,
q_hat,
DataType::U32,
format!("Fix: conformal_threshold requires n > 0, got {n}."),
);
}
if k == 0 || k > n {
return crate::invalid_output_program(
OP_ID,
q_hat,
DataType::U32,
format!("Fix: conformal_threshold k must satisfy 1 <= k <= n, got k={k}, n={n}."),
);
}
let t = Expr::InvocationId { axis: 0 };
let body = vec![Node::if_then(
Expr::eq(t.clone(), Expr::u32(0)),
vec![Node::store(
q_hat,
Expr::u32(0),
Expr::load(scores_sorted, Expr::u32(k - 1)),
)],
)];
Program::wrapped(
vec![
BufferDecl::storage(scores_sorted, 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(n),
BufferDecl::storage(q_hat, 1, BufferAccess::ReadWrite, DataType::U32).with_count(1),
],
[256, 1, 1],
vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(body),
}],
)
}
#[must_use]
pub fn conformal_rank(n: u32, alpha: f64) -> u32 {
let Some(rank) = try_conformal_rank(n, alpha) else {
return 0;
};
rank
}
#[must_use]
pub fn try_conformal_rank(n: u32, alpha: f64) -> Option<u32> {
if n == 0 || !(alpha > 0.0 && alpha < 1.0) {
return None;
}
let raw = (1.0 - alpha) * (n as f64 + 1.0);
let rank = raw.ceil() as u32;
Some(rank.clamp(1, n))
}
#[must_use]
pub fn conformal_threshold_cpu(scores: &[u32], alpha: f64) -> u32 {
let n = scores.len() as u32;
let Some(k) = try_conformal_rank(n, alpha) else {
return 0;
};
let mut sorted = scores.to_vec();
sorted.sort_unstable();
sorted[(k - 1) as usize]
}
#[must_use]
pub fn predict_interval(y: u32, q_hat: u32) -> (u32, u32) {
let lo = y.saturating_sub(q_hat);
let hi = y.saturating_add(q_hat);
(lo, hi)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rank_alpha_05_picks_median() {
assert_eq!(conformal_rank(9, 0.5), 5);
}
#[test]
fn rank_alpha_005_picks_high_quantile() {
assert_eq!(conformal_rank(99, 0.05), 95);
}
#[test]
fn rank_clamps_to_n() {
assert_eq!(conformal_rank(10, 1e-9), 10);
}
#[test]
fn cpu_threshold_picks_correct_quantile() {
let scores = vec![1, 5, 3, 8, 2, 9, 4, 7, 6];
assert_eq!(conformal_threshold_cpu(&scores, 0.5), 5);
}
#[test]
fn cpu_threshold_alpha_low_picks_high() {
let scores = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
assert_eq!(conformal_threshold_cpu(&scores, 0.1), 10);
}
#[test]
fn predict_interval_is_symmetric() {
let (lo, hi) = predict_interval(10, 3);
assert_eq!(lo, 7);
assert_eq!(hi, 13);
}
#[test]
fn predict_interval_saturates_low() {
let (lo, hi) = predict_interval(2, 5);
assert_eq!(lo, 0); assert_eq!(hi, 7);
}
#[test]
fn predict_interval_saturates_high() {
let (lo, hi) = predict_interval(u32::MAX, 5);
assert_eq!(lo, u32::MAX - 5);
assert_eq!(hi, u32::MAX); }
#[test]
fn ir_program_buffer_layout() {
let p = conformal_threshold("scores", "q", 100, 95);
assert_eq!(p.workgroup_size, [256, 1, 1]);
let names: Vec<&str> = p.buffers.iter().map(|b| b.name()).collect();
assert_eq!(names, vec!["scores", "q"]);
assert_eq!(p.buffers[0].count(), 100);
assert_eq!(p.buffers[1].count(), 1);
}
#[test]
fn k_zero_traps() {
let p = conformal_threshold("s", "q", 100, 0);
assert!(p.stats().trap());
}
#[test]
fn k_over_n_traps() {
let p = conformal_threshold("s", "q", 10, 11);
assert!(p.stats().trap());
}
}