warp-types 0.3.2

Type-safe GPU warp programming via linear typestate: compile-time prevention of shuffle-from-inactive-lane bugs
Documentation
//! Active set types: compile-time lane subset tracking.
//!
//! Active sets are zero-sized marker types that represent subsets of warp lanes.
//! The type system tracks which lanes are active through diverge/merge operations,
//! preventing shuffle-from-inactive-lane bugs at compile time.
//!
//! # Lattice structure
//!
//! Active sets form a Boolean lattice under subset ordering:
//!
//! ```text
//!                    All (32/64 lanes)
//!                   /    \
//!            Even (16)   Odd (16)     LowHalf (16)   HighHalf (16)
//!             / \         / \            / \              / \
//!        EvenLow EvenHigh OddLow OddHigh EvenLow OddLow EvenHigh OddHigh
//!          (8)    (8)      (8)    (8)      (8)    (8)     (8)     (8)
//! ```
//!
//! Note: `EvenLow` appears under both `Even` and `LowHalf` — same set,
//! reached by different diverge paths. Path independence is a key property.

/// Sealed trait module — prevents external crates from implementing safety-critical traits.
///
/// Hard-sealed: the `_sealed` method returns a `pub(crate)` type with no default
/// body. External crates cannot name `SealToken` and therefore cannot provide
/// the required method implementation.
#[doc(hidden)]
pub mod sealed {
    #[doc(hidden)]
    pub(crate) struct SealToken;

    #[allow(private_interfaces)]
    pub trait Sealed {
        #[doc(hidden)]
        fn _sealed() -> SealToken;
    }
}

/// Marker trait for active lane set types.
///
/// Each implementor is a zero-sized type encoding a specific bitmask of lanes.
/// The `MASK` constant enables runtime debugging; the type itself provides
/// compile-time tracking.
pub trait ActiveSet: sealed::Sealed + Copy + 'static {
    /// Bitmask of active lanes (for runtime debugging/verification).
    const MASK: u64;
    /// Human-readable name.
    const NAME: &'static str;
}

/// Proof that `Self` and `Other` are complements: disjoint AND covering all lanes.
///
/// This is THE key safety trait. `merge(a, b)` requires `A: ComplementOf<B>`.
/// Only implemented for valid complement pairs — the compiler rejects invalid merges.
#[diagnostic::on_unimplemented(
    message = "`{Self}` is not the complement of `{Other}` — cannot merge these sub-warps",
    label = "merge requires complementary active sets (e.g., Even + Odd, LowHalf + HighHalf)",
    note = "use `diverge_even_odd()` or `diverge_halves()` to create valid complement pairs, then merge them"
)]
pub trait ComplementOf<Other: ActiveSet>: sealed::Sealed + ActiveSet {}

/// Proof that `Self` and `Other` are complements within a parent set `P`.
///
/// `S1 ∪ S2 = P` and `S1 ∩ S2 = ∅`. Used for nested divergence where
/// merge returns to a parent set rather than `All`.
pub trait ComplementWithin<Other: ActiveSet, Parent: ActiveSet>:
    sealed::Sealed + ActiveSet
{
}

/// Proof that an active set can be split into two disjoint subsets.
///
/// Implemented for each valid diverge pattern (e.g., `All` → `Even` + `Odd`).
#[diagnostic::on_unimplemented(
    message = "`{Self}` cannot be split into `{TrueBranch}` + `{FalseBranch}`",
    label = "this diverge pattern is not defined in the active set hierarchy",
    note = "valid diverge patterns: All → Even/Odd, All → LowHalf/HighHalf, Even → EvenLow/EvenHigh, etc."
)]
pub trait CanDiverge<TrueBranch: ActiveSet, FalseBranch: ActiveSet>:
    sealed::Sealed + ActiveSet + Sized
{
    fn diverge(
        warp: crate::warp::Warp<Self>,
    ) -> (
        crate::warp::Warp<TrueBranch>,
        crate::warp::Warp<FalseBranch>,
    );
}

/// No lanes active (degenerate). Not part of the diverge hierarchy.
#[derive(Copy, Clone, Debug, Default)]
pub struct Empty;
#[allow(private_interfaces)]
impl sealed::Sealed for Empty {
    fn _sealed() -> sealed::SealToken {
        sealed::SealToken
    }
}
impl ActiveSet for Empty {
    const MASK: u64 = 0;
    const NAME: &'static str = "Empty";
}

// ============================================================================
// Generated active set hierarchy
//
// The warp_sets! macro validates at compile time:
//   - Each pair is disjoint (true_mask & false_mask == 0)
//   - Each pair covers its parent (true_mask | false_mask == parent_mask)
//   - Children are subsets of parent (child_mask & !parent_mask == 0)
//
// Shared types (e.g., EvenLow under both Even and LowHalf) are deduplicated.
// ============================================================================

// 32-lane NVIDIA warps (default)
#[cfg(not(feature = "warp64"))]
warp_types_macros::warp_sets! {
    All = 0xFFFFFFFF {
        Even = 0x55555555 / Odd = 0xAAAAAAAA,
        LowHalf = 0x0000FFFF / HighHalf = 0xFFFF0000,
        Lane0 = 0x00000001 / NotLane0 = 0xFFFFFFFE,
    }
    Even = 0x55555555 {
        EvenLow = 0x00005555 / EvenHigh = 0x55550000,
    }
    Odd = 0xAAAAAAAA {
        OddLow = 0x0000AAAA / OddHigh = 0xAAAA0000,
    }
    LowHalf = 0x0000FFFF {
        EvenLow = 0x00005555 / OddLow = 0x0000AAAA,
    }
    HighHalf = 0xFFFF0000 {
        EvenHigh = 0x55550000 / OddHigh = 0xAAAA0000,
    }
}

