warp-types 0.3.2

Type-safe GPU warp programming via linear typestate: compile-time prevention of shuffle-from-inactive-lane bugs
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
//! Arbitrary Predicates: Beyond Marker Types
//!
//! **STATUS: Validated** — Research exploration complete. See conclusions below.
//!
//! THE LIMITATION:
//! Our marker types (Even, Odd, LowHalf) only cover predefined patterns.
//! What about arbitrary runtime predicates?
//!
//! ```ignore
//! let threshold = compute_threshold();  // Runtime value!
//! let (below, above) = warp.diverge(|lane| lane < threshold);
//! // What type is `below`? Can't be a marker type - threshold is runtime!
//! ```
//!
//! This module explores the design space for handling arbitrary predicates.

use std::marker::PhantomData;

// ============================================================================
// APPROACH 1: EXISTENTIAL TYPES (Type erasure with witness)
// ============================================================================

/// When we can't name the exact active set statically, use existential:
/// "There exists some set S such that this is a Warp<S>"
///
/// This loses precision but maintains the invariant that SOME set exists.
pub mod existential {

    pub trait ActiveSet {
        fn mask(&self) -> u32;
    }

    /// A warp with an unknown (existentially quantified) active set.
    ///
    /// We know:
    /// - Some subset of lanes is active
    /// - The mask is consistent
    ///
    /// We DON'T know:
    /// - Which specific lanes (at compile time)
    /// - Whether it complements another set
    pub struct SomeWarp {
        mask: u32,
    }

    impl SomeWarp {
        /// Create from runtime predicate
        pub fn from_predicate<F: Fn(u32) -> bool>(pred: F) -> Self {
            let mut mask = 0u32;
            for lane in 0..32 {
                if pred(lane) {
                    mask |= 1 << lane;
                }
            }
            SomeWarp { mask }
        }

        pub fn mask(&self) -> u32 {
            self.mask
        }

        pub fn population(&self) -> u32 {
            self.mask.count_ones()
        }

        /// Check at runtime if this complements another SomeWarp
        pub fn complements(&self, other: &SomeWarp) -> bool {
            // Disjoint and covering
            (self.mask & other.mask) == 0 && (self.mask | other.mask) == 0xFFFFFFFF
        }
    }

    /// Diverge with arbitrary predicate - returns existential warps
    pub fn diverge_arbitrary<F: Fn(u32) -> bool>(pred: F) -> (SomeWarp, SomeWarp) {
        let true_branch = SomeWarp::from_predicate(&pred);
        let false_branch = SomeWarp::from_predicate(|lane| !pred(lane));
        (true_branch, false_branch)
    }

    /// Merge with runtime complement check
    pub fn merge_checked(left: SomeWarp, right: SomeWarp) -> Result<SomeWarp, &'static str> {
        if left.complements(&right) {
            Ok(SomeWarp {
                mask: left.mask | right.mask,
            })
        } else {
            Err("Warps are not complementary")
        }
    }

    #[cfg(test)]
    mod tests {
        use super::*;

        #[test]
        fn test_arbitrary_predicate() {
            let threshold = 10u32; // Runtime value!

            let (below, above) = diverge_arbitrary(|lane| lane < threshold);

            assert_eq!(below.population(), 10); // Lanes 0-9
            assert_eq!(above.population(), 22); // Lanes 10-31
            assert!(below.complements(&above));
        }

        #[test]
        fn test_merge_checked() {
            let (a, b) = diverge_arbitrary(|lane| lane % 3 == 0);

            // These complement
            assert!(merge_checked(a, b).is_ok());

            // These overlap — different predicates, not complements
            let (c, _) = diverge_arbitrary(|lane| lane < 5);
            let (d, _) = diverge_arbitrary(|lane| lane < 10);
            assert_ne!(
                c.mask, d.mask,
                "overlapping predicates produce different masks"
            );

            // Merging non-complements must fail
            assert!(
                merge_checked(c, d).is_err(),
                "overlapping warps should fail merge"
            );
        }
    }
}

// ============================================================================
// APPROACH 2: REFINEMENT TYPES (Predicate in the type)
// ============================================================================

/// Refinement types: Warp<{s | P(s)}>
///
/// The predicate P is part of the type. Two warps with the same predicate
/// have the same type; different predicates = different types.
///
/// Challenge: Predicate equality is undecidable in general.
/// Solution: Use syntactic equality or canonical forms.
pub mod refinement {
    use super::*;

    /// A predicate on lane IDs, represented as a trait
    pub trait LanePredicate: Copy {
        fn test(lane: u32) -> bool;
        fn name() -> &'static str;
    }

    /// Warp refined by a predicate
    pub struct RefinedWarp<P: LanePredicate> {
        _marker: PhantomData<P>,
    }

