use std::collections::HashSet;
pub type Lane = u32;
pub const WARP_SIZE: u32 = crate::WARP_SIZE;
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ProofActiveSet(HashSet<Lane>);
impl ProofActiveSet {
pub fn all() -> Self {
ProofActiveSet((0..WARP_SIZE).collect())
}
pub fn empty() -> Self {
ProofActiveSet(HashSet::new())
}
pub fn from_predicate<F: Fn(Lane) -> bool>(pred: F) -> Self {
ProofActiveSet((0..WARP_SIZE).filter(|&l| pred(l)).collect())
}
pub fn union(&self, other: &Self) -> Self {
ProofActiveSet(self.0.union(&other.0).copied().collect())
}
pub fn intersection(&self, other: &Self) -> Self {
ProofActiveSet(self.0.intersection(&other.0).copied().collect())
}
pub fn complement(&self) -> Self {
ProofActiveSet((0..WARP_SIZE).filter(|l| !self.0.contains(l)).collect())
}
pub fn is_disjoint(&self, other: &Self) -> bool {
self.0.is_disjoint(&other.0)
}
pub fn is_all(&self) -> bool {
self.0.len() == WARP_SIZE as usize
}
pub fn contains(&self, lane: Lane) -> bool {
self.0.contains(&lane)
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Type {
Warp(ProofActiveSet),
PerLane,
Unit,
Pair(Box<Type>, Box<Type>),
}
#[derive(Clone, Debug)]
pub enum Expr {
WarpVal(ProofActiveSet),
PerLaneVal(Vec<i32>),
UnitVal,
PairVal(Box<Expr>, Box<Expr>),
Var(String),
Diverge(Box<Expr>, Predicate),
Merge(Box<Expr>, Box<Expr>),
Shuffle(Box<Expr>, Box<Expr>, u32),
Let(String, Box<Expr>, Box<Expr>),
}
#[derive(Clone, Debug)]
pub enum Predicate {
Even,
LessThan(u32),
Custom(fn(Lane) -> bool),
}
impl Predicate {
pub fn eval(&self, lane: Lane) -> bool {
match self {
Predicate::Even => lane % 2 == 0,
Predicate::LessThan(n) => lane < *n,
Predicate::Custom(f) => f(lane),
}
}
pub fn active_set(&self) -> ProofActiveSet {
ProofActiveSet::from_predicate(|l| self.eval(l))
}
}
pub type Context = std::collections::HashMap<String, Type>;
pub type TypeResult = Result<Type, String>;
pub fn type_check(ctx: &mut Context, expr: &Expr) -> TypeResult {
match expr {
Expr::WarpVal(s) => Ok(Type::Warp(s.clone())),
Expr::PerLaneVal(vals) => {
if vals.len() != WARP_SIZE as usize {
return Err(format!(
"PerLaneVal has {} elements, expected {WARP_SIZE}",
vals.len()
));
}
Ok(Type::PerLane)
}
Expr::UnitVal => Ok(Type::Unit),
Expr::PairVal(e1, e2) => {
let t1 = type_check(ctx, e1)?;
let t2 = type_check(ctx, e2)?;
Ok(Type::Pair(Box::new(t1), Box::new(t2)))
}
Expr::Var(x) => ctx
.remove(x)
.ok_or_else(|| format!("Unbound variable: {}", x)),
Expr::Diverge(w, pred) => {
let warp_type = type_check(ctx, w)?;
match warp_type {
Type::Warp(s) => {
let p = pred.active_set();
let s_true = s.intersection(&p);
let s_false = s.intersection(&p.complement());
Ok(Type::Pair(
Box::new(Type::Warp(s_true)),
Box::new(Type::Warp(s_false)),
))
}
_ => Err("diverge requires a Warp".to_string()),
}
}
Expr::Merge(w1, w2) => {
let t1 = type_check(ctx, w1)?;
let t2 = type_check(ctx, w2)?;
match (t1, t2) {
(Type::Warp(s1), Type::Warp(s2)) => {
if !s1.is_disjoint(&s2) {
Err("merge requires disjoint active sets".to_string())
} else {
Ok(Type::Warp(s1.union(&s2)))
}
}
_ => Err("merge requires two Warps".to_string()),
}
}
Expr::Shuffle(w, data, _mask) => {
let warp_type = type_check(ctx, w)?;
let data_type = type_check(ctx, data)?;
match (warp_type, data_type) {
(Type::Warp(s), Type::PerLane) => {
if s.is_all() {
Ok(Type::PerLane)
} else {
Err("shuffle requires Warp<All>".to_string())
}
}
_ => Err("shuffle requires Warp and PerLane".to_string()),
}
}
Expr::Let(x, e1, e2) => {
let t1 = type_check(ctx, e1)?;
if ctx.contains_key(x) {
return Err(format!(
"Let binding '{}' shadows existing variable (not fresh)",
x
));
}
ctx.insert(x.clone(), t1);
let t2 = type_check(ctx, e2)?;
if ctx.remove(x).is_some() {
return Err(format!("Linear variable '{}' not consumed in let body", x));
}
Ok(t2)
}
}
}
pub fn is_value(expr: &Expr) -> bool {
match expr {
Expr::WarpVal(_) => true,
Expr::PerLaneVal(_) => true,
Expr::UnitVal => true,
Expr::PairVal(e1, e2) => is_value(e1) && is_value(e2),
_ => false,
}
}
pub fn step(expr: &Expr) -> Option<Expr> {
match expr {
_ if is_value(expr) => None,
Expr::Diverge(w, pred) => {
if let Expr::WarpVal(s) = w.as_ref() {
let p = pred.active_set();
let s_true = s.intersection(&p);
let s_false = s.intersection(&p.complement());
Some(Expr::PairVal(
Box::new(Expr::WarpVal(s_true)),
Box::new(Expr::WarpVal(s_false)),
))
} else {
step(w).map(|w2| Expr::Diverge(Box::new(w2), pred.clone()))
}
}
Expr::Merge(w1, w2) => {
match (w1.as_ref(), w2.as_ref()) {
(Expr::WarpVal(s1), Expr::WarpVal(s2)) => {
if !s1.is_disjoint(s2) {
return None; }
Some(Expr::WarpVal(s1.union(s2)))
}
(Expr::WarpVal(_), _) => step(w2).map(|w2_| Expr::Merge(w1.clone(), Box::new(w2_))),
_ => step(w1).map(|w1_| Expr::Merge(Box::new(w1_), w2.clone())),
}
}
Expr::Shuffle(w, data, mask) => {
match (w.as_ref(), data.as_ref()) {
(Expr::WarpVal(s), Expr::PerLaneVal(vals)) => {
if !s.is_all() {
return None; }
if vals.len() != WARP_SIZE as usize {
return None; }
let mut result = vals.clone();
for lane in 0..WARP_SIZE {
let src = lane ^ mask;
if src < WARP_SIZE {
result[lane as usize] = vals[src as usize];
}
}
Some(Expr::PerLaneVal(result))
}
(Expr::WarpVal(_), _) => {
step(data).map(|d| Expr::Shuffle(w.clone(), Box::new(d), *mask))
}
_ => step(w).map(|w_| Expr::Shuffle(Box::new(w_), data.clone(), *mask)),
}
}
Expr::Let(x, e1, e2) => {
if is_value(e1) {
Some(substitute(e2, x, e1))
} else {
step(e1).map(|e1_| Expr::Let(x.clone(), Box::new(e1_), e2.clone()))
}
}
_ => None,
}
}
fn substitute(expr: &Expr, var: &str, val: &Expr) -> Expr {
match expr {
Expr::Var(x) if x == var => val.clone(),
Expr::Var(_) => expr.clone(),
Expr::WarpVal(_) | Expr::PerLaneVal(_) | Expr::UnitVal => expr.clone(),
Expr::PairVal(e1, e2) => Expr::PairVal(
Box::new(substitute(e1, var, val)),
Box::new(substitute(e2, var, val)),
),
Expr::Diverge(w, p) => Expr::Diverge(Box::new(substitute(w, var, val)), p.clone()),
Expr::Merge(w1, w2) => Expr::Merge(
Box::new(substitute(w1, var, val)),
Box::new(substitute(w2, var, val)),
),
Expr::Shuffle(w, d, m) => Expr::Shuffle(
Box::new(substitute(w, var, val)),
Box::new(substitute(d, var, val)),
*m,
),
Expr::Let(x, e1, e2) => {
let e1_ = substitute(e1, var, val);
let e2_ = if x == var {
e2.clone()
} else {
Box::new(substitute(e2, var, val))
};
Expr::Let(x.clone(), Box::new(e1_), e2_)
}
}
}
pub fn progress_check(expr: &Expr) -> bool {
let mut ctx = Context::new();
if type_check(&mut ctx, expr).is_ok() {
is_value(expr) || step(expr).is_some()
} else {
true }
}
pub fn preservation_check(expr: &Expr) -> bool {
let mut ctx = Context::new();
let original_type = type_check(&mut ctx, expr);
if let Some(stepped) = step(expr) {
let mut ctx2 = Context::new();
let stepped_type = type_check(&mut ctx2, &stepped);
original_type == stepped_type
} else {
true }
}
pub fn type_safety_check(expr: &Expr) -> bool {
let mut ctx = Context::new();
if type_check(&mut ctx, expr).is_err() {
return true; }
let mut current = expr.clone();
let mut steps = 0;
const MAX_STEPS: usize = 1000;
while !is_value(¤t) && steps < MAX_STEPS {
match step(¤t) {
Some(next) => {
let mut ctx1 = Context::new();
let mut ctx2 = Context::new();
let t1 = type_check(&mut ctx1, ¤t);
let t2 = type_check(&mut ctx2, &next);
if t1 != t2 {
return false; }
current = next;
steps += 1;
}
None => {
return false; }
}
}
is_value(¤t)
}
pub fn diverge_complement_lemma(s: &ProofActiveSet, pred: &Predicate) -> bool {
let p = pred.active_set();
let s1 = s.intersection(&p);
let s2 = s.intersection(&p.complement());
let union = s1.union(&s2);
let covers = union == *s;
let disjoint = s1.is_disjoint(&s2);
covers && disjoint
}
pub fn merge_restore_lemma(s: &ProofActiveSet, pred: &Predicate) -> bool {
let p = pred.active_set();
let s1 = s.intersection(&p);
let s2 = s.intersection(&p.complement());
let merged = s1.union(&s2);
merged == *s
}
pub fn shuffle_source_lemma(s: &ProofActiveSet, mask: u32) -> bool {
if !s.is_all() {
for lane in 0..WARP_SIZE {
if s.contains(lane) {
let src = lane ^ mask;
if !s.contains(src) {
return false; }
}
}
}
true
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_diverge_complement() {
let all = ProofActiveSet::all();
assert!(diverge_complement_lemma(&all, &Predicate::Even));
assert!(diverge_complement_lemma(&all, &Predicate::LessThan(16)));
}
#[test]
fn test_merge_restore() {
let all = ProofActiveSet::all();
assert!(merge_restore_lemma(&all, &Predicate::Even));
assert!(merge_restore_lemma(&all, &Predicate::LessThan(10)));
}
#[test]
fn test_shuffle_source_all() {
let all = ProofActiveSet::all();
assert!(shuffle_source_lemma(&all, 1));
assert!(shuffle_source_lemma(&all, 5));
assert!(shuffle_source_lemma(&all, 31));
}
#[test]
fn test_shuffle_source_even_fails() {
let even = ProofActiveSet::from_predicate(|l| l % 2 == 0);
assert!(!shuffle_source_lemma(&even, 1));
assert!(shuffle_source_lemma(&even, 2));
}
#[test]
fn test_type_check_good_program() {
let program = Expr::Let(
"pair".to_string(),
Box::new(Expr::Diverge(
Box::new(Expr::WarpVal(ProofActiveSet::all())),
Predicate::Even,
)),
Box::new(Expr::Var("pair".to_string())),
);
let mut ctx = Context::new();
assert!(type_check(&mut ctx, &program).is_ok());
}
#[test]
fn test_type_check_bad_shuffle() {
let even_warp = Expr::WarpVal(ProofActiveSet::from_predicate(|l| l % 2 == 0));
let data = Expr::PerLaneVal(vec![0; WARP_SIZE as usize]);
let bad_shuffle = Expr::Shuffle(Box::new(even_warp), Box::new(data), 1);
let mut ctx = Context::new();
let result = type_check(&mut ctx, &bad_shuffle);
assert!(result.is_err());
assert!(result.unwrap_err().contains("Warp<All>"));
}
#[test]
fn test_progress() {
let all_warp = Expr::WarpVal(ProofActiveSet::all());
let diverge = Expr::Diverge(Box::new(all_warp), Predicate::Even);
assert!(progress_check(&diverge));
assert!(step(&diverge).is_some());
}
#[test]
fn test_preservation() {
let all_warp = Expr::WarpVal(ProofActiveSet::all());
let diverge = Expr::Diverge(Box::new(all_warp), Predicate::Even);
assert!(preservation_check(&diverge));
}
#[test]
fn test_type_safety_good_program() {
let all_warp = Expr::WarpVal(ProofActiveSet::all());
let diverge = Expr::Diverge(Box::new(all_warp), Predicate::Even);
assert!(type_safety_check(&diverge));
}
}