// 64-lane AMD wavefronts (warp64 feature)
#[cfg(feature = "warp64")]
warp_types_macros::warp_sets! {
    All = 0xFFFFFFFFFFFFFFFF {
        Even = 0x5555555555555555 / Odd = 0xAAAAAAAAAAAAAAAA,
        LowHalf = 0x00000000FFFFFFFF / HighHalf = 0xFFFFFFFF00000000,
        Lane0 = 0x0000000000000001 / NotLane0 = 0xFFFFFFFFFFFFFFFE,
    }
    Even = 0x5555555555555555 {
        EvenLow = 0x0000000055555555 / EvenHigh = 0x5555555500000000,
    }
    Odd = 0xAAAAAAAAAAAAAAAA {
        OddLow = 0x00000000AAAAAAAA / OddHigh = 0xAAAAAAAA00000000,
    }
    LowHalf = 0x00000000FFFFFFFF {
        EvenLow = 0x0000000055555555 / OddLow = 0x00000000AAAAAAAA,
    }
    HighHalf = 0xFFFFFFFF00000000 {
        EvenHigh = 0x5555555500000000 / OddHigh = 0xAAAAAAAA00000000,
    }
}

// Empty/All complement pair — Empty isn't produced by any diverge,
// so it's not part of the generated hierarchy
impl ComplementOf<Empty> for All {}
impl ComplementOf<All> for Empty {}

// NOTE: EvenLow/EvenHigh are complements within Even, NOT within All.
// ComplementOf requires covering ALL lanes, so these do NOT get ComplementOf impls.
// Use merge_within<EvenLow, EvenHigh, Even>() for nested merges.
// See also: ComplementWithin impls generated by warp_sets! macro.

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    #[cfg(not(feature = "warp64"))]
    fn test_mask_values() {
        assert_eq!(All::MASK, 0xFFFFFFFF);
        assert_eq!(Empty::MASK, 0x00000000);
        assert_eq!(Even::MASK, 0x55555555);
        assert_eq!(Odd::MASK, 0xAAAAAAAA);
        assert_eq!(LowHalf::MASK, 0x0000FFFF);
        assert_eq!(HighHalf::MASK, 0xFFFF0000);
        assert_eq!(Lane0::MASK, 0x00000001);
        assert_eq!(NotLane0::MASK, 0xFFFFFFFE);
        assert_eq!(EvenLow::MASK, 0x00005555);
        assert_eq!(EvenHigh::MASK, 0x55550000);
        assert_eq!(OddLow::MASK, 0x0000AAAA);
        assert_eq!(OddHigh::MASK, 0xAAAA0000);
    }

    #[test]
    #[cfg(feature = "warp64")]
    fn test_mask_values_64() {
        assert_eq!(All::MASK, 0xFFFFFFFFFFFFFFFF);
        assert_eq!(Empty::MASK, 0x00000000);
        assert_eq!(Even::MASK, 0x5555555555555555);
        assert_eq!(Odd::MASK, 0xAAAAAAAAAAAAAAAA);
        assert_eq!(LowHalf::MASK, 0x00000000FFFFFFFF);
        assert_eq!(HighHalf::MASK, 0xFFFFFFFF00000000);
        assert_eq!(Lane0::MASK, 0x0000000000000001);
        assert_eq!(NotLane0::MASK, 0xFFFFFFFFFFFFFFFE);
        assert_eq!(EvenLow::MASK, 0x0000000055555555);
        assert_eq!(EvenHigh::MASK, 0x5555555500000000);
        assert_eq!(OddLow::MASK, 0x00000000AAAAAAAA);
        assert_eq!(OddHigh::MASK, 0xAAAAAAAA00000000);
    }

    #[test]
    fn test_intersection_properties() {
        assert_eq!(Even::MASK & LowHalf::MASK, EvenLow::MASK);
        assert_eq!(Even::MASK & HighHalf::MASK, EvenHigh::MASK);
        assert_eq!(Odd::MASK & LowHalf::MASK, OddLow::MASK);
        assert_eq!(Odd::MASK & HighHalf::MASK, OddHigh::MASK);
    }

    #[test]
    fn test_union_properties() {
        assert_eq!(EvenLow::MASK | EvenHigh::MASK, Even::MASK);
        assert_eq!(OddLow::MASK | OddHigh::MASK, Odd::MASK);
        assert_eq!(EvenLow::MASK | OddLow::MASK, LowHalf::MASK);
        assert_eq!(EvenHigh::MASK | OddHigh::MASK, HighHalf::MASK);
        assert_eq!(
            EvenLow::MASK | EvenHigh::MASK | OddLow::MASK | OddHigh::MASK,
            All::MASK
        );
    }

    #[test]
    fn test_pairwise_disjoint() {
        let sets = [EvenLow::MASK, EvenHigh::MASK, OddLow::MASK, OddHigh::MASK];
        for i in 0..sets.len() {
            for j in (i + 1)..sets.len() {
                assert_eq!(sets[i] & sets[j], 0, "sets {} and {} overlap", i, j);
            }
        }
    }

    #[test]
    fn test_complement_symmetry() {
        assert_eq!(Even::MASK | Odd::MASK, All::MASK);
        assert_eq!(Even::MASK & Odd::MASK, 0);
        assert_eq!(LowHalf::MASK | HighHalf::MASK, All::MASK);
        assert_eq!(LowHalf::MASK & HighHalf::MASK, 0);
        assert_eq!(Lane0::MASK | NotLane0::MASK, All::MASK);
        assert_eq!(Lane0::MASK & NotLane0::MASK, 0);
    }
}