Skip to main content

ringkernel_wgpu_codegen/
loops.rs

1//! Loop transpilation for WGSL code generation.
2//!
3//! Handles Rust for/while/loop constructs and converts them to WGSL equivalents.
4
5/// Represents recognized loop patterns from Rust DSL.
6#[derive(Debug, Clone)]
7pub enum LoopPattern {
8    /// `for i in start..end` - exclusive range
9    ForRange {
10        var: String,
11        start: String,
12        end: String,
13        inclusive: bool,
14    },
15    /// `while condition { ... }`
16    While { condition: String },
17    /// `loop { ... }` - infinite loop with break
18    Loop,
19}
20
21impl LoopPattern {
22    /// Generate the WGSL loop header.
23    pub fn to_wgsl_header(&self) -> String {
24        match self {
25            LoopPattern::ForRange {
26                var,
27                start,
28                end,
29                inclusive,
30            } => {
31                let op = if *inclusive { "<=" } else { "<" };
32                format!("for (var {var}: i32 = {start}; {var} {op} {end}; {var} = {var} + 1)")
33            }
34            LoopPattern::While { condition } => {
35                format!("while ({condition})")
36            }
37            LoopPattern::Loop => "loop".to_string(),
38        }
39    }
40}
41
42/// Information about a range expression.
43#[derive(Debug, Clone)]
44pub struct RangeInfo {
45    /// Start of range (or None for `..end`)
46    pub start: Option<String>,
47    /// End of range (or None for `start..`)
48    pub end: Option<String>,
49    /// Whether the range is inclusive (`..=`)
50    pub inclusive: bool,
51}
52
53impl RangeInfo {
54    /// Create a new range info.
55    pub fn new(start: Option<String>, end: Option<String>, inclusive: bool) -> Self {
56        Self {
57            start,
58            end,
59            inclusive,
60        }
61    }
62
63    /// Get the start expression, defaulting to "0" if not specified.
64    pub fn start_or_default(&self) -> String {
65        self.start.clone().unwrap_or_else(|| "0".to_string())
66    }
67
68    /// Get the end expression, or None if unbounded.
69    pub fn end_expr(&self) -> Option<&str> {
70        self.end.as_deref()
71    }
72}
73
74/// Convert a Rust range to a WGSL for loop pattern.
75pub fn range_to_for_loop(var: &str, range: &RangeInfo) -> LoopPattern {
76    LoopPattern::ForRange {
77        var: var.to_string(),
78        start: range.start_or_default(),
79        end: range
80            .end
81            .clone()
82            .unwrap_or_else(|| "/* unbounded */".to_string()),
83        inclusive: range.inclusive,
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90
91    #[test]
92    fn test_for_range_exclusive() {
93        let pattern = LoopPattern::ForRange {
94            var: "i".to_string(),
95            start: "0".to_string(),
96            end: "10".to_string(),
97            inclusive: false,
98        };
99        assert_eq!(
100            pattern.to_wgsl_header(),
101            "for (var i: i32 = 0; i < 10; i = i + 1)"
102        );
103    }
104
105    #[test]
106    fn test_for_range_inclusive() {
107        let pattern = LoopPattern::ForRange {
108            var: "j".to_string(),
109            start: "1".to_string(),
110            end: "n".to_string(),
111            inclusive: true,
112        };
113        assert_eq!(
114            pattern.to_wgsl_header(),
115            "for (var j: i32 = 1; j <= n; j = j + 1)"
116        );
117    }
118
119    #[test]
120    fn test_while_loop() {
121        let pattern = LoopPattern::While {
122            condition: "x > 0".to_string(),
123        };
124        assert_eq!(pattern.to_wgsl_header(), "while (x > 0)");
125    }
126
127    #[test]
128    fn test_infinite_loop() {
129        let pattern = LoopPattern::Loop;
130        assert_eq!(pattern.to_wgsl_header(), "loop");
131    }
132
133    #[test]
134    fn test_range_info() {
135        let range = RangeInfo::new(Some("0".to_string()), Some("10".to_string()), false);
136        let pattern = range_to_for_loop("i", &range);
137        assert_eq!(
138            pattern.to_wgsl_header(),
139            "for (var i: i32 = 0; i < 10; i = i + 1)"
140        );
141    }
142}