warp-types 0.3.2

Type-safe GPU warp programming via linear typestate: compile-time prevention of shuffle-from-inactive-lane bugs
Documentation
//! Register Pressure Tracking
//!
//! **STATUS: Research** — Exploratory prototype, not promoted to main API.
//!
//! Research question: "Can we track register pressure linearly?"
//!
//! # Background
//!
//! GPU register allocation is critical for performance:
//! - Each thread has limited registers (64-255 depending on GPU)
//! - More registers per thread = fewer concurrent threads (occupancy)
//! - Spilling to local memory is very slow
//!
//! Can the type system track register usage to prevent spills?
//!
//! # Key Insight
//!
//! Linear types can track resources that must be used exactly once.
//! Registers are such resources:
//! - Allocated when variable is created
//! - Freed when variable goes out of scope
//! - Total usage must not exceed limit
//!
//! # Challenges
//!
//! 1. Register count depends on variable types (f32 = 1, f64 = 2, etc.)
//! 2. Compiler may use more registers than source suggests (temporaries)
//! 3. Control flow affects live ranges (different paths = different pressure)
//!
//! # Approach
//!
//! Track APPROXIMATE pressure using type-level natural numbers.
//! This won't be exact but can catch obvious overuse.

use std::marker::PhantomData;

// ============================================================================
// TYPE-LEVEL NATURAL NUMBERS (Peano)
// ============================================================================

/// Zero registers
pub struct Z;

/// Successor: S<N> = N + 1
pub struct S<N>(PhantomData<N>);

// Type aliases for convenience
pub type N0 = Z;
pub type N1 = S<N0>;
pub type N2 = S<N1>;
pub type N3 = S<N2>;
pub type N4 = S<N3>;
pub type N5 = S<N4>;
pub type N8 = S<S<S<N5>>>; // 5 + 3 = 8
pub type N16 = S<S<S<S<S<S<S<S<N8>>>>>>>>; // 8 + 8 = 16
pub type N32 = S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<S<N16>>>>>>>>>>>>>>>>; // 16 + 16 = 32

// ============================================================================
// REGISTER BUDGET
// ============================================================================

/// A register budget with N registers remaining
pub struct Budget<N> {
    _remaining: PhantomData<N>,
}

impl<N> Budget<N> {
    pub fn new() -> Self {
        Budget {
            _remaining: PhantomData,
        }
    }
}

/// Trait for types that can be subtracted
pub trait Sub<Rhs> {
    type Output;
}

// N - 0 = N
impl<N> Sub<Z> for N {
    type Output = N;
}

// S<N> - S<M> = N - M (for N >= M)
impl<N, M> Sub<S<M>> for S<N>
where
    N: Sub<M>,
{
    type Output = <N as Sub<M>>::Output;
}

// ============================================================================
// REGISTER-TRACKED VALUES
// ============================================================================

/// A value that uses R registers
pub struct Reg<T: Copy, R> {
    value: T,
    _registers: PhantomData<R>,
}

impl<T: Copy, R> Reg<T, R> {
    /// Create a register-tracked value
    pub fn new(value: T) -> Self {
        Reg {
            value,
            _registers: PhantomData,
        }
    }

    pub fn get(&self) -> T {
        self.value
    }
}

// Size of common types in registers
pub trait RegisterSize {
    type Size;
}

impl RegisterSize for i32 {
    type Size = N1;
}
impl RegisterSize for u32 {
    type Size = N1;
}
impl RegisterSize for f32 {
    type Size = N1;
}
impl RegisterSize for i64 {
    type Size = N2;
}
impl RegisterSize for u64 {
    type Size = N2;
}
impl RegisterSize for f64 {
    type Size = N2;
}
impl RegisterSize for bool {
    type Size = N1;
} // Predicate register

// Arrays use N * element_size registers
impl<T: RegisterSize, const N: usize> RegisterSize for [T; N]
where
// This would need type-level multiplication which is complex
// Simplified: treat arrays as using many registers
{
    type Size = N32; // Conservative estimate
}

// ============================================================================
// ALLOCATION AND DEALLOCATION
// ============================================================================

/// Allocate registers for a value
pub fn alloc<T, R, B>(_budget: Budget<B>, value: T) -> (Budget<<B as Sub<R>>::Output>, Reg<T, R>)
where
    T: Copy + RegisterSize<Size = R>,
    B: Sub<R>,
{
    (Budget::new(), Reg::new(value))
}

/// Frees one register slot (simplified — a full implementation would track R
/// and add it back to B, but this prototype only increments by one).
pub fn free<T: Copy, R, B>(_budget: Budget<B>, _reg: Reg<T, R>) -> Budget<S<B>>
where
    // This is simplified - real implementation would add R to B
{
    Budget::new()
}

// ============================================================================
// SIMPLIFIED RUNTIME TRACKING
// ============================================================================

/// Runtime register tracker (simpler, doesn't use type-level numbers)
pub mod runtime {
    /// Track register usage at runtime
    #[derive(Debug, Clone)]
    pub struct RegisterTracker {
        used: usize,
        limit: usize,
        peak: usize,
    }

    impl RegisterTracker {
        pub fn new(limit: usize) -> Self {
            RegisterTracker {
                used: 0,
                limit,
                peak: 0,
            }
        }