    impl<P: LanePredicate> RefinedWarp<P> {
        pub fn new() -> Self {
            RefinedWarp {
                _marker: PhantomData,
            }
        }

        pub fn mask() -> u32 {
            let mut m = 0u32;
            for lane in 0..32 {
                if P::test(lane) {
                    m |= 1 << lane;
                }
            }
            m
        }
    }

    /// Negation of a predicate
    pub struct Not<P: LanePredicate>(PhantomData<P>);

    impl<P: LanePredicate> Copy for Not<P> {}
    impl<P: LanePredicate> Clone for Not<P> {
        fn clone(&self) -> Self {
            *self
        }
    }

    impl<P: LanePredicate> LanePredicate for Not<P> {
        fn test(lane: u32) -> bool {
            !P::test(lane)
        }
        fn name() -> &'static str {
            "Not<P>" // Would need const generics for real name
        }
    }

    /// Diverge produces a warp and its negation
    pub fn diverge<P: LanePredicate>() -> (RefinedWarp<P>, RefinedWarp<Not<P>>) {
        (RefinedWarp::new(), RefinedWarp::new())
    }

    /// P and Not<P> are always complementary - STATICALLY KNOWN!
    pub fn merge<P: LanePredicate>(
        _left: RefinedWarp<P>,
        _right: RefinedWarp<Not<P>>,
    ) -> RefinedWarp<All> {
        RefinedWarp::new()
    }

    // Concrete predicates
    #[derive(Copy, Clone)]
    pub struct All;
    impl LanePredicate for All {
        fn test(_: u32) -> bool {
            true
        }
        fn name() -> &'static str {
            "All"
        }
    }

    #[derive(Copy, Clone)]
    pub struct LessThan<const N: u32>;
    impl<const N: u32> LanePredicate for LessThan<N> {
        fn test(lane: u32) -> bool {
            lane < N
        }
        fn name() -> &'static str {
            "LessThan<N>"
        }
    }

    /// THE LIMITATION: N must be a const, not a runtime value!
    /// LessThan<threshold> where threshold is runtime doesn't work.

    #[cfg(test)]
    mod tests {
        use super::*;

        #[test]
        fn test_refinement_static() {
            // With const, this works beautifully
            let (below, above) = diverge::<LessThan<10>>();

            assert_eq!(RefinedWarp::<LessThan<10>>::mask(), 0x000003FF); // Lanes 0-9
            assert_eq!(RefinedWarp::<Not<LessThan<10>>>::mask(), 0xFFFFFC00); // Lanes 10-31

            // Merge is statically verified
            let _all = merge(below, above);
        }
    }
}

// ============================================================================
// APPROACH 3: INDEXED TYPES (Predicate as type parameter)
// ============================================================================

/// Instead of embedding the predicate, use an INDEX that identifies it.
///
/// Runtime: A table maps indices to predicates.
/// Compile-time: Indices have complement relationships.
pub mod indexed {

    /// An index identifying a predicate (opaque token)
    #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
    pub struct PredicateId(u32);

    /// A warp with a predicate identified by index
    pub struct IndexedWarp {
        predicate_id: PredicateId,
        mask: u32,
    }

    /// Registry of predicates and their relationships
    pub struct PredicateRegistry {
        next_id: u32,
        masks: std::collections::HashMap<PredicateId, u32>,
        complements: std::collections::HashMap<PredicateId, PredicateId>,
    }

    impl PredicateRegistry {
        pub fn new() -> Self {
            PredicateRegistry {
                next_id: 0,
                masks: std::collections::HashMap::new(),
                complements: std::collections::HashMap::new(),
            }
        }

        /// Register a predicate and its complement
        pub fn register<F: Fn(u32) -> bool>(&mut self, pred: F) -> (PredicateId, PredicateId) {
            let id_true = PredicateId(self.next_id);
            let id_false = PredicateId(self.next_id + 1);
            self.next_id += 2;

            let mut mask_true = 0u32;
            for lane in 0..32 {
                if pred(lane) {
                    mask_true |= 1 << lane;
                }
            }
            let mask_false = !mask_true;

            self.masks.insert(id_true, mask_true);
            self.masks.insert(id_false, mask_false);
            self.complements.insert(id_true, id_false);
            self.complements.insert(id_false, id_true);

            (id_true, id_false)
        }

        pub fn are_complements(&self, a: PredicateId, b: PredicateId) -> bool {
            self.complements.get(&a) == Some(&b)
        }

        pub fn mask(&self, id: PredicateId) -> u32 {
            self.masks.get(&id).copied().unwrap_or(0)
        }
    }

