use crate::active_set::ActiveSet;
use crate::warp::Warp;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WarpError {
pub operation: &'static str,
pub expected_mask: u64,
pub actual_mask: u64,
}
impl core::fmt::Display for WarpError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
"{}: expected mask 0x{:08X}, got 0x{:08X}",
self.operation, self.expected_mask, self.actual_mask
)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AscribeError {
pub expected_name: &'static str,
pub expected_mask: u64,
pub actual_mask: u64,
}
impl core::fmt::Display for AscribeError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
"ascribe to {}: expected mask 0x{:08X}, got 0x{:08X}",
self.expected_name, self.expected_mask, self.actual_mask
)
}
}
#[cfg(not(target_arch = "nvptx64"))]
impl std::error::Error for WarpError {}
#[cfg(not(target_arch = "nvptx64"))]
impl std::error::Error for AscribeError {}
#[derive(Debug)]
#[must_use = "dropping a DynWarp without merging may indicate a missing merge — \
the compiler cannot enforce linear use, but this warns on accidental drops"]
pub struct DynWarp {
active_mask: u64,
full_mask: u64,
}
impl DynWarp {
pub fn all() -> Self {
DynWarp {
active_mask: crate::active_set::All::MASK,
full_mask: crate::active_set::All::MASK,
}
}
pub fn all_64() -> Self {
DynWarp {
active_mask: 0xFFFFFFFFFFFFFFFF,
full_mask: 0xFFFFFFFFFFFFFFFF,
}
}
pub fn from_mask_32(mask: u32) -> Self {
DynWarp {
active_mask: mask as u64,
full_mask: 0xFFFFFFFF,
}
}
pub fn from_mask_64(mask: u64) -> Self {
DynWarp {
active_mask: mask,
full_mask: 0xFFFFFFFFFFFFFFFF,
}
}
pub fn from_mask(mask: u64) -> Self {
let full = if mask <= 0xFFFFFFFF {
0xFFFFFFFF
} else {
0xFFFFFFFFFFFFFFFF
};
DynWarp {
active_mask: mask,
full_mask: full,
}
}
pub fn from_static<S: ActiveSet>(_warp: Warp<S>) -> Self {
DynWarp {
active_mask: S::MASK,
full_mask: crate::active_set::All::MASK,
}
}
pub fn ascribe<S: ActiveSet>(self) -> Result<Warp<S>, AscribeError> {
if self.active_mask == S::MASK {
Ok(Warp::new())
} else {
Err(AscribeError {
expected_name: S::NAME,
expected_mask: S::MASK,
actual_mask: self.active_mask,
})
}
}
pub fn active_mask(&self) -> u64 {
self.active_mask
}
pub fn population(&self) -> u32 {
self.active_mask.count_ones()
}
pub fn shuffle_xor_scalar(&self, value: i32, _xor_mask: u32) -> Result<i32, WarpError> {
let full = self.full_mask;
if self.active_mask != full {
return Err(WarpError {
operation: "shuffle_xor",
expected_mask: full,
actual_mask: self.active_mask,
});
}
Ok(value) }
pub fn shuffle_down_scalar(&self, value: i32, _delta: u32) -> Result<i32, WarpError> {
let full = self.full_mask;
if self.active_mask != full {
return Err(WarpError {
operation: "shuffle_down",
expected_mask: full,
actual_mask: self.active_mask,
});
}
Ok(value) }
pub fn reduce_sum_scalar(&self, value: i32) -> Result<i32, WarpError> {
let full = self.full_mask;
if self.active_mask != full {
return Err(WarpError {
operation: "reduce_sum",
expected_mask: full,
actual_mask: self.active_mask,
});
}
let warp_width = full.count_ones() as i32;
Ok(value * warp_width)
}
pub fn broadcast_scalar(&self, value: i32) -> Result<i32, WarpError> {
let full = self.full_mask;
if self.active_mask != full {
return Err(WarpError {
operation: "broadcast",
expected_mask: full,
actual_mask: self.active_mask,
});
}
Ok(value)
}
pub fn ballot(&self, predicate: &[bool; 32]) -> Result<u32, WarpError> {
let full = self.full_mask;
if full > 0xFFFFFFFF {
return Err(WarpError {
operation: "ballot (64-lane warp incompatible with u32 result)",
expected_mask: 0xFFFFFFFF,
actual_mask: full,
});
}
if self.active_mask != full {
return Err(WarpError {
operation: "ballot",
expected_mask: full,
actual_mask: self.active_mask,
});
}
let mut mask = 0u32;
for (i, &p) in predicate.iter().enumerate() {
if p {
mask |= 1 << i;
}
}
Ok(mask)
}
pub fn diverge(self, predicate_mask: u64) -> (DynWarp, DynWarp) {
let true_mask = self.active_mask & predicate_mask;
let false_mask = self.active_mask & !predicate_mask;
(
DynWarp {
active_mask: true_mask,
full_mask: self.full_mask,
},
DynWarp {
active_mask: false_mask,
full_mask: self.full_mask,
},
)
}
pub fn merge(self, other: DynWarp) -> Result<DynWarp, WarpError> {
if self.full_mask != other.full_mask {
return Err(WarpError {
operation: "merge (full_mask mismatch)",
expected_mask: self.full_mask,
actual_mask: other.full_mask,
});
}
let overlap = self.active_mask & other.active_mask;
if overlap != 0 {
return Err(WarpError {
operation: "merge",
expected_mask: 0, actual_mask: overlap,
});
}
Ok(DynWarp {
active_mask: self.active_mask | other.active_mask,
full_mask: self.full_mask,
})
}
pub fn merge_covering(self, other: DynWarp) -> Result<DynWarp, WarpError> {
let result = self.merge(other)?;
if result.active_mask != result.full_mask {
return Err(WarpError {
operation: "merge_covering (not all lanes covered)",
expected_mask: result.full_mask,
actual_mask: result.active_mask,
});
}
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::active_set::*;
#[test]
#[cfg(not(feature = "warp64"))]
fn dyn_warp_all() {
let w = DynWarp::all();
assert_eq!(w.active_mask(), All::MASK);
assert_eq!(w.population(), crate::WARP_SIZE);
}
#[test]
#[cfg(not(feature = "warp64"))]
fn from_mask_32_basic() {
let w = DynWarp::from_mask_32(LowHalf::MASK as u32);
assert_eq!(w.active_mask(), LowHalf::MASK);
assert_eq!(w.full_mask, All::MASK);
assert_eq!(w.population(), crate::WARP_SIZE / 2);
}
#[test]
#[cfg(not(feature = "warp64"))]
fn from_mask_32_empty() {
let w = DynWarp::from_mask_32(0);
assert_eq!(w.active_mask(), 0);
assert_eq!(w.full_mask, All::MASK);
assert_eq!(w.population(), 0);
}
#[test]
fn from_mask_64_low_bits() {
let w = DynWarp::from_mask_64(0xFFFFFFFF);
assert_eq!(w.active_mask(), 0xFFFFFFFF);
assert_eq!(w.full_mask, 0xFFFFFFFFFFFFFFFF);
assert_eq!(w.population(), 32);
}
#[test]
fn from_mask_64_full() {
let w = DynWarp::from_mask_64(0xFFFFFFFFFFFFFFFF);
assert_eq!(w.active_mask(), 0xFFFFFFFFFFFFFFFF);
assert_eq!(w.full_mask, 0xFFFFFFFFFFFFFFFF);
}
#[test]
#[cfg(not(feature = "warp64"))]
fn ascribe_all_succeeds() {
let w = DynWarp::all();
let warp: Warp<All> = w.ascribe().unwrap();
assert_eq!(warp.active_mask(), All::MASK);
}
#[test]
#[cfg(not(feature = "warp64"))]
fn ascribe_wrong_type_fails() {
let w = DynWarp::all();
let err = w.ascribe::<Even>().unwrap_err();
assert_eq!(err.expected_name, "Even");
assert_eq!(err.expected_mask, Even::MASK);
assert_eq!(err.actual_mask, All::MASK);
}
#[test]
fn from_static_roundtrip() {
let warp: Warp<Even> = Warp::new();
let dyn_warp = DynWarp::from_static(warp);
assert_eq!(dyn_warp.active_mask(), Even::MASK);
let _: Warp<Even> = dyn_warp.ascribe().unwrap();
}
#[test]
fn shuffle_all_succeeds() {
let w = DynWarp::all();
assert!(w.shuffle_xor_scalar(42, 1).is_ok());
}
#[test]
fn shuffle_partial_fails() {
let w = DynWarp::from_mask(Even::MASK);
let err = w.shuffle_xor_scalar(42, 1).unwrap_err();
assert_eq!(err.operation, "shuffle_xor");
assert_eq!(err.actual_mask, Even::MASK);
}
#[test]
#[cfg(not(feature = "warp64"))]
fn ballot_all_succeeds() {
let w = DynWarp::all();
let pred = [true; 32];
assert_eq!(w.ballot(&pred).unwrap(), All::MASK as u32);
}
#[test]
#[cfg(not(feature = "warp64"))]
fn ballot_partial_fails() {
let w = DynWarp::from_mask(LowHalf::MASK);
let pred = [true; 32];
assert!(w.ballot(&pred).is_err());
}
#[test]
#[cfg(feature = "warp64")]
fn ballot_partial_fails_64() {
let w = DynWarp::from_mask_64(LowHalf::MASK);
let pred = [true; 32];
assert!(w.ballot(&pred).is_err());
}
#[test]
fn diverge_produces_disjoint_masks() {
let w = DynWarp::all();
let (evens, odds) = w.diverge(Even::MASK);
assert_eq!(evens.active_mask(), Even::MASK);
assert_eq!(odds.active_mask(), Odd::MASK);
assert_eq!(evens.active_mask() & odds.active_mask(), 0);
}
#[test]
fn merge_disjoint_succeeds() {
let evens = DynWarp::from_mask(Even::MASK);
let odds = DynWarp::from_mask(Odd::MASK);
let merged = evens.merge(odds).unwrap();
assert_eq!(merged.active_mask(), All::MASK);
}
#[test]
fn merge_overlapping_fails() {
let a = DynWarp::from_mask(LowHalf::MASK);
let b = DynWarp::from_mask(Even::MASK); assert!(a.merge(b).is_err());
}
#[test]
#[cfg(not(feature = "warp64"))]
fn gradual_migration_workflow() {
let w = DynWarp::all();
let (evens, odds) = w.diverge(Even::MASK);
assert!(evens.shuffle_xor_scalar(42, 1).is_err()); let merged = evens.merge(odds).unwrap();
assert!(merged.shuffle_xor_scalar(42, 1).is_ok());
let all = DynWarp::all();
let warp: Warp<All> = all.ascribe().unwrap();
assert_eq!(warp.population(), crate::WARP_SIZE);
let dyn_again = DynWarp::from_static(warp);
assert_eq!(dyn_again.active_mask(), All::MASK);
}
#[test]
#[cfg(not(feature = "warp64"))]
fn nested_diverge_merge_dynamic() {
let w = DynWarp::all();
let (low, high) = w.diverge(LowHalf::MASK);
assert_eq!(low.population(), crate::WARP_SIZE / 2);
assert_eq!(high.population(), crate::WARP_SIZE / 2);
let (even_low, odd_low) = low.diverge(Even::MASK);
assert_eq!(even_low.active_mask(), EvenLow::MASK);
assert_eq!(odd_low.active_mask(), OddLow::MASK);
assert!(even_low.shuffle_xor_scalar(1, 1).is_err());
let low_restored = even_low.merge(odd_low).unwrap();
assert_eq!(low_restored.active_mask(), LowHalf::MASK);
let all = low_restored.merge(high).unwrap();
assert_eq!(all.active_mask(), All::MASK);
let _warp: Warp<All> = all.ascribe().unwrap();
}
#[test]
fn dyn_warp_all_64() {
let w = DynWarp::all_64();
assert_eq!(w.active_mask(), 0xFFFFFFFFFFFFFFFF);
assert_eq!(w.population(), 64);
}
#[test]
fn reduce_sum_64_lane() {
let w = DynWarp::all_64();
let result = w.reduce_sum_scalar(1).unwrap();
assert_eq!(result, 64); }
#[test]
fn ballot_64_lane_errors() {
let w = DynWarp::all_64();
let pred = [true; 32];
let err = w.ballot(&pred).unwrap_err();
assert!(err.operation.contains("64-lane"));
}
#[test]
fn shuffle_64_lane_succeeds() {
let w = DynWarp::all_64();
assert!(w.shuffle_xor_scalar(42, 1).is_ok());
}
#[test]
fn from_mask_infers_64_lane_width() {
let w = DynWarp::from_mask(0xFFFFFFFFFFFFFFFF);
assert_eq!(w.population(), 64);
assert!(w.shuffle_xor_scalar(42, 1).is_ok());
}
#[test]
fn from_mask_high_bits_merge_works() {
let w = DynWarp::from_mask(0x1_0000_0000);
assert_eq!(w.full_mask, 0xFFFFFFFFFFFFFFFF);
assert!(w.shuffle_xor_scalar(42, 1).is_err());
}
#[test]
fn merge_mismatched_width_fails() {
let a = DynWarp::all(); let b = DynWarp::all_64(); let (a1, _a2) = a.diverge(Even::MASK);
let (b1, _b2) = b.diverge(Even::MASK);
assert!(a1.merge(b1).is_err());
}
#[test]
fn merge_same_width_succeeds() {
let a = DynWarp::all();
let (a1, a2) = a.diverge(Even::MASK);
assert!(a1.merge(a2).is_ok());
}
#[test]
fn error_messages_are_clear() {
let w = DynWarp::from_mask(Even::MASK);
let err = w.shuffle_xor_scalar(42, 1).unwrap_err();
let msg = err.to_string();
assert!(msg.contains("shuffle_xor"));
assert!(msg.contains("FFFFFFFF")); assert!(msg.contains("55555555"));
let err = DynWarp::from_mask(0x1234).ascribe::<All>().unwrap_err();
let msg = err.to_string();
assert!(msg.contains("All"));
assert!(msg.contains("00001234"));
}
#[test]
#[cfg(not(feature = "warp64"))]
fn merge_covering_succeeds_on_complements() {
let a = DynWarp::from_mask_32(Even::MASK as u32);
let b = DynWarp::from_mask_32(Odd::MASK as u32);
let merged = a.merge_covering(b).unwrap();
assert_eq!(merged.active_mask(), All::MASK);
}
#[test]
fn merge_covering_fails_on_partial() {
let a = DynWarp::from_mask_32(0x1);
let b = DynWarp::from_mask_32(0x2);
assert!(a.merge_covering(b).is_err());
}
}