ringkernel_cuda_codegen/
loops.rs

1//! Loop transpilation helpers for CUDA code generation.
2//!
3//! This module provides utilities for transpiling Rust loop constructs
4//! to their CUDA C equivalents.
5//!
6//! # Supported Loop Types
7//!
8//! - `for i in start..end` → `for (int i = start; i < end; i++)`
9//! - `for i in start..=end` → `for (int i = start; i <= end; i++)`
10//! - `while condition` → `while (condition)`
11//! - `loop` → `while (true)` or `for (;;)`
12//!
13//! # Control Flow
14//!
15//! - `break` → `break;`
16//! - `continue` → `continue;`
17//! - `break 'label` → Not yet supported (would require goto)
18
19use syn::{Expr, ExprRange, RangeLimits};
20
21/// Information about a parsed range expression.
22#[derive(Debug, Clone)]
23pub struct RangeInfo {
24    /// Start value of the range (None for unbounded start).
25    pub start: Option<String>,
26    /// End value of the range (None for unbounded end).
27    pub end: Option<String>,
28    /// Whether the range is inclusive (..=) or exclusive (..).
29    pub inclusive: bool,
30}
31
32impl RangeInfo {
33    /// Parse a range expression into RangeInfo.
34    ///
35    /// # Arguments
36    ///
37    /// * `range` - The range expression to parse
38    /// * `transpile_expr` - Function to transpile sub-expressions to CUDA strings
39    ///
40    /// # Returns
41    ///
42    /// A RangeInfo struct containing the parsed range bounds.
43    pub fn from_range<F>(range: &ExprRange, transpile_expr: F) -> Self
44    where
45        F: Fn(&Expr) -> Result<String, crate::TranspileError>,
46    {
47        let start = range.start.as_ref().and_then(|e| transpile_expr(e).ok());
48
49        let end = range.end.as_ref().and_then(|e| transpile_expr(e).ok());
50
51        let inclusive = matches!(range.limits, RangeLimits::Closed(_));
52
53        RangeInfo {
54            start,
55            end,
56            inclusive,
57        }
58    }
59
60    /// Generate the CUDA comparison operator for the loop condition.
61    pub fn comparison_op(&self) -> &'static str {
62        if self.inclusive {
63            "<="
64        } else {
65            "<"
66        }
67    }
68
69    /// Generate a complete CUDA for loop header.
70    ///
71    /// # Arguments
72    ///
73    /// * `var_name` - The loop variable name
74    /// * `var_type` - The CUDA type for the loop variable (e.g., "int")
75    ///
76    /// # Returns
77    ///
78    /// A string like `for (int i = 0; i < n; i++)`
79    pub fn to_cuda_for_header(&self, var_name: &str, var_type: &str) -> String {
80        let start = self.start.as_deref().unwrap_or("0");
81        let end = self.end.as_deref().unwrap_or("/* end */");
82        let op = self.comparison_op();
83
84        format!("for ({var_type} {var_name} = {start}; {var_name} {op} {end}; {var_name}++)")
85    }
86}
87
88/// Represents different loop patterns that can be transpiled.
89#[derive(Debug, Clone)]
90pub enum LoopPattern {
91    /// A for loop over a range: `for i in start..end`
92    ForRange { var_name: String, range: RangeInfo },
93    /// A for loop over an iterator (not fully supported yet)
94    ForIterator { var_name: String, iterator: String },
95    /// A while loop: `while condition { ... }`
96    While { condition: String },
97    /// An infinite loop: `loop { ... }`
98    Infinite,
99}
100
101impl LoopPattern {
102    /// Generate the CUDA loop header for this pattern.
103    ///
104    /// # Arguments
105    ///
106    /// * `var_type` - The type to use for loop variables (e.g., "int")
107    ///
108    /// # Returns
109    ///
110    /// The CUDA loop header string.
111    pub fn to_cuda_header(&self, var_type: &str) -> String {
112        match self {
113            LoopPattern::ForRange { var_name, range } => {
114                range.to_cuda_for_header(var_name, var_type)
115            }
116            LoopPattern::ForIterator { var_name, iterator } => {
117                // Basic iterator support - treat as range-like
118                format!("for ({var_type} {var_name} : {iterator})")
119            }
120            LoopPattern::While { condition } => {
121                format!("while ({condition})")
122            }
123            LoopPattern::Infinite => {
124                // Using while(true) for clarity; could also use for(;;)
125                "while (true)".to_string()
126            }
127        }
128    }
129}
130
131/// Check if an expression is a simple range (start..end or start..=end).
132pub fn is_range_expr(expr: &Expr) -> bool {
133    matches!(expr, Expr::Range(_))
134}
135
136/// Extract the loop variable name from a for loop pattern.
137pub fn extract_loop_var(pat: &syn::Pat) -> Option<String> {
138    match pat {
139        syn::Pat::Ident(ident) => Some(ident.ident.to_string()),
140        _ => None,
141    }
142}
143
144/// Determine the appropriate CUDA type for a loop variable.
145///
146/// This uses heuristics based on the range bounds to pick int vs size_t.
147pub fn infer_loop_var_type(range: &RangeInfo) -> &'static str {
148    // Check if the range bounds suggest a specific type
149    if let Some(ref end) = range.end {
150        // If the end contains "size" or looks like a size type, use size_t
151        if end.contains("size") || end.contains("len") {
152            return "size_t";
153        }
154    }
155
156    // Default to int for most cases
157    "int"
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163    use syn::parse_quote;
164
165    #[test]
166    fn test_range_info_exclusive() {
167        let range: ExprRange = parse_quote!(0..10);
168        let info = RangeInfo::from_range(&range, |e| {
169            Ok(quote::ToTokens::to_token_stream(e).to_string())
170        });
171
172        assert!(!info.inclusive);
173        assert_eq!(info.comparison_op(), "<");
174    }
175
176    #[test]
177    fn test_range_info_inclusive() {
178        let range: ExprRange = parse_quote!(0..=10);
179        let info = RangeInfo::from_range(&range, |e| {
180            Ok(quote::ToTokens::to_token_stream(e).to_string())
181        });
182
183        assert!(info.inclusive);
184        assert_eq!(info.comparison_op(), "<=");
185    }
186
187    #[test]
188    fn test_for_header_generation() {
189        let range = RangeInfo {
190            start: Some("0".to_string()),
191            end: Some("n".to_string()),
192            inclusive: false,
193        };
194
195        let header = range.to_cuda_for_header("i", "int");
196        assert_eq!(header, "for (int i = 0; i < n; i++)");
197    }
198
199    #[test]
200    fn test_for_header_inclusive() {
201        let range = RangeInfo {
202            start: Some("1".to_string()),
203            end: Some("10".to_string()),
204            inclusive: true,
205        };
206
207        let header = range.to_cuda_for_header("j", "int");
208        assert_eq!(header, "for (int j = 1; j <= 10; j++)");
209    }
210
211    #[test]
212    fn test_loop_pattern_while() {
213        let pattern = LoopPattern::While {
214            condition: "i < 10".to_string(),
215        };
216
217        assert_eq!(pattern.to_cuda_header("int"), "while (i < 10)");
218    }
219
220    #[test]
221    fn test_loop_pattern_infinite() {
222        let pattern = LoopPattern::Infinite;
223        assert_eq!(pattern.to_cuda_header("int"), "while (true)");
224    }
225
226    #[test]
227    fn test_extract_loop_var() {
228        let pat: syn::Pat = parse_quote!(i);
229        assert_eq!(extract_loop_var(&pat), Some("i".to_string()));
230
231        let pat_complex: syn::Pat = parse_quote!((a, b));
232        assert_eq!(extract_loop_var(&pat_complex), None);
233    }
234
235    #[test]
236    fn test_infer_loop_var_type() {
237        let range_int = RangeInfo {
238            start: Some("0".to_string()),
239            end: Some("10".to_string()),
240            inclusive: false,
241        };
242        assert_eq!(infer_loop_var_type(&range_int), "int");
243
244        let range_size = RangeInfo {
245            start: Some("0".to_string()),
246            end: Some("data.len()".to_string()),
247            inclusive: false,
248        };
249        assert_eq!(infer_loop_var_type(&range_size), "size_t");
250    }
251}