    impl IndexedWarp {
        pub fn new(id: PredicateId, mask: u32) -> Self {
            IndexedWarp {
                predicate_id: id,
                mask,
            }
        }

        pub fn id(&self) -> PredicateId {
            self.predicate_id
        }

        pub fn mask(&self) -> u32 {
            self.mask
        }
    }

    /// Merge with registry lookup
    pub fn merge_indexed(
        registry: &PredicateRegistry,
        left: IndexedWarp,
        right: IndexedWarp,
    ) -> Result<IndexedWarp, &'static str> {
        if registry.are_complements(left.id(), right.id()) {
            Ok(IndexedWarp::new(
                PredicateId(u32::MAX), // "All" sentinel
                left.mask | right.mask,
            ))
        } else {
            Err("Predicates are not registered complements")
        }
    }

    #[cfg(test)]
    mod tests {
        use super::*;

        #[test]
        fn test_indexed_predicates() {
            let mut registry = PredicateRegistry::new();

            // Register a runtime predicate
            let threshold = 15u32; // Could be computed at runtime!
            let (below_id, above_id) = registry.register(|lane| lane < threshold);

            let below = IndexedWarp::new(below_id, registry.mask(below_id));
            let above = IndexedWarp::new(above_id, registry.mask(above_id));

            assert!(registry.are_complements(below_id, above_id));
            assert!(merge_indexed(&registry, below, above).is_ok());
        }
    }
}

// ============================================================================
// APPROACH 4: HYBRID - Static shape, dynamic mask
// ============================================================================

/// Key insight: We often know the SHAPE of divergence statically,
/// even when the exact MASK is dynamic.
///
/// Shape = "split by threshold" (two contiguous ranges)
/// Mask = which specific lanes (depends on runtime threshold)
///
/// We can verify shapes statically and masks dynamically.
pub mod hybrid_shape {

    /// Shape of an active set - known statically
    #[derive(Copy, Clone, Debug, PartialEq)]
    pub enum Shape {
        All,
        /// Contiguous range [0, threshold)
        LowRange,
        /// Contiguous range [threshold, 32)
        HighRange,
        /// Arbitrary (no static guarantees)
        Arbitrary,
    }

    /// A warp with known shape but dynamic mask
    pub struct ShapedWarp {
        shape: Shape,
        mask: u32,
    }

    impl ShapedWarp {
        pub fn all() -> Self {
            ShapedWarp {
                shape: Shape::All,
                mask: 0xFFFFFFFF,
            }
        }

        pub fn shape(&self) -> Shape {
            self.shape
        }

        pub fn mask(&self) -> u32 {
            self.mask
        }
    }

    /// Diverge by threshold - shape is known, mask is dynamic
    pub fn diverge_by_threshold(threshold: u32) -> (ShapedWarp, ShapedWarp) {
        let low_mask = if threshold >= 32 {
            u32::MAX
        } else {
            (1u32 << threshold) - 1
        };
        let high_mask = !low_mask;

        (
            ShapedWarp {
                shape: Shape::LowRange,
                mask: low_mask,
            },
            ShapedWarp {
                shape: Shape::HighRange,
                mask: high_mask,
            },
        )
    }

    /// Shape-aware merge - LowRange + HighRange = All (statically!)
    pub fn merge_shaped(left: ShapedWarp, right: ShapedWarp) -> Result<ShapedWarp, &'static str> {
        match (left.shape, right.shape) {
            // Statically known complement shapes
            (Shape::LowRange, Shape::HighRange) | (Shape::HighRange, Shape::LowRange) => {
                Ok(ShapedWarp {
                    shape: Shape::All,
                    mask: left.mask | right.mask,
                })
            }
            // Runtime check for arbitrary shapes
            (Shape::Arbitrary, Shape::Arbitrary) => {
                if (left.mask & right.mask) == 0 && (left.mask | right.mask) == 0xFFFFFFFF {
                    Ok(ShapedWarp {
                        shape: Shape::All,
                        mask: 0xFFFFFFFF,
                    })
                } else {
                    Err("Arbitrary shapes don't complement")
                }
            }
            _ => Err("Incompatible shapes"),
        }
    }

    /// Operations available depend on shape
    impl ShapedWarp {
        /// Shuffle only on All shape
        pub fn shuffle_xor(&self, data: i32, _mask: u32) -> Option<i32> {
            match self.shape {
                Shape::All => Some(data), // Placeholder
                _ => None,
            }
        }

        /// Broadcast available on any contiguous range
        pub fn broadcast_first(&self, data: i32) -> Option<i32> {
            match self.shape {
                Shape::All | Shape::LowRange | Shape::HighRange => Some(data),
                Shape::Arbitrary => None, // Can't identify "first"
            }
        }
    }

    #[cfg(test)]
    mod tests {
        use super::*;

        #[test]
        fn test_shape_merge() {
            let threshold = 20u32; // Runtime!
            let (low, high) = diverge_by_threshold(threshold);

            assert_eq!(low.shape(), Shape::LowRange);
            assert_eq!(high.shape(), Shape::HighRange);

            // Shape merge is STATICALLY verified (LowRange + HighRange = All)
            let merged = merge_shaped(low, high).unwrap();
            assert_eq!(merged.shape(), Shape::All);
        }

        #[test]
        fn test_shape_operations() {
            let (low, _high) = diverge_by_threshold(16);

            // Shuffle not available on LowRange
            assert!(low.shuffle_xor(42, 1).is_none());

            // Broadcast is available
            assert!(low.broadcast_first(42).is_some());
        }
    }
}

