decy_codegen/
pattern_gen.rs1use decy_hir::{BinaryOperator, HirExpression, HirStatement};
36
37pub struct PatternGenerator;
47
48impl PatternGenerator {
49 pub fn new() -> Self {
51 Self
52 }
53
54 pub fn transform_tag_check(&self, stmt: &HirStatement) -> String {
64 if let HirStatement::If { condition, then_block: _, else_block: _ } = stmt {
65 if let Some((var_name, _tag_field, _tag_value)) = Self::extract_tag_check(condition) {
67 let mut result = String::new();
68
69 result.push_str(&format!("match {} {{\n", var_name));
71
72 let arms = self.collect_match_arms(stmt);
74
75 for arm in arms {
76 result.push_str(&format!(" {},\n", arm));
77 }
78
79 result.push('}');
80 return result;
81 }
82 }
83
84 String::new()
86 }
87
88 fn extract_tag_check(condition: &HirExpression) -> Option<(String, String, String)> {
91 if let HirExpression::BinaryOp { left, op, right } = condition {
92 if !matches!(op, BinaryOperator::Equal) {
93 return None;
94 }
95
96 if let HirExpression::FieldAccess { object, field } = &**left {
98 if let HirExpression::Variable(var_name) = &**object {
99 if let HirExpression::Variable(tag_value) = &**right {
101 return Some((var_name.clone(), field.clone(), tag_value.clone()));
102 }
103 }
104 }
105 }
106 None
107 }
108
109 fn collect_match_arms(&self, stmt: &HirStatement) -> Vec<String> {
111 let mut arms = Vec::new();
112 self.collect_arms_recursive(stmt, &mut arms);
113 arms
114 }
115
116 fn collect_arms_recursive(&self, stmt: &HirStatement, arms: &mut Vec<String>) {
118 if let HirStatement::If { condition, then_block, else_block } = stmt {
119 if let Some((_var_name, _tag_field, tag_value)) = Self::extract_tag_check(condition) {
120 let variant_name = Self::capitalize_tag_value(&tag_value);
122 let binding = Self::extract_union_field_binding(then_block);
123
124 let arm_body = self.generate_arm_body(then_block);
125
126 let arm = if let Some(field) = binding {
127 format!("Value::{}({}) => {}", variant_name, field, arm_body)
128 } else {
129 format!("Value::{} => {}", variant_name, arm_body)
130 };
131
132 arms.push(arm);
133
134 if let Some(else_stmts) = else_block {
136 if else_stmts.len() == 1 {
137 if matches!(else_stmts[0], HirStatement::If { .. }) {
139 self.collect_arms_recursive(&else_stmts[0], arms);
140 return;
141 }
142 }
143
144 let else_body = self.generate_block_body(else_stmts);
146 arms.push(format!("_ => {}", else_body));
147 }
148 }
149 }
150 }
151
152 fn capitalize_tag_value(tag_value: &str) -> String {
154 let parts: Vec<String> = tag_value
155 .split('_')
156 .filter(|s| !s.is_empty())
157 .map(|part| {
158 let mut chars = part.chars();
159 match chars.next() {
160 None => String::new(),
161 Some(first) => {
162 let rest: String = chars.collect::<String>().to_lowercase();
163 first.to_uppercase().collect::<String>() + &rest
164 }
165 }
166 })
167 .collect();
168
169 if parts.is_empty() {
170 tag_value.to_string()
171 } else {
172 parts.join("")
173 }
174 }
175
176 fn extract_union_field_binding(then_block: &[HirStatement]) -> Option<String> {
178 for stmt in then_block {
179 if let Some(field) = Self::find_union_field_in_stmt(stmt) {
180 return Some(field);
181 }
182 }
183 None
184 }
185
186 fn find_union_field_in_stmt(stmt: &HirStatement) -> Option<String> {
188 match stmt {
189 HirStatement::Return(Some(expr)) => Self::find_union_field_in_expr(expr),
190 HirStatement::Expression(expr) => Self::find_union_field_in_expr(expr),
191 _ => None,
192 }
193 }
194
195 fn find_union_field_in_expr(expr: &HirExpression) -> Option<String> {
197 if let HirExpression::FieldAccess { object, field } = expr {
198 if let HirExpression::FieldAccess { object: _inner_obj, field: inner_field } = &**object
200 {
201 if inner_field == "data" {
202 return Some(field.clone());
203 }
204 }
205 }
206 None
207 }
208
209 fn generate_arm_body(&self, then_block: &[HirStatement]) -> String {
211 if then_block.len() == 1 {
212 if let HirStatement::Return(Some(expr)) = &then_block[0] {
213 let field = Self::find_union_field_in_expr(expr);
215 if let Some(f) = field {
216 return format!("return {}", f);
217 }
218 return "return /* value */".to_string();
219 }
220 }
221
222 self.generate_block_body(then_block)
223 }
224
225 fn generate_block_body(&self, stmts: &[HirStatement]) -> String {
227 if stmts.is_empty() {
228 return "{}".to_string();
229 }
230
231 if stmts.len() == 1 && matches!(&stmts[0], HirStatement::Return(Some(_))) {
232 return "/* return value */".to_string();
233 }
234
235 "{ /* block */ }".to_string()
236 }
237}
238
239impl Default for PatternGenerator {
240 fn default() -> Self {
241 Self::new()
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248
249 #[test]
250 fn test_pattern_generator_default_impl() {
251 let gen = PatternGenerator::default();
252 let result = gen.transform_tag_check(&HirStatement::Break);
254 assert!(result.is_empty());
255 }
256
257 #[test]
258 fn test_capitalize_tag_value_single_underscore() {
259 let result = PatternGenerator::capitalize_tag_value("_");
261 assert_eq!(result, "_");
262 }
263
264 #[test]
265 fn test_capitalize_tag_value_empty() {
266 let result = PatternGenerator::capitalize_tag_value("");
267 assert_eq!(result, "");
268 }
269
270 #[test]
271 fn test_generate_block_body_single_return() {
272 let gen = PatternGenerator::new();
273 let condition = HirExpression::BinaryOp {
275 left: Box::new(HirExpression::FieldAccess {
276 object: Box::new(HirExpression::Variable("v".to_string())),
277 field: "tag".to_string(),
278 }),
279 op: BinaryOperator::Equal,
280 right: Box::new(HirExpression::Variable("INT".to_string())),
281 };
282
283 let stmt = HirStatement::If {
284 condition,
285 then_block: vec![],
286 else_block: Some(vec![HirStatement::Return(Some(HirExpression::IntLiteral(-1)))]),
287 };
288
289 let result = gen.transform_tag_check(&stmt);
290 assert!(result.contains("_ => /* return value */"));
292 }
293
294 #[test]
295 fn test_find_union_field_in_non_field_access() {
296 let result = PatternGenerator::find_union_field_in_expr(&HirExpression::IntLiteral(42));
298 assert!(result.is_none());
299 }
300
301 #[test]
302 fn test_find_union_field_single_field_access() {
303 let expr = HirExpression::FieldAccess {
305 object: Box::new(HirExpression::Variable("v".to_string())),
306 field: "value".to_string(),
307 };
308 let result = PatternGenerator::find_union_field_in_expr(&expr);
309 assert!(result.is_none());
310 }
311}