Skip to main content

decy_codegen/
pattern_gen.rs

1//! Pattern matching generation from tag checks (DECY-082).
2//!
3//! Transforms C tagged union access patterns (if-else-if chains checking tag fields)
4//! into safe Rust match expressions with exhaustive pattern matching.
5//!
6//! # Transformation
7//!
8//! C code:
9//! ```c
10//! if (v.tag == INT) {
11//!     return v.data.i;
12//! } else if (v.tag == FLOAT) {
13//!     return v.data.f;
14//! } else {
15//!     return -1;
16//! }
17//! ```
18//!
19//! Rust code:
20//! ```rust,ignore
21//! match v {
22//!     Value::Int(i) => return i,
23//!     Value::Float(f) => return f,
24//!     _ => return -1,
25//! }
26//! ```
27//!
28//! # Benefits
29//!
30//! - **Type safety**: Compiler verifies correct variant access
31//! - **Exhaustiveness**: All possible cases must be handled
32//! - **Zero unsafe**: No unsafe union field access
33//! - **Pattern binding**: Direct access to variant payloads
34
35use decy_hir::{BinaryOperator, HirExpression, HirStatement};
36
37/// Generator for Rust pattern matching from C tag checks.
38///
39/// # Algorithm
40///
41/// 1. **Detect tag check**: Identify if-conditions comparing tag fields
42/// 2. **Extract variant info**: Parse tag value and union field access
43/// 3. **Generate match arms**: Convert if-else-if chain to match arms
44/// 4. **Capitalize variants**: Convert C constants to PascalCase
45/// 5. **Add wildcard**: Ensure exhaustive matching with `_` arm
46pub struct PatternGenerator;
47
48impl PatternGenerator {
49    /// Create a new pattern generator.
50    pub fn new() -> Self {
51        Self
52    }
53
54    /// Transform C tag check (if statement) into Rust match expression.
55    ///
56    /// # Arguments
57    ///
58    /// * `stmt` - HIR if statement checking a tag field
59    ///
60    /// # Returns
61    ///
62    /// Rust match expression as a string, or empty string if not a tag check
63    pub fn transform_tag_check(&self, stmt: &HirStatement) -> String {
64        if let HirStatement::If { condition, then_block: _, else_block: _ } = stmt {
65            // Check if this is a tag comparison
66            if let Some((var_name, _tag_field, _tag_value)) = Self::extract_tag_check(condition) {
67                let mut result = String::new();
68
69                // Start match expression
70                result.push_str(&format!("match {} {{\n", var_name));
71
72                // Collect all arms from if-else-if chain
73                let arms = self.collect_match_arms(stmt);
74
75                for arm in arms {
76                    result.push_str(&format!("    {},\n", arm));
77                }
78
79                result.push('}');
80                return result;
81            }
82        }
83
84        // Not a tag check - return empty
85        String::new()
86    }
87
88    /// Extract tag check components from condition.
89    /// Returns (variable_name, tag_field, tag_value) if valid tag check.
90    fn extract_tag_check(condition: &HirExpression) -> Option<(String, String, String)> {
91        if let HirExpression::BinaryOp { left, op, right } = condition {
92            if !matches!(op, BinaryOperator::Equal) {
93                return None;
94            }
95
96            // Left should be field access: v.tag
97            if let HirExpression::FieldAccess { object, field } = &**left {
98                if let HirExpression::Variable(var_name) = &**object {
99                    // Right should be enum constant
100                    if let HirExpression::Variable(tag_value) = &**right {
101                        return Some((var_name.clone(), field.clone(), tag_value.clone()));
102                    }
103                }
104            }
105        }
106        None
107    }
108
109    /// Collect all match arms from if-else-if chain.
110    fn collect_match_arms(&self, stmt: &HirStatement) -> Vec<String> {
111        let mut arms = Vec::new();
112        self.collect_arms_recursive(stmt, &mut arms);
113        arms
114    }
115
116    /// Recursively collect match arms from nested if statements.
117    fn collect_arms_recursive(&self, stmt: &HirStatement, arms: &mut Vec<String>) {
118        if let HirStatement::If { condition, then_block, else_block } = stmt {
119            if let Some((_var_name, _tag_field, tag_value)) = Self::extract_tag_check(condition) {
120                // Generate match arm for this variant
121                let variant_name = Self::capitalize_tag_value(&tag_value);
122                let binding = Self::extract_union_field_binding(then_block);
123
124                let arm_body = self.generate_arm_body(then_block);
125
126                let arm = if let Some(field) = binding {
127                    format!("Value::{}({}) => {}", variant_name, field, arm_body)
128                } else {
129                    format!("Value::{} => {}", variant_name, arm_body)
130                };
131
132                arms.push(arm);
133
134                // Process else block
135                if let Some(else_stmts) = else_block {
136                    if else_stmts.len() == 1 {
137                        // Check if it's another tag check (else-if)
138                        if matches!(else_stmts[0], HirStatement::If { .. }) {
139                            self.collect_arms_recursive(&else_stmts[0], arms);
140                            return;
141                        }
142                    }
143
144                    // Not another tag check - generate default arm
145                    let else_body = self.generate_block_body(else_stmts);
146                    arms.push(format!("_ => {}", else_body));
147                }
148            }
149        }
150    }
151
152    /// Capitalize tag value to PascalCase variant name.
153    fn capitalize_tag_value(tag_value: &str) -> String {
154        let parts: Vec<String> = tag_value
155            .split('_')
156            .filter(|s| !s.is_empty())
157            .map(|part| {
158                let mut chars = part.chars();
159                match chars.next() {
160                    None => String::new(),
161                    Some(first) => {
162                        let rest: String = chars.collect::<String>().to_lowercase();
163                        first.to_uppercase().collect::<String>() + &rest
164                    }
165                }
166            })
167            .collect();
168
169        if parts.is_empty() {
170            tag_value.to_string()
171        } else {
172            parts.join("")
173        }
174    }
175
176    /// Extract union field name from then block if it accesses v.data.field.
177    fn extract_union_field_binding(then_block: &[HirStatement]) -> Option<String> {
178        for stmt in then_block {
179            if let Some(field) = Self::find_union_field_in_stmt(stmt) {
180                return Some(field);
181            }
182        }
183        None
184    }
185
186    /// Find union field access in a statement.
187    fn find_union_field_in_stmt(stmt: &HirStatement) -> Option<String> {
188        match stmt {
189            HirStatement::Return(Some(expr)) => Self::find_union_field_in_expr(expr),
190            HirStatement::Expression(expr) => Self::find_union_field_in_expr(expr),
191            _ => None,
192        }
193    }
194
195    /// Find union field access in an expression (v.data.field_name).
196    fn find_union_field_in_expr(expr: &HirExpression) -> Option<String> {
197        if let HirExpression::FieldAccess { object, field } = expr {
198            // Check if object is v.data
199            if let HirExpression::FieldAccess { object: _inner_obj, field: inner_field } = &**object
200            {
201                if inner_field == "data" {
202                    return Some(field.clone());
203                }
204            }
205        }
206        None
207    }
208
209    /// Generate arm body from then block.
210    fn generate_arm_body(&self, then_block: &[HirStatement]) -> String {
211        if then_block.len() == 1 {
212            if let HirStatement::Return(Some(expr)) = &then_block[0] {
213                // Simple return - just return the expression
214                let field = Self::find_union_field_in_expr(expr);
215                if let Some(f) = field {
216                    return format!("return {}", f);
217                }
218                return "return /* value */".to_string();
219            }
220        }
221
222        self.generate_block_body(then_block)
223    }
224
225    /// Generate body for a block of statements.
226    fn generate_block_body(&self, stmts: &[HirStatement]) -> String {
227        if stmts.is_empty() {
228            return "{}".to_string();
229        }
230
231        if stmts.len() == 1 && matches!(&stmts[0], HirStatement::Return(Some(_))) {
232            return "/* return value */".to_string();
233        }
234
235        "{ /* block */ }".to_string()
236    }
237}
238
239impl Default for PatternGenerator {
240    fn default() -> Self {
241        Self::new()
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248
249    #[test]
250    fn test_pattern_generator_default_impl() {
251        let gen = PatternGenerator::default();
252        // Non-if statement returns empty
253        let result = gen.transform_tag_check(&HirStatement::Break);
254        assert!(result.is_empty());
255    }
256
257    #[test]
258    fn test_capitalize_tag_value_single_underscore() {
259        // Single underscore → split gives ["", ""] → filter empties → parts empty → return as-is
260        let result = PatternGenerator::capitalize_tag_value("_");
261        assert_eq!(result, "_");
262    }
263
264    #[test]
265    fn test_capitalize_tag_value_empty() {
266        let result = PatternGenerator::capitalize_tag_value("");
267        assert_eq!(result, "");
268    }
269
270    #[test]
271    fn test_generate_block_body_single_return() {
272        let gen = PatternGenerator::new();
273        // Build tag check with else block containing a single Return(Some(...))
274        let condition = HirExpression::BinaryOp {
275            left: Box::new(HirExpression::FieldAccess {
276                object: Box::new(HirExpression::Variable("v".to_string())),
277                field: "tag".to_string(),
278            }),
279            op: BinaryOperator::Equal,
280            right: Box::new(HirExpression::Variable("INT".to_string())),
281        };
282
283        let stmt = HirStatement::If {
284            condition,
285            then_block: vec![],
286            else_block: Some(vec![HirStatement::Return(Some(HirExpression::IntLiteral(-1)))]),
287        };
288
289        let result = gen.transform_tag_check(&stmt);
290        // Else block with single return → "/* return value */"
291        assert!(result.contains("_ => /* return value */"));
292    }
293
294    #[test]
295    fn test_find_union_field_in_non_field_access() {
296        // Test find_union_field_in_expr with non-FieldAccess expression
297        let result = PatternGenerator::find_union_field_in_expr(&HirExpression::IntLiteral(42));
298        assert!(result.is_none());
299    }
300
301    #[test]
302    fn test_find_union_field_single_field_access() {
303        // FieldAccess where inner is not another FieldAccess with "data"
304        let expr = HirExpression::FieldAccess {
305            object: Box::new(HirExpression::Variable("v".to_string())),
306            field: "value".to_string(),
307        };
308        let result = PatternGenerator::find_union_field_in_expr(&expr);
309        assert!(result.is_none());
310    }
311}