// ============================================================================
// APPROACH 5: DEPENDENT TYPES (What a host language could do)
// ============================================================================

/// With full dependent types, we could write:
///
/// ```ignore
/// fn diverge<P: Lane -> Bool>(w: Warp<All>) -> (Warp<{l | P(l)}>, Warp<{l | !P(l)}>) {
///     ...
/// }
///
/// fn merge<S1, S2>(w1: Warp<S1>, w2: Warp<S2>) -> Warp<S1 ∪ S2>
///     where S1 ∩ S2 = ∅
/// {
///     ...
/// }
/// ```
///
/// The type checker would verify:
/// - S1 and S2 are disjoint (compile-time proof obligation)
/// - Result is exactly S1 ∪ S2
///
/// This is what languages like Idris, Agda, or F* can do.
/// Rust can approximate with trait bounds, but lacks full power.
pub mod dependent_sketch {
    // This is a sketch of what dependent types would enable.
    //
    // Key capabilities:
    // 1. Types can depend on runtime values
    // 2. Type equality can require proofs
    // 3. Refinements can express arbitrary predicates
    //
    // For a language with dependent types:
    // - Extend session types with lane predicates
    // - Use SMT solver for predicate satisfiability
    // - Allow runtime fallback for undecidable cases

    // Pseudo-syntax for what we'd want:
    //
    // type ActiveSet = { mask: u32 | popcount(mask) > 0 }
    //
    // fn diverge(w: Warp, pred: Lane -> Bool)
    //   -> (Warp<{ l | pred(l) }>, Warp<{ l | !pred(l) }>)
    //
    // fn merge(w1: Warp<S1>, w2: Warp<S2>)
    //   -> Warp<S1 | S2>
    //   requires S1 & S2 == 0
    //
    // fn shuffle(w: Warp<All>, data: PerLane<T>, perm: Perm) -> PerLane<T>
    //   // Only allowed when active set is All
}

// ============================================================================
// SUMMARY: Recommendations
// ============================================================================
//
// For warp-types research, recommend a LAYERED approach:
//
// LAYER 1: Marker types (current)
// - Static, zero overhead
// - Covers: Even/Odd, LowHalf/HighHalf, common patterns
// - Use for: Performance-critical code with known patterns
//
// LAYER 2: Shaped types (new)
// - Static shape, dynamic mask
// - Covers: Threshold splits, strided patterns
// - Use for: Data-dependent divergence with regular structure
//
// LAYER 3: Indexed types (new)
// - Registry tracks predicates and relationships
// - Covers: Repeated dynamic patterns
// - Use for: Patterns that recur within a kernel
//
// LAYER 4: Existential types (fallback)
// - Runtime checks only
// - Covers: Arbitrary predicates
// - Use for: Rare, irregular patterns
//
// This gives a spectrum from most-static to most-dynamic.
// Most real GPU code uses Layer 1-2. Layer 3-4 are escape hatches.

#[cfg(test)]
mod integration_tests {
    use super::*;

    #[test]
    fn test_layered_approach() {
        // Layer 1: Marker types (compile-time)
        // (covered by existing static_verify.rs)

        // Layer 2: Shaped types
        let threshold = 20u32;
        let (low, high) = hybrid_shape::diverge_by_threshold(threshold);
        let merged = hybrid_shape::merge_shaped(low, high).unwrap();
        assert_eq!(merged.shape(), hybrid_shape::Shape::All);

        // Layer 3: Indexed types
        let mut registry = indexed::PredicateRegistry::new();
        let (id_a, id_b) = registry.register(|lane| lane % 5 == 0);
        assert!(registry.are_complements(id_a, id_b));

        // Layer 4: Existential types
        let (some_a, some_b) = existential::diverge_arbitrary(|lane| lane.count_ones() % 2 == 0);
        assert!(some_a.complements(&some_b));
    }
}