ringkernel_cuda_codegen/
loops.rs1use syn::{Expr, ExprRange, RangeLimits};
20
21#[derive(Debug, Clone)]
23pub struct RangeInfo {
24 pub start: Option<String>,
26 pub end: Option<String>,
28 pub inclusive: bool,
30}
31
32impl RangeInfo {
33 pub fn from_range<F>(range: &ExprRange, transpile_expr: F) -> Self
44 where
45 F: Fn(&Expr) -> Result<String, crate::TranspileError>,
46 {
47 let start = range.start.as_ref().and_then(|e| transpile_expr(e).ok());
48
49 let end = range.end.as_ref().and_then(|e| transpile_expr(e).ok());
50
51 let inclusive = matches!(range.limits, RangeLimits::Closed(_));
52
53 RangeInfo {
54 start,
55 end,
56 inclusive,
57 }
58 }
59
60 pub fn comparison_op(&self) -> &'static str {
62 if self.inclusive {
63 "<="
64 } else {
65 "<"
66 }
67 }
68
69 pub fn to_cuda_for_header(&self, var_name: &str, var_type: &str) -> String {
80 let start = self.start.as_deref().unwrap_or("0");
81 let end = self.end.as_deref().unwrap_or("/* end */");
82 let op = self.comparison_op();
83
84 format!("for ({var_type} {var_name} = {start}; {var_name} {op} {end}; {var_name}++)")
85 }
86}
87
88#[derive(Debug, Clone)]
90pub enum LoopPattern {
91 ForRange { var_name: String, range: RangeInfo },
93 ForIterator { var_name: String, iterator: String },
95 While { condition: String },
97 Infinite,
99}
100
101impl LoopPattern {
102 pub fn to_cuda_header(&self, var_type: &str) -> String {
112 match self {
113 LoopPattern::ForRange { var_name, range } => {
114 range.to_cuda_for_header(var_name, var_type)
115 }
116 LoopPattern::ForIterator { var_name, iterator } => {
117 format!("for ({var_type} {var_name} : {iterator})")
119 }
120 LoopPattern::While { condition } => {
121 format!("while ({condition})")
122 }
123 LoopPattern::Infinite => {
124 "while (true)".to_string()
126 }
127 }
128 }
129}
130
131pub fn is_range_expr(expr: &Expr) -> bool {
133 matches!(expr, Expr::Range(_))
134}
135
136pub fn extract_loop_var(pat: &syn::Pat) -> Option<String> {
138 match pat {
139 syn::Pat::Ident(ident) => Some(ident.ident.to_string()),
140 _ => None,
141 }
142}
143
144pub fn infer_loop_var_type(range: &RangeInfo) -> &'static str {
148 if let Some(ref end) = range.end {
150 if end.contains("size") || end.contains("len") {
152 return "size_t";
153 }
154 }
155
156 "int"
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163 use syn::parse_quote;
164
165 #[test]
166 fn test_range_info_exclusive() {
167 let range: ExprRange = parse_quote!(0..10);
168 let info = RangeInfo::from_range(&range, |e| {
169 Ok(quote::ToTokens::to_token_stream(e).to_string())
170 });
171
172 assert!(!info.inclusive);
173 assert_eq!(info.comparison_op(), "<");
174 }
175
176 #[test]
177 fn test_range_info_inclusive() {
178 let range: ExprRange = parse_quote!(0..=10);
179 let info = RangeInfo::from_range(&range, |e| {
180 Ok(quote::ToTokens::to_token_stream(e).to_string())
181 });
182
183 assert!(info.inclusive);
184 assert_eq!(info.comparison_op(), "<=");
185 }
186
187 #[test]
188 fn test_for_header_generation() {
189 let range = RangeInfo {
190 start: Some("0".to_string()),
191 end: Some("n".to_string()),
192 inclusive: false,
193 };
194
195 let header = range.to_cuda_for_header("i", "int");
196 assert_eq!(header, "for (int i = 0; i < n; i++)");
197 }
198
199 #[test]
200 fn test_for_header_inclusive() {
201 let range = RangeInfo {
202 start: Some("1".to_string()),
203 end: Some("10".to_string()),
204 inclusive: true,
205 };
206
207 let header = range.to_cuda_for_header("j", "int");
208 assert_eq!(header, "for (int j = 1; j <= 10; j++)");
209 }
210
211 #[test]
212 fn test_loop_pattern_while() {
213 let pattern = LoopPattern::While {
214 condition: "i < 10".to_string(),
215 };
216
217 assert_eq!(pattern.to_cuda_header("int"), "while (i < 10)");
218 }
219
220 #[test]
221 fn test_loop_pattern_infinite() {
222 let pattern = LoopPattern::Infinite;
223 assert_eq!(pattern.to_cuda_header("int"), "while (true)");
224 }
225
226 #[test]
227 fn test_extract_loop_var() {
228 let pat: syn::Pat = parse_quote!(i);
229 assert_eq!(extract_loop_var(&pat), Some("i".to_string()));
230
231 let pat_complex: syn::Pat = parse_quote!((a, b));
232 assert_eq!(extract_loop_var(&pat_complex), None);
233 }
234
235 #[test]
236 fn test_infer_loop_var_type() {
237 let range_int = RangeInfo {
238 start: Some("0".to_string()),
239 end: Some("10".to_string()),
240 inclusive: false,
241 };
242 assert_eq!(infer_loop_var_type(&range_int), "int");
243
244 let range_size = RangeInfo {
245 start: Some("0".to_string()),
246 end: Some("data.len()".to_string()),
247 inclusive: false,
248 };
249 assert_eq!(infer_loop_var_type(&range_size), "size_t");
250 }
251}