use std::cmp::Ordering;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct RegionTriple {
pub pid: u32,
pub start: u32,
pub end: u32,
}
impl RegionTriple {
#[must_use]
pub const fn new(pid: u32, start: u32, end: u32) -> Self {
Self { pid, start, end }
}
}
impl Ord for RegionTriple {
fn cmp(&self, other: &Self) -> Ordering {
self.pid
.cmp(&other.pid)
.then(self.start.cmp(&other.start))
.then(self.end.cmp(&other.end))
}
}
impl PartialOrd for RegionTriple {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
#[must_use]
pub fn dedup_regions_cpu(input: Vec<RegionTriple>) -> Vec<RegionTriple> {
let mut owned = input;
dedup_regions_inplace(&mut owned);
owned
}
#[must_use]
pub fn dedup_regions_flag_program(
pids: &str,
starts: &str,
ends: &str,
survivors: &str,
count: u32,
) -> vyre_foundation::ir::Program {
use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
let t = Expr::InvocationId { axis: 0 };
let body = vec![Node::if_then(
Expr::lt(t.clone(), Expr::u32(count)),
vec![
Node::let_bind("pid_self", Expr::load(pids, t.clone())),
Node::let_bind("start_self", Expr::load(starts, t.clone())),
Node::if_then(
Expr::eq(t.clone(), Expr::u32(0)),
vec![Node::store(survivors, t.clone(), Expr::u32(1))],
),
Node::if_then(
Expr::ne(t.clone(), Expr::u32(0)),
vec![
Node::let_bind(
"pid_prev",
Expr::load(pids, Expr::sub(t.clone(), Expr::u32(1))),
),
Node::let_bind(
"end_prev",
Expr::load(ends, Expr::sub(t.clone(), Expr::u32(1))),
),
Node::let_bind(
"different_pid",
Expr::ne(Expr::var("pid_self"), Expr::var("pid_prev")),
),
Node::let_bind(
"no_overlap",
Expr::gt(Expr::var("start_self"), Expr::var("end_prev")),
),
Node::let_bind(
"flag",
Expr::select(
Expr::or(Expr::var("different_pid"), Expr::var("no_overlap")),
Expr::u32(1),
Expr::u32(0),
),
),
Node::store(survivors, t.clone(), Expr::var("flag")),
],
),
],
)];
Program::wrapped(
vec![
BufferDecl::storage(pids, 0, BufferAccess::ReadOnly, DataType::U32).with_count(count),
BufferDecl::storage(starts, 1, BufferAccess::ReadOnly, DataType::U32).with_count(count),
BufferDecl::storage(ends, 2, BufferAccess::ReadOnly, DataType::U32).with_count(count),
BufferDecl::storage(survivors, 3, BufferAccess::WriteOnly, DataType::U32)
.with_count(count),
],
[count.clamp(1, 64), 1, 1],
vec![Node::Region {
generator: Ident::from("vyre-primitives::matching::region::dedup_regions_flag"),
source_region: None,
body: Arc::new(body),
}],
)
}
pub fn sort_regions_cpu(input: &mut [RegionTriple]) {
input.sort();
}
#[must_use]
pub fn region_sort_program(
pids_in: &str,
starts_in: &str,
ends_in: &str,
pids_out: &str,
starts_out: &str,
ends_out: &str,
count: u32,
) -> vyre_foundation::ir::Program {
use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
if count == 0 {
return crate::invalid_output_program(
"vyre-primitives::matching::region::sort_regions",
pids_out,
DataType::U32,
format!("Fix: region_sort_program requires count > 0, got {count}."),
);
}
let t = Expr::InvocationId { axis: 0 };
let pid_eq = Expr::eq(Expr::var("pid_j"), Expr::var("pid_i"));
let start_eq = Expr::eq(Expr::var("start_j"), Expr::var("start_i"));
let lower_key = Expr::or(
Expr::lt(Expr::var("pid_j"), Expr::var("pid_i")),
Expr::or(
Expr::and(
pid_eq.clone(),
Expr::lt(Expr::var("start_j"), Expr::var("start_i")),
),
Expr::and(
pid_eq.clone(),
Expr::and(
start_eq.clone(),
Expr::lt(Expr::var("end_j"), Expr::var("end_i")),
),
),
),
);
let stable_tie = Expr::and(
pid_eq,
Expr::and(
start_eq,
Expr::and(
Expr::eq(Expr::var("end_j"), Expr::var("end_i")),
Expr::lt(Expr::var("j"), Expr::var("i")),
),
),
);
let body = vec![Node::if_then(
Expr::lt(t.clone(), Expr::u32(count)),
vec![
Node::let_bind("i", t.clone()),
Node::let_bind("pid_i", Expr::load(pids_in, Expr::var("i"))),
Node::let_bind("start_i", Expr::load(starts_in, Expr::var("i"))),
Node::let_bind("end_i", Expr::load(ends_in, Expr::var("i"))),
Node::let_bind("rank", Expr::u32(0)),
Node::loop_for(
"j",
Expr::u32(0),
Expr::u32(count),
vec![
Node::let_bind("pid_j", Expr::load(pids_in, Expr::var("j"))),
Node::let_bind("start_j", Expr::load(starts_in, Expr::var("j"))),
Node::let_bind("end_j", Expr::load(ends_in, Expr::var("j"))),
Node::if_then(
Expr::or(lower_key.clone(), stable_tie.clone()),
vec![Node::assign(
"rank",
Expr::add(Expr::var("rank"), Expr::u32(1)),
)],
),
],
),
Node::store(pids_out, Expr::var("rank"), Expr::var("pid_i")),
Node::store(starts_out, Expr::var("rank"), Expr::var("start_i")),
Node::store(ends_out, Expr::var("rank"), Expr::var("end_i")),
],
)];
Program::wrapped(
vec![
BufferDecl::storage(pids_in, 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(count),
BufferDecl::storage(starts_in, 1, BufferAccess::ReadOnly, DataType::U32)
.with_count(count),
BufferDecl::storage(ends_in, 2, BufferAccess::ReadOnly, DataType::U32)
.with_count(count),
BufferDecl::storage(pids_out, 3, BufferAccess::ReadWrite, DataType::U32)
.with_count(count),
BufferDecl::storage(starts_out, 4, BufferAccess::ReadWrite, DataType::U32)
.with_count(count),
BufferDecl::storage(ends_out, 5, BufferAccess::ReadWrite, DataType::U32)
.with_count(count),
],
[256, 1, 1],
vec![Node::Region {
generator: Ident::from("vyre-primitives::matching::region::region_sort"),
source_region: None,
body: Arc::new(body),
}],
)
}
pub fn dedup_regions_inplace(input: &mut Vec<RegionTriple>) {
if input.is_empty() {
return;
}
input.sort_unstable();
let mut write = 1usize;
for read in 1..input.len() {
let next = input[read];
let last = input[write - 1];
let same_pid = next.pid == last.pid;
let overlap_or_touch = next.start <= last.end;
if same_pid && overlap_or_touch {
if next.end > last.end {
input[write - 1].end = next.end;
}
} else {
input[write] = next;
write += 1;
}
}
input.truncate(write);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_input() {
assert!(dedup_regions_cpu(vec![]).is_empty());
}
#[test]
fn single_pass_through() {
let r = RegionTriple::new(0, 5, 10);
assert_eq!(dedup_regions_cpu(vec![r]), vec![r]);
}
#[test]
fn exact_duplicate_collapses() {
let r = RegionTriple::new(0, 5, 10);
assert_eq!(dedup_regions_cpu(vec![r, r]), vec![r]);
}
#[test]
fn overlapping_same_pid_merges() {
let a = RegionTriple::new(0, 5, 10);
let b = RegionTriple::new(0, 7, 12);
assert_eq!(
dedup_regions_cpu(vec![a, b]),
vec![RegionTriple::new(0, 5, 12)]
);
}
#[test]
fn touching_same_pid_merges() {
let a = RegionTriple::new(0, 5, 10);
let b = RegionTriple::new(0, 10, 15);
assert_eq!(
dedup_regions_cpu(vec![a, b]),
vec![RegionTriple::new(0, 5, 15)]
);
}
#[test]
fn different_pids_never_merge() {
let a = RegionTriple::new(0, 5, 10);
let b = RegionTriple::new(1, 5, 10);
let mut got = dedup_regions_cpu(vec![a, b]);
got.sort_unstable();
assert_eq!(got, vec![a, b]);
}
#[test]
fn unsorted_input_handled() {
let a = RegionTriple::new(0, 5, 10);
let b = RegionTriple::new(0, 7, 12);
let c = RegionTriple::new(1, 3, 4);
let got = dedup_regions_cpu(vec![b, a, c]);
assert_eq!(got, vec![RegionTriple::new(0, 5, 12), c]);
}
#[test]
fn cluster_of_three_merges() {
let a = RegionTriple::new(0, 1, 3);
let b = RegionTriple::new(0, 2, 5);
let c = RegionTriple::new(0, 4, 8);
assert_eq!(
dedup_regions_cpu(vec![a, b, c]),
vec![RegionTriple::new(0, 1, 8)]
);
}
#[test]
fn zero_width_matches_preserved() {
let a = RegionTriple::new(0, 5, 5); let b = RegionTriple::new(1, 5, 5); let mut got = dedup_regions_cpu(vec![a, b]);
got.sort_unstable();
assert_eq!(got, vec![a, b]);
}
#[test]
fn sort_regions_cpu_matches_ord_impl() {
let mut a = vec![
RegionTriple::new(2, 0, 1),
RegionTriple::new(0, 5, 10),
RegionTriple::new(1, 3, 4),
RegionTriple::new(0, 5, 8),
RegionTriple::new(0, 5, 10),
];
sort_regions_cpu(&mut a);
assert_eq!(
a,
vec![
RegionTriple::new(0, 5, 8),
RegionTriple::new(0, 5, 10),
RegionTriple::new(0, 5, 10),
RegionTriple::new(1, 3, 4),
RegionTriple::new(2, 0, 1),
]
);
}
#[test]
fn sort_regions_cpu_is_stable_for_equal_triples() {
let mut a = vec![
RegionTriple::new(0, 5, 10),
RegionTriple::new(0, 5, 10),
RegionTriple::new(0, 5, 10),
];
sort_regions_cpu(&mut a);
assert_eq!(a.len(), 3);
for r in &a {
assert_eq!(*r, RegionTriple::new(0, 5, 10));
}
}
#[test]
fn region_sort_program_emits_expected_buffers() {
let p = region_sort_program("pi", "si", "ei", "po", "so", "eo", 64);
assert_eq!(p.workgroup_size, [256, 1, 1]);
let names: Vec<&str> = p.buffers.iter().map(|b| b.name()).collect();
assert_eq!(names, vec!["pi", "si", "ei", "po", "so", "eo"]);
for buf in p.buffers.iter() {
assert_eq!(buf.count(), 64);
}
}
#[test]
fn region_sort_program_zero_count_traps() {
let p = region_sort_program("pi", "si", "ei", "po", "so", "eo", 0);
assert!(p.stats().trap());
}
#[test]
fn region_sort_program_pipeline_composes_with_dedup_flags() {
let sort_p = region_sort_program("pi", "si", "ei", "ps", "ss", "es", 32);
let flag_p = dedup_regions_flag_program("ps", "ss", "es", "flags", 32);
let sort_outputs: Vec<&str> = sort_p
.buffers
.iter()
.filter(|b| b.access() == vyre_foundation::ir::BufferAccess::ReadWrite)
.map(|b| b.name())
.collect();
assert_eq!(sort_outputs, vec!["ps", "ss", "es"]);
let flag_inputs: Vec<&str> = flag_p
.buffers
.iter()
.filter(|b| b.access() == vyre_foundation::ir::BufferAccess::ReadOnly)
.map(|b| b.name())
.collect();
assert_eq!(flag_inputs, vec!["ps", "ss", "es"]);
}
}