decy_codegen/
pattern_gen.rs1use decy_hir::{BinaryOperator, HirExpression, HirStatement};
36
37pub struct PatternGenerator;
47
48impl PatternGenerator {
49 pub fn new() -> Self {
51 Self
52 }
53
54 pub fn transform_tag_check(&self, stmt: &HirStatement) -> String {
64 if let HirStatement::If {
65 condition,
66 then_block: _,
67 else_block: _,
68 } = stmt
69 {
70 if let Some((var_name, _tag_field, _tag_value)) = Self::extract_tag_check(condition) {
72 let mut result = String::new();
73
74 result.push_str(&format!("match {} {{\n", var_name));
76
77 let arms = self.collect_match_arms(stmt);
79
80 for arm in arms {
81 result.push_str(&format!(" {},\n", arm));
82 }
83
84 result.push('}');
85 return result;
86 }
87 }
88
89 String::new()
91 }
92
93 fn extract_tag_check(condition: &HirExpression) -> Option<(String, String, String)> {
96 if let HirExpression::BinaryOp { left, op, right } = condition {
97 if !matches!(op, BinaryOperator::Equal) {
98 return None;
99 }
100
101 if let HirExpression::FieldAccess { object, field } = &**left {
103 if let HirExpression::Variable(var_name) = &**object {
104 if let HirExpression::Variable(tag_value) = &**right {
106 return Some((var_name.clone(), field.clone(), tag_value.clone()));
107 }
108 }
109 }
110 }
111 None
112 }
113
114 fn collect_match_arms(&self, stmt: &HirStatement) -> Vec<String> {
116 let mut arms = Vec::new();
117 self.collect_arms_recursive(stmt, &mut arms);
118 arms
119 }
120
121 fn collect_arms_recursive(&self, stmt: &HirStatement, arms: &mut Vec<String>) {
123 if let HirStatement::If {
124 condition,
125 then_block,
126 else_block,
127 } = stmt
128 {
129 if let Some((_var_name, _tag_field, tag_value)) = Self::extract_tag_check(condition) {
130 let variant_name = Self::capitalize_tag_value(&tag_value);
132 let binding = Self::extract_union_field_binding(then_block);
133
134 let arm_body = self.generate_arm_body(then_block);
135
136 let arm = if let Some(field) = binding {
137 format!("Value::{}({}) => {}", variant_name, field, arm_body)
138 } else {
139 format!("Value::{} => {}", variant_name, arm_body)
140 };
141
142 arms.push(arm);
143
144 if let Some(else_stmts) = else_block {
146 if else_stmts.len() == 1 {
147 if matches!(else_stmts[0], HirStatement::If { .. }) {
149 self.collect_arms_recursive(&else_stmts[0], arms);
150 return;
151 }
152 }
153
154 let else_body = self.generate_block_body(else_stmts);
156 arms.push(format!("_ => {}", else_body));
157 }
158 }
159 }
160 }
161
162 fn capitalize_tag_value(tag_value: &str) -> String {
164 let parts: Vec<String> = tag_value
165 .split('_')
166 .filter(|s| !s.is_empty())
167 .map(|part| {
168 let mut chars = part.chars();
169 match chars.next() {
170 None => String::new(),
171 Some(first) => {
172 let rest: String = chars.collect::<String>().to_lowercase();
173 first.to_uppercase().collect::<String>() + &rest
174 }
175 }
176 })
177 .collect();
178
179 if parts.is_empty() {
180 tag_value.to_string()
181 } else {
182 parts.join("")
183 }
184 }
185
186 fn extract_union_field_binding(then_block: &[HirStatement]) -> Option<String> {
188 for stmt in then_block {
189 if let Some(field) = Self::find_union_field_in_stmt(stmt) {
190 return Some(field);
191 }
192 }
193 None
194 }
195
196 fn find_union_field_in_stmt(stmt: &HirStatement) -> Option<String> {
198 match stmt {
199 HirStatement::Return(Some(expr)) => Self::find_union_field_in_expr(expr),
200 HirStatement::Expression(expr) => Self::find_union_field_in_expr(expr),
201 _ => None,
202 }
203 }
204
205 fn find_union_field_in_expr(expr: &HirExpression) -> Option<String> {
207 if let HirExpression::FieldAccess { object, field } = expr {
208 if let HirExpression::FieldAccess {
210 object: _inner_obj,
211 field: inner_field,
212 } = &**object
213 {
214 if inner_field == "data" {
215 return Some(field.clone());
216 }
217 }
218 }
219 None
220 }
221
222 fn generate_arm_body(&self, then_block: &[HirStatement]) -> String {
224 if then_block.len() == 1 {
225 if let HirStatement::Return(Some(expr)) = &then_block[0] {
226 let field = Self::find_union_field_in_expr(expr);
228 if let Some(f) = field {
229 return format!("return {}", f);
230 }
231 return "return /* value */".to_string();
232 }
233 }
234
235 self.generate_block_body(then_block)
236 }
237
238 fn generate_block_body(&self, stmts: &[HirStatement]) -> String {
240 if stmts.is_empty() {
241 return "{}".to_string();
242 }
243
244 if stmts.len() == 1 && matches!(&stmts[0], HirStatement::Return(Some(_))) {
245 return "/* return value */".to_string();
246 }
247
248 "{ /* block */ }".to_string()
249 }
250}
251
252impl Default for PatternGenerator {
253 fn default() -> Self {
254 Self::new()
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261
262 #[test]
263 fn test_pattern_generator_default_impl() {
264 let gen = PatternGenerator::default();
265 let result = gen.transform_tag_check(&HirStatement::Break);
267 assert!(result.is_empty());
268 }
269
270 #[test]
271 fn test_capitalize_tag_value_single_underscore() {
272 let result = PatternGenerator::capitalize_tag_value("_");
274 assert_eq!(result, "_");
275 }
276
277 #[test]
278 fn test_capitalize_tag_value_empty() {
279 let result = PatternGenerator::capitalize_tag_value("");
280 assert_eq!(result, "");
281 }
282
283 #[test]
284 fn test_generate_block_body_single_return() {
285 let gen = PatternGenerator::new();
286 let condition = HirExpression::BinaryOp {
288 left: Box::new(HirExpression::FieldAccess {
289 object: Box::new(HirExpression::Variable("v".to_string())),
290 field: "tag".to_string(),
291 }),
292 op: BinaryOperator::Equal,
293 right: Box::new(HirExpression::Variable("INT".to_string())),
294 };
295
296 let stmt = HirStatement::If {
297 condition,
298 then_block: vec![],
299 else_block: Some(vec![HirStatement::Return(Some(
300 HirExpression::IntLiteral(-1),
301 ))]),
302 };
303
304 let result = gen.transform_tag_check(&stmt);
305 assert!(result.contains("_ => /* return value */"));
307 }
308
309 #[test]
310 fn test_find_union_field_in_non_field_access() {
311 let result = PatternGenerator::find_union_field_in_expr(&HirExpression::IntLiteral(42));
313 assert!(result.is_none());
314 }
315
316 #[test]
317 fn test_find_union_field_single_field_access() {
318 let expr = HirExpression::FieldAccess {
320 object: Box::new(HirExpression::Variable("v".to_string())),
321 field: "value".to_string(),
322 };
323 let result = PatternGenerator::find_union_field_in_expr(&expr);
324 assert!(result.is_none());
325 }
326}