ringkernel_wgpu_codegen/
loops.rs1#[derive(Debug, Clone)]
7pub enum LoopPattern {
8 ForRange {
10 var: String,
11 start: String,
12 end: String,
13 inclusive: bool,
14 },
15 While { condition: String },
17 Loop,
19}
20
21impl LoopPattern {
22 pub fn to_wgsl_header(&self) -> String {
24 match self {
25 LoopPattern::ForRange {
26 var,
27 start,
28 end,
29 inclusive,
30 } => {
31 let op = if *inclusive { "<=" } else { "<" };
32 format!("for (var {var}: i32 = {start}; {var} {op} {end}; {var} = {var} + 1)")
33 }
34 LoopPattern::While { condition } => {
35 format!("while ({condition})")
36 }
37 LoopPattern::Loop => "loop".to_string(),
38 }
39 }
40}
41
42#[derive(Debug, Clone)]
44pub struct RangeInfo {
45 pub start: Option<String>,
47 pub end: Option<String>,
49 pub inclusive: bool,
51}
52
53impl RangeInfo {
54 pub fn new(start: Option<String>, end: Option<String>, inclusive: bool) -> Self {
56 Self {
57 start,
58 end,
59 inclusive,
60 }
61 }
62
63 pub fn start_or_default(&self) -> String {
65 self.start.clone().unwrap_or_else(|| "0".to_string())
66 }
67
68 pub fn end_expr(&self) -> Option<&str> {
70 self.end.as_deref()
71 }
72}
73
74pub fn range_to_for_loop(var: &str, range: &RangeInfo) -> LoopPattern {
76 LoopPattern::ForRange {
77 var: var.to_string(),
78 start: range.start_or_default(),
79 end: range
80 .end
81 .clone()
82 .unwrap_or_else(|| "/* unbounded */".to_string()),
83 inclusive: range.inclusive,
84 }
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90
91 #[test]
92 fn test_for_range_exclusive() {
93 let pattern = LoopPattern::ForRange {
94 var: "i".to_string(),
95 start: "0".to_string(),
96 end: "10".to_string(),
97 inclusive: false,
98 };
99 assert_eq!(
100 pattern.to_wgsl_header(),
101 "for (var i: i32 = 0; i < 10; i = i + 1)"
102 );
103 }
104
105 #[test]
106 fn test_for_range_inclusive() {
107 let pattern = LoopPattern::ForRange {
108 var: "j".to_string(),
109 start: "1".to_string(),
110 end: "n".to_string(),
111 inclusive: true,
112 };
113 assert_eq!(
114 pattern.to_wgsl_header(),
115 "for (var j: i32 = 1; j <= n; j = j + 1)"
116 );
117 }
118
119 #[test]
120 fn test_while_loop() {
121 let pattern = LoopPattern::While {
122 condition: "x > 0".to_string(),
123 };
124 assert_eq!(pattern.to_wgsl_header(), "while (x > 0)");
125 }
126
127 #[test]
128 fn test_infinite_loop() {
129 let pattern = LoopPattern::Loop;
130 assert_eq!(pattern.to_wgsl_header(), "loop");
131 }
132
133 #[test]
134 fn test_range_info() {
135 let range = RangeInfo::new(Some("0".to_string()), Some("10".to_string()), false);
136 let pattern = range_to_for_loop("i", &range);
137 assert_eq!(
138 pattern.to_wgsl_header(),
139 "for (var i: i32 = 0; i < 10; i = i + 1)"
140 );
141 }
142}