        /// Allocate registers, returns false if would exceed limit
        pub fn alloc(&mut self, count: usize) -> bool {
            if self.used + count > self.limit {
                return false;
            }
            self.used += count;
            self.peak = self.peak.max(self.used);
            true
        }

        /// Free registers
        pub fn free(&mut self, count: usize) {
            self.used = self.used.saturating_sub(count);
        }

        pub fn used(&self) -> usize {
            self.used
        }
        pub fn remaining(&self) -> usize {
            self.limit - self.used
        }
        pub fn peak(&self) -> usize {
            self.peak
        }
    }

    /// A value tracked by a register tracker
    pub struct TrackedValue<T> {
        value: T,
        reg_count: usize,
    }

    impl<T> TrackedValue<T> {
        pub fn new(tracker: &mut RegisterTracker, value: T, reg_count: usize) -> Option<Self> {
            if tracker.alloc(reg_count) {
                Some(TrackedValue { value, reg_count })
            } else {
                None
            }
        }

        pub fn get(&self) -> &T {
            &self.value
        }

        pub fn free(self, tracker: &mut RegisterTracker) -> T {
            tracker.free(self.reg_count);
            self.value
        }
    }

    #[cfg(test)]
    mod tests {
        use super::*;

        #[test]
        fn test_basic_tracking() {
            let mut tracker = RegisterTracker::new(64);

            assert_eq!(tracker.used(), 0);
            assert_eq!(tracker.remaining(), 64);

            // Allocate 10 registers
            assert!(tracker.alloc(10));
            assert_eq!(tracker.used(), 10);

            // Allocate 50 more
            assert!(tracker.alloc(50));
            assert_eq!(tracker.used(), 60);

            // Try to allocate 10 more (would exceed limit)
            assert!(!tracker.alloc(10));
            assert_eq!(tracker.used(), 60);

            // Free some
            tracker.free(20);
            assert_eq!(tracker.used(), 40);

            // Peak should be 60
            assert_eq!(tracker.peak(), 60);
        }

        #[test]
        fn test_tracked_values() {
            let mut tracker = RegisterTracker::new(10);

            // Allocate a value using 4 registers
            let v1 = TrackedValue::new(&mut tracker, 42i32, 4).unwrap();
            assert_eq!(tracker.used(), 4);

            // Allocate another using 4
            let v2 = TrackedValue::new(&mut tracker, 3.14f32, 4).unwrap();
            assert_eq!(tracker.used(), 8);

            // Can't allocate 4 more (8 + 4 > 10)
            assert!(TrackedValue::new(&mut tracker, 0u64, 4).is_none());

            // Free v1
            let val1 = v1.free(&mut tracker);
            assert_eq!(val1, 42);
            assert_eq!(tracker.used(), 4);

            // Now we can allocate
            let _v3 = TrackedValue::new(&mut tracker, 0u64, 4).unwrap();
            assert_eq!(tracker.used(), 8);

            let _ = v2; // Use v2 to avoid warning
        }
    }
}

// ============================================================================
// SUMMARY
// ============================================================================

/// Summary: Can we track register pressure linearly?
///
/// ## Answer: Partially, with limitations
///
/// ### What Works
///
/// 1. **Runtime tracking**: Track allocation/deallocation, enforce limits
///    - Simple to implement
///    - Catches actual overuse
///    - Runtime overhead
///
/// 2. **Type-level approximation**: Use Peano numbers to track pressure
///    - Zero runtime overhead
///    - Catches some issues at compile time
///    - Complex type signatures
///
/// ### Limitations
///
/// 1. **Compiler temporaries**: Compiler may use more registers than source
/// 2. **Spilling decisions**: Compiler decides when to spill, not programmer
/// 3. **Live ranges**: Optimal allocation requires whole-function analysis
/// 4. **Branching**: Different branches may have different pressure
///
/// ### Practical Approach
///
/// 1. Use runtime tracking during development to profile pressure
/// 2. Use type-level hints for hot paths (e.g., `#[max_registers(32)]`)
/// 3. Trust compiler for final allocation, but guide it with annotations
///
/// ### Relation to Linear Types
///
/// Register tracking IS linear resource management:
/// - Allocate: Use resource
/// - Free: Return resource
/// - Budget: Finite pool
///
/// But registers are FUNGIBLE (any register works for any value),
/// unlike affine resources which have specific identities.
///
/// ### Verdict
///
/// Type-level register tracking is POSSIBLE but IMPRACTICAL for full precision.
/// Runtime tracking + compiler hints is more realistic.
/// Linear types help conceptually but exact tracking needs compiler support.
pub const _SUMMARY: () = ();

#[cfg(test)]
mod type_level_tests {
    use super::*;

    #[test]
    fn test_type_level_allocation() {
        // Start with 5 register budget
        let budget: Budget<N5> = Budget::new();

        // Allocate an i32 (1 register)
        let (budget, reg1): (Budget<N4>, Reg<i32, N1>) = alloc(budget, 42i32);

        // Allocate another i32
        let (budget, reg2): (Budget<N3>, Reg<i32, N1>) = alloc(budget, 10i32);

        // Use the values
        assert_eq!(reg1.get() + reg2.get(), 52);

        // Budget is now N3 (started with N5, used N2)
        let _: Budget<N3> = budget;
    }
}