Skip to main content

decy_codegen/
pattern_gen.rs

1//! Pattern matching generation from tag checks (DECY-082).
2//!
3//! Transforms C tagged union access patterns (if-else-if chains checking tag fields)
4//! into safe Rust match expressions with exhaustive pattern matching.
5//!
6//! # Transformation
7//!
8//! C code:
9//! ```c
10//! if (v.tag == INT) {
11//!     return v.data.i;
12//! } else if (v.tag == FLOAT) {
13//!     return v.data.f;
14//! } else {
15//!     return -1;
16//! }
17//! ```
18//!
19//! Rust code:
20//! ```rust,ignore
21//! match v {
22//!     Value::Int(i) => return i,
23//!     Value::Float(f) => return f,
24//!     _ => return -1,
25//! }
26//! ```
27//!
28//! # Benefits
29//!
30//! - **Type safety**: Compiler verifies correct variant access
31//! - **Exhaustiveness**: All possible cases must be handled
32//! - **Zero unsafe**: No unsafe union field access
33//! - **Pattern binding**: Direct access to variant payloads
34
35use decy_hir::{BinaryOperator, HirExpression, HirStatement};
36
37/// Generator for Rust pattern matching from C tag checks.
38///
39/// # Algorithm
40///
41/// 1. **Detect tag check**: Identify if-conditions comparing tag fields
42/// 2. **Extract variant info**: Parse tag value and union field access
43/// 3. **Generate match arms**: Convert if-else-if chain to match arms
44/// 4. **Capitalize variants**: Convert C constants to PascalCase
45/// 5. **Add wildcard**: Ensure exhaustive matching with `_` arm
46pub struct PatternGenerator;
47
48impl PatternGenerator {
49    /// Create a new pattern generator.
50    pub fn new() -> Self {
51        Self
52    }
53
54    /// Transform C tag check (if statement) into Rust match expression.
55    ///
56    /// # Arguments
57    ///
58    /// * `stmt` - HIR if statement checking a tag field
59    ///
60    /// # Returns
61    ///
62    /// Rust match expression as a string, or empty string if not a tag check
63    pub fn transform_tag_check(&self, stmt: &HirStatement) -> String {
64        if let HirStatement::If {
65            condition,
66            then_block: _,
67            else_block: _,
68        } = stmt
69        {
70            // Check if this is a tag comparison
71            if let Some((var_name, _tag_field, _tag_value)) = Self::extract_tag_check(condition) {
72                let mut result = String::new();
73
74                // Start match expression
75                result.push_str(&format!("match {} {{\n", var_name));
76
77                // Collect all arms from if-else-if chain
78                let arms = self.collect_match_arms(stmt);
79
80                for arm in arms {
81                    result.push_str(&format!("    {},\n", arm));
82                }
83
84                result.push('}');
85                return result;
86            }
87        }
88
89        // Not a tag check - return empty
90        String::new()
91    }
92
93    /// Extract tag check components from condition.
94    /// Returns (variable_name, tag_field, tag_value) if valid tag check.
95    fn extract_tag_check(condition: &HirExpression) -> Option<(String, String, String)> {
96        if let HirExpression::BinaryOp { left, op, right } = condition {
97            if !matches!(op, BinaryOperator::Equal) {
98                return None;
99            }
100
101            // Left should be field access: v.tag
102            if let HirExpression::FieldAccess { object, field } = &**left {
103                if let HirExpression::Variable(var_name) = &**object {
104                    // Right should be enum constant
105                    if let HirExpression::Variable(tag_value) = &**right {
106                        return Some((var_name.clone(), field.clone(), tag_value.clone()));
107                    }
108                }
109            }
110        }
111        None
112    }
113
114    /// Collect all match arms from if-else-if chain.
115    fn collect_match_arms(&self, stmt: &HirStatement) -> Vec<String> {
116        let mut arms = Vec::new();
117        self.collect_arms_recursive(stmt, &mut arms);
118        arms
119    }
120
121    /// Recursively collect match arms from nested if statements.
122    fn collect_arms_recursive(&self, stmt: &HirStatement, arms: &mut Vec<String>) {
123        if let HirStatement::If {
124            condition,
125            then_block,
126            else_block,
127        } = stmt
128        {
129            if let Some((_var_name, _tag_field, tag_value)) = Self::extract_tag_check(condition) {
130                // Generate match arm for this variant
131                let variant_name = Self::capitalize_tag_value(&tag_value);
132                let binding = Self::extract_union_field_binding(then_block);
133
134                let arm_body = self.generate_arm_body(then_block);
135
136                let arm = if let Some(field) = binding {
137                    format!("Value::{}({}) => {}", variant_name, field, arm_body)
138                } else {
139                    format!("Value::{} => {}", variant_name, arm_body)
140                };
141
142                arms.push(arm);
143
144                // Process else block
145                if let Some(else_stmts) = else_block {
146                    if else_stmts.len() == 1 {
147                        // Check if it's another tag check (else-if)
148                        if matches!(else_stmts[0], HirStatement::If { .. }) {
149                            self.collect_arms_recursive(&else_stmts[0], arms);
150                            return;
151                        }
152                    }
153
154                    // Not another tag check - generate default arm
155                    let else_body = self.generate_block_body(else_stmts);
156                    arms.push(format!("_ => {}", else_body));
157                }
158            }
159        }
160    }
161
162    /// Capitalize tag value to PascalCase variant name.
163    fn capitalize_tag_value(tag_value: &str) -> String {
164        let parts: Vec<String> = tag_value
165            .split('_')
166            .filter(|s| !s.is_empty())
167            .map(|part| {
168                let mut chars = part.chars();
169                match chars.next() {
170                    None => String::new(),
171                    Some(first) => {
172                        let rest: String = chars.collect::<String>().to_lowercase();
173                        first.to_uppercase().collect::<String>() + &rest
174                    }
175                }
176            })
177            .collect();
178
179        if parts.is_empty() {
180            tag_value.to_string()
181        } else {
182            parts.join("")
183        }
184    }
185
186    /// Extract union field name from then block if it accesses v.data.field.
187    fn extract_union_field_binding(then_block: &[HirStatement]) -> Option<String> {
188        for stmt in then_block {
189            if let Some(field) = Self::find_union_field_in_stmt(stmt) {
190                return Some(field);
191            }
192        }
193        None
194    }
195
196    /// Find union field access in a statement.
197    fn find_union_field_in_stmt(stmt: &HirStatement) -> Option<String> {
198        match stmt {
199            HirStatement::Return(Some(expr)) => Self::find_union_field_in_expr(expr),
200            HirStatement::Expression(expr) => Self::find_union_field_in_expr(expr),
201            _ => None,
202        }
203    }
204
205    /// Find union field access in an expression (v.data.field_name).
206    fn find_union_field_in_expr(expr: &HirExpression) -> Option<String> {
207        if let HirExpression::FieldAccess { object, field } = expr {
208            // Check if object is v.data
209            if let HirExpression::FieldAccess {
210                object: _inner_obj,
211                field: inner_field,
212            } = &**object
213            {
214                if inner_field == "data" {
215                    return Some(field.clone());
216                }
217            }
218        }
219        None
220    }
221
222    /// Generate arm body from then block.
223    fn generate_arm_body(&self, then_block: &[HirStatement]) -> String {
224        if then_block.len() == 1 {
225            if let HirStatement::Return(Some(expr)) = &then_block[0] {
226                // Simple return - just return the expression
227                let field = Self::find_union_field_in_expr(expr);
228                if let Some(f) = field {
229                    return format!("return {}", f);
230                }
231                return "return /* value */".to_string();
232            }
233        }
234
235        self.generate_block_body(then_block)
236    }
237
238    /// Generate body for a block of statements.
239    fn generate_block_body(&self, stmts: &[HirStatement]) -> String {
240        if stmts.is_empty() {
241            return "{}".to_string();
242        }
243
244        if stmts.len() == 1 && matches!(&stmts[0], HirStatement::Return(Some(_))) {
245            return "/* return value */".to_string();
246        }
247
248        "{ /* block */ }".to_string()
249    }
250}
251
252impl Default for PatternGenerator {
253    fn default() -> Self {
254        Self::new()
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    #[test]
263    fn test_pattern_generator_default_impl() {
264        let gen = PatternGenerator::default();
265        // Non-if statement returns empty
266        let result = gen.transform_tag_check(&HirStatement::Break);
267        assert!(result.is_empty());
268    }
269
270    #[test]
271    fn test_capitalize_tag_value_single_underscore() {
272        // Single underscore → split gives ["", ""] → filter empties → parts empty → return as-is
273        let result = PatternGenerator::capitalize_tag_value("_");
274        assert_eq!(result, "_");
275    }
276
277    #[test]
278    fn test_capitalize_tag_value_empty() {
279        let result = PatternGenerator::capitalize_tag_value("");
280        assert_eq!(result, "");
281    }
282
283    #[test]
284    fn test_generate_block_body_single_return() {
285        let gen = PatternGenerator::new();
286        // Build tag check with else block containing a single Return(Some(...))
287        let condition = HirExpression::BinaryOp {
288            left: Box::new(HirExpression::FieldAccess {
289                object: Box::new(HirExpression::Variable("v".to_string())),
290                field: "tag".to_string(),
291            }),
292            op: BinaryOperator::Equal,
293            right: Box::new(HirExpression::Variable("INT".to_string())),
294        };
295
296        let stmt = HirStatement::If {
297            condition,
298            then_block: vec![],
299            else_block: Some(vec![HirStatement::Return(Some(
300                HirExpression::IntLiteral(-1),
301            ))]),
302        };
303
304        let result = gen.transform_tag_check(&stmt);
305        // Else block with single return → "/* return value */"
306        assert!(result.contains("_ => /* return value */"));
307    }
308
309    #[test]
310    fn test_find_union_field_in_non_field_access() {
311        // Test find_union_field_in_expr with non-FieldAccess expression
312        let result = PatternGenerator::find_union_field_in_expr(&HirExpression::IntLiteral(42));
313        assert!(result.is_none());
314    }
315
316    #[test]
317    fn test_find_union_field_single_field_access() {
318        // FieldAccess where inner is not another FieldAccess with "data"
319        let expr = HirExpression::FieldAccess {
320            object: Box::new(HirExpression::Variable("v".to_string())),
321            field: "value".to_string(),
322        };
323        let result = PatternGenerator::find_union_field_in_expr(&expr);
324        assert!(result.is_none());
325    }
326}