use crate::GpuValue;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
#[repr(transparent)]
pub struct LaneId(u8);
impl LaneId {
pub const fn new(id: u8) -> Self {
assert!(
id < 64,
"Lane ID must be < 64 (supports NVIDIA 32-lane and AMD 64-lane)"
);
LaneId(id)
}
pub const fn get(self) -> u8 {
self.0
}
pub const fn index(self) -> usize {
self.0 as usize
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
#[repr(transparent)]
pub struct WarpId(u16);
impl WarpId {
pub const fn new(id: u16) -> Self {
WarpId(id)
}
pub const fn get(self) -> u16 {
self.0
}
}
#[must_use = "Uniform values represent warp-wide reduction results — dropping discards the result"]
#[derive(Clone, Copy, Debug, PartialEq)]
#[repr(transparent)]
pub struct Uniform<T: GpuValue> {
value: T,
}
impl<T: GpuValue> Uniform<T> {
pub const fn from_const(value: T) -> Self {
Uniform { value }
}
pub fn get(self) -> T {
self.value
}
pub fn broadcast(self) -> PerLane<T> {
PerLane { value: self.value }
}
}
#[must_use = "PerLane values carry per-lane GPU data — dropping discards computation"]
#[derive(Clone, Copy, Debug, PartialEq)]
#[repr(transparent)]
pub struct PerLane<T: GpuValue> {
value: T,
}
impl<T: GpuValue> PerLane<T> {
pub fn new(value: T) -> Self {
PerLane { value }
}
pub fn get(self) -> T {
self.value
}
pub unsafe fn assume_uniform(self) -> Uniform<T> {
Uniform { value: self.value }
}
}
impl<T: GpuValue + core::ops::Add<Output = T>> core::ops::Add for PerLane<T> {
type Output = PerLane<T>;
fn add(self, rhs: PerLane<T>) -> PerLane<T> {
PerLane {
value: self.value + rhs.value,
}
}
}
#[must_use = "SingleLane values exist in one lane only — dropping discards the reduction result"]
#[derive(Clone, Copy, Debug, PartialEq)]
#[repr(transparent)]
pub struct SingleLane<T: GpuValue, const LANE: u8> {
value: T,
}
impl<T: GpuValue, const LANE: u8> SingleLane<T, LANE> {
pub fn new(value: T) -> Self {
SingleLane { value }
}
pub fn get(self) -> T {
self.value
}
pub fn broadcast(self) -> Uniform<T> {
Uniform::from_const(self.value)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct Role {
mask: u64,
name: &'static str,
}
impl Role {
pub const fn mask(self) -> u64 {
self.mask
}
pub const fn name(self) -> &'static str {
self.name
}
}
impl Role {
pub const fn lanes(start: u8, end: u8, name: &'static str) -> Self {
assert!(start < 64 && end <= 64 && start < end);
let width = (end - start) as u64;
let mask = if width >= 64 {
u64::MAX
} else {
((1u64 << width) - 1) << start
};
Role { mask, name }
}
pub const fn lane(id: u8, name: &'static str) -> Self {
assert!(id < 64);
Role {
mask: 1u64 << id,
name,
}
}
pub const fn contains(self, lane: LaneId) -> bool {
(self.mask & (1u64 << lane.0)) != 0
}
pub const fn count(self) -> u32 {
self.mask.count_ones()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lane_id() {
let lane = LaneId::new(15);
assert_eq!(lane.get(), 15);
assert_eq!(lane.index(), 15);
}
#[test]
fn test_lane_id_boundary_31() {
let lane = LaneId::new(31);
assert_eq!(lane.get(), 31);
assert_eq!(lane.index(), 31);
}
#[test]
fn test_lane_id_boundary_63() {
let lane = LaneId::new(63);
assert_eq!(lane.get(), 63);
}
#[test]
#[should_panic]
fn test_lane_id_out_of_range() {
LaneId::new(64);
}
#[test]
fn test_uniform_broadcast() {
let u: Uniform<i32> = Uniform::from_const(42);
let p: PerLane<i32> = u.broadcast();
assert_eq!(p.get(), 42);
}
#[test]
fn test_single_lane_broadcast() {
let reduced: SingleLane<i32, 0> = SingleLane::new(42);
let uniform: Uniform<i32> = reduced.broadcast();
assert_eq!(uniform.get(), 42);
}
#[test]
fn test_role_coverage() {
let coordinator = Role::lanes(0, 4, "coordinator");
let worker = Role::lanes(4, 32, "worker");
assert!(coordinator.contains(LaneId::new(0)));
assert!(coordinator.contains(LaneId::new(3)));
assert!(!coordinator.contains(LaneId::new(4)));
assert!(!worker.contains(LaneId::new(3)));
assert!(worker.contains(LaneId::new(4)));
assert!(worker.contains(LaneId::new(31)));
assert_eq!(coordinator.count(), 4);
assert_eq!(worker.count(), 28);
}
}