1use smol_str::{SmolStr, format_smolstr};
5use std::collections::{BTreeMap, HashMap};
6use std::rc::Rc;
7
8use crate::expression_tree::Expression;
9use crate::langtype::{Struct, StructName, Type};
10
11pub fn remove_return(doc: &crate::object_tree::Document) {
12 doc.visit_all_used_components(|component| {
13 crate::object_tree::visit_all_expressions(component, |e, _| {
14 let mut ret_ty = None;
15 fn visit(e: &Expression, ret_ty: &mut Option<Type>) {
16 if ret_ty.is_some() {
17 return;
18 }
19 match e {
20 Expression::ReturnStatement(x) => {
21 *ret_ty = Some(x.as_ref().map_or(Type::Void, |x| x.ty()));
22 }
23 _ => e.visit(|e| visit(e, ret_ty)),
24 };
25 }
26 visit(e, &mut ret_ty);
27 let Some(ret_ty) = ret_ty else { return };
28 let ctx = RemoveReturnContext { ret_ty };
29 *e = process_expression(std::mem::take(e), true, &ctx, &ctx.ret_ty)
30 .to_expression(&ctx.ret_ty);
31 })
32 });
33}
34
35fn process_expression(
36 e: Expression,
37 toplevel: bool,
38 ctx: &RemoveReturnContext,
39 ty: &Type,
40) -> ExpressionResult {
41 match e {
42 Expression::DebugHook { expression, .. } => {
43 process_expression(*expression, toplevel, ctx, ty)
44 }
45 Expression::ReturnStatement(expr) => ExpressionResult::Return(expr.map(|e| *e)),
46 Expression::CodeBlock(expr) => {
47 process_codeblock(expr.into_iter().peekable(), toplevel, ty, ctx)
48 }
49 Expression::Condition { condition, true_expr, false_expr } => {
50 let te = process_expression(*true_expr, false, ctx, ty);
51 let fe = process_expression(*false_expr, false, ctx, ty);
52 match (te, fe) {
53 (ExpressionResult::Just(te), ExpressionResult::Just(fe)) => {
54 Expression::Condition { condition, true_expr: te.into(), false_expr: fe.into() }
55 .into()
56 }
57 (ExpressionResult::Just(te), ExpressionResult::Return(fe)) => {
58 ExpressionResult::MaybeReturn {
59 pre_statements: Vec::new(),
60 condition: *condition,
61 returned_value: fe,
62 actual_value: cleanup_empty_block(te),
63 }
64 }
65 (ExpressionResult::Return(te), ExpressionResult::Just(fe)) => {
66 ExpressionResult::MaybeReturn {
67 pre_statements: Vec::new(),
68 condition: Expression::UnaryOp { sub: condition, op: '!' },
69 returned_value: te,
70 actual_value: cleanup_empty_block(fe),
71 }
72 }
73 (ExpressionResult::Return(te), ExpressionResult::Return(fe)) => {
74 ExpressionResult::Return(Some(Expression::Condition {
75 condition,
76 true_expr: te.unwrap_or(Expression::CodeBlock(Vec::new())).into(),
77 false_expr: fe.unwrap_or(Expression::CodeBlock(Vec::new())).into(),
78 }))
79 }
80 (te, fe) => {
81 let has_value = has_value(ty) && (te.has_value() || fe.has_value());
82 let ty = if has_value { ty } else { &Type::Void };
83 let te = te.into_return_object(ty, &ctx.ret_ty);
84 let fe = fe.into_return_object(ty, &ctx.ret_ty);
85 ExpressionResult::ReturnObject {
86 has_value,
87 has_return_value: self::has_value(&ctx.ret_ty),
88 value: Expression::Condition {
89 condition,
90 true_expr: te.into(),
91 false_expr: fe.into(),
92 },
93 }
94 }
95 }
96 }
97 Expression::Cast { from, to } => {
98 let ty = if !has_value(ty) { ty.clone() } else { from.ty() };
99 process_expression(*from, toplevel, ctx, &ty)
100 .map_value(|e| Expression::Cast { from: e.into(), to })
101 }
102 e => {
103 #[cfg(debug_assertions)]
105 {
106 e.visit_recursive(&mut |e| assert!(!matches!(e, Expression::ReturnStatement(_))));
107 }
108 ExpressionResult::Just(e)
109 }
110 }
111}
112
113fn cleanup_empty_block(te: Expression) -> Option<Expression> {
115 if matches!(&te, Expression::CodeBlock(stmts) if stmts.is_empty()) { None } else { Some(te) }
116}
117
118fn process_codeblock(
119 mut iter: std::iter::Peekable<impl Iterator<Item = Expression>>,
120 toplevel: bool,
121 ty: &Type,
122 ctx: &RemoveReturnContext,
123) -> ExpressionResult {
124 let mut stmts = Vec::new();
125 while let Some(e) = iter.next() {
126 let is_last = iter.peek().is_none();
127 match process_expression(e, toplevel, ctx, if is_last { ty } else { &Type::Void }) {
128 ExpressionResult::Just(x) => stmts.push(x),
129 ExpressionResult::Return(x) => {
130 stmts.extend(x);
131 return ExpressionResult::Return(
132 (!stmts.is_empty()).then_some(Expression::CodeBlock(stmts)),
133 );
134 }
135 ExpressionResult::MaybeReturn {
136 mut pre_statements,
137 condition,
138 returned_value,
139 actual_value,
140 } => {
141 stmts.append(&mut pre_statements);
142 if is_last {
143 return ExpressionResult::MaybeReturn {
144 pre_statements: stmts,
145 condition,
146 returned_value,
147 actual_value,
148 };
149 } else if toplevel {
150 let rest = process_codeblock(iter, true, ty, ctx).to_expression(&ctx.ret_ty);
151 let mut rest_ex = Expression::CodeBlock(
152 actual_value.into_iter().chain(core::iter::once(rest)).collect(),
153 );
154 if rest_ex.ty() != ctx.ret_ty {
155 rest_ex =
156 Expression::Cast { from: Box::new(rest_ex), to: ctx.ret_ty.clone() }
157 }
158 return ExpressionResult::MaybeReturn {
159 pre_statements: stmts,
160 condition,
161 returned_value,
162 actual_value: Some(rest_ex),
163 };
164 } else {
165 return continue_codeblock(
166 iter,
167 ty,
168 ctx,
169 ExpressionResult::MaybeReturn {
170 pre_statements: Vec::new(),
171 condition,
172 returned_value,
173 actual_value,
174 }
175 .into_return_object(ty, &ctx.ret_ty),
176 stmts,
177 has_value(&ctx.ret_ty),
178 );
179 }
180 }
181 ExpressionResult::ReturnObject { value, has_value, has_return_value } => {
182 if is_last {
183 return ExpressionResult::ReturnObject {
184 value: codeblock_with_expr(stmts, value),
185 has_value,
186 has_return_value,
187 };
188 } else {
189 return continue_codeblock(iter, ty, ctx, value, stmts, has_return_value);
190 }
191 }
192 }
193 }
194 ExpressionResult::Just(Expression::CodeBlock(stmts))
195}
196
197fn continue_codeblock(
198 iter: std::iter::Peekable<impl Iterator<Item = Expression>>,
199 ty: &Type,
200 ctx: &RemoveReturnContext,
201 return_object: Expression,
202 mut stmts: Vec<Expression>,
203 has_return_value: bool,
204) -> ExpressionResult {
205 let rest = process_codeblock(iter, false, ty, ctx).into_return_object(ty, &ctx.ret_ty);
206 static COUNT: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0);
207 let unique_name = format_smolstr!(
208 "return_check_merge{}",
209 COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
210 );
211 let load = Box::new(Expression::ReadLocalVariable {
212 name: unique_name.clone(),
213 ty: return_object.ty(),
214 });
215 stmts.push(Expression::StoreLocalVariable { name: unique_name, value: return_object.into() });
216 stmts.push(Expression::Condition {
217 condition: Expression::StructFieldAccess {
218 base: load.clone(),
219 name: FIELD_CONDITION.into(),
220 }
221 .into(),
222 true_expr: rest.into(),
223 false_expr: ExpressionResult::Return(has_return_value.then(|| {
224 Expression::StructFieldAccess { base: load.clone(), name: FIELD_RETURNED.into() }
225 }))
226 .into_return_object(ty, &ctx.ret_ty)
227 .into(),
228 });
229 ExpressionResult::ReturnObject {
230 value: Expression::CodeBlock(stmts),
231 has_value: has_value(ty),
232 has_return_value,
233 }
234}
235
236struct RemoveReturnContext {
237 ret_ty: Type,
238}
239
240#[derive(Debug)]
241enum ExpressionResult {
242 Just(Expression),
244 MaybeReturn {
246 pre_statements: Vec<Expression>,
248 condition: Expression,
250 returned_value: Option<Expression>,
252 actual_value: Option<Expression>,
254 },
255 Return(Option<Expression>),
257 ReturnObject { value: Expression, has_value: bool, has_return_value: bool },
260}
261
262impl From<Expression> for ExpressionResult {
263 fn from(v: Expression) -> Self {
264 Self::Just(v)
265 }
266}
267
268const FIELD_CONDITION: &str = "condition";
269const FIELD_ACTUAL: &str = "actual";
270const FIELD_RETURNED: &str = "returned";
271
272impl ExpressionResult {
273 fn to_expression(self, ty: &Type) -> Expression {
274 match self {
275 ExpressionResult::Just(e) => e,
276 ExpressionResult::Return(e) => e.unwrap_or(Expression::CodeBlock(Vec::new())),
277 ExpressionResult::MaybeReturn {
278 mut pre_statements,
279 condition,
280 returned_value,
281 actual_value,
282 } => {
283 pre_statements.push(Expression::Condition {
284 condition: condition.into(),
285 true_expr: actual_value.unwrap_or(Expression::CodeBlock(Vec::new())).into(),
286 false_expr: returned_value.unwrap_or(Expression::CodeBlock(Vec::new())).into(),
287 });
288 Expression::CodeBlock(pre_statements)
289 }
290 ExpressionResult::ReturnObject { value, has_value, has_return_value } => {
291 static COUNT: std::sync::atomic::AtomicUsize =
292 std::sync::atomic::AtomicUsize::new(0);
293 let name = format_smolstr!(
294 "returned_expression{}",
295 COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
296 );
297 let load =
298 Box::new(Expression::ReadLocalVariable { name: name.clone(), ty: value.ty() });
299 Expression::CodeBlock(vec![
300 Expression::StoreLocalVariable { name, value: value.into() },
301 Expression::Condition {
302 condition: Expression::StructFieldAccess {
303 base: load.clone(),
304 name: FIELD_CONDITION.into(),
305 }
306 .into(),
307 true_expr: if has_value {
308 Expression::StructFieldAccess {
309 base: load.clone(),
310 name: FIELD_ACTUAL.into(),
311 }
312 } else {
313 Expression::default_value_for_type(ty)
314 }
315 .into(),
316 false_expr: if has_return_value {
317 Expression::StructFieldAccess {
318 base: load.clone(),
319 name: FIELD_RETURNED.into(),
320 }
321 } else {
322 Expression::default_value_for_type(ty)
323 }
324 .into(),
325 },
326 ])
327 }
328 }
329 }
330
331 fn into_return_object(self, ty: &Type, ret_ty: &Type) -> Expression {
332 match self {
333 ExpressionResult::Just(e) => {
334 let ret_value = Expression::default_value_for_type(ret_ty);
335 if has_value(ty) {
336 make_struct(
337 [
338 (FIELD_CONDITION, Type::Bool, Expression::BoolLiteral(true)),
339 (FIELD_RETURNED, ret_ty.clone(), ret_value),
340 (FIELD_ACTUAL, e.ty(), e),
341 ]
342 .into_iter(),
343 )
344 } else {
345 let object = make_struct(
346 [
347 (FIELD_CONDITION, Type::Bool, Expression::BoolLiteral(true)),
348 (FIELD_RETURNED, ret_ty.clone(), ret_value),
349 ]
350 .into_iter(),
351 );
352 if e.is_constant(None) {
353 object
354 } else {
355 Expression::CodeBlock(vec![e, object])
356 }
357 }
358 }
359 ExpressionResult::MaybeReturn {
360 pre_statements,
361 condition,
362 returned_value,
363 actual_value,
364 } => {
365 let mut true_expr = match actual_value {
366 Some(e) => ExpressionResult::Just(e).into_return_object(ty, ret_ty),
367 None => make_struct(
368 [(FIELD_CONDITION, Type::Bool, Expression::BoolLiteral(true))].into_iter(),
369 ),
370 };
371 let mut false_expr =
372 ExpressionResult::Return(returned_value).into_return_object(ty, ret_ty);
373 let true_ty = true_expr.ty();
374 let false_ty = false_expr.ty();
375 if true_ty != false_ty {
376 let common_ty = Expression::common_target_type_for_type_list(
377 [&true_ty, &false_ty].into_iter().cloned(),
378 );
379 if common_ty != true_ty {
380 true_expr =
381 convert_struct(std::mem::take(&mut true_expr), common_ty.clone())
382 }
383 if common_ty != false_ty {
384 false_expr = convert_struct(std::mem::take(&mut false_expr), common_ty)
385 }
386 }
387 let o = Expression::Condition {
388 condition: condition.into(),
389 true_expr: true_expr.into(),
390 false_expr: false_expr.into(),
391 };
392 codeblock_with_expr(pre_statements, o)
393 }
394 ExpressionResult::Return(r) => make_struct(
395 [(FIELD_CONDITION, Type::Bool, Expression::BoolLiteral(false))]
396 .into_iter()
397 .chain(r.map(|r| (FIELD_RETURNED, ret_ty.clone(), r)))
398 .chain(has_value(ty).then(|| {
399 (FIELD_ACTUAL, ty.clone(), Expression::default_value_for_type(ty))
400 })),
401 ),
402 ExpressionResult::ReturnObject { value, .. } => value,
403 }
404 }
405
406 fn map_value(self, f: impl FnOnce(Expression) -> Expression) -> Self {
407 match self {
408 ExpressionResult::Just(e) => ExpressionResult::Just(f(e)),
409 ExpressionResult::Return(e) => ExpressionResult::Return(e),
410 ExpressionResult::MaybeReturn {
411 pre_statements,
412 condition,
413 returned_value,
414 actual_value,
415 } => ExpressionResult::MaybeReturn {
416 pre_statements,
417 condition,
418 returned_value,
419 actual_value: actual_value.map(f),
420 },
421 ExpressionResult::ReturnObject { value, has_value, has_return_value } => {
422 if !has_value {
423 return ExpressionResult::ReturnObject { value, has_value, has_return_value };
424 }
425 static COUNT: std::sync::atomic::AtomicUsize =
426 std::sync::atomic::AtomicUsize::new(0);
427 let name = format_smolstr!(
428 "mapped_expression{}",
429 COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
430 );
431 let value_ty = value.ty();
432 let load = |field: &str| Expression::StructFieldAccess {
433 base: Box::new(Expression::ReadLocalVariable {
434 name: name.clone(),
435 ty: value_ty.clone(),
436 }),
437 name: field.into(),
438 };
439 let condition = (FIELD_CONDITION, Type::Bool, load(FIELD_CONDITION));
440 let actual = f(load(FIELD_ACTUAL));
441 let actual = (FIELD_ACTUAL, actual.ty(), actual);
442 let ret = has_return_value.then(|| {
443 let r = load(FIELD_RETURNED);
444 (FIELD_RETURNED, r.ty(), r)
445 });
446 ExpressionResult::ReturnObject {
447 value: Expression::CodeBlock(vec![
448 Expression::StoreLocalVariable { name, value: value.into() },
449 make_struct([condition, actual].into_iter().chain(ret.into_iter())),
450 ]),
451 has_value,
452 has_return_value,
453 }
454 }
455 }
456 }
457
458 fn has_value(&self) -> bool {
459 match self {
460 ExpressionResult::Just(expression) => has_value(&expression.ty()),
461 ExpressionResult::MaybeReturn { actual_value, .. } => {
462 actual_value.as_ref().is_some_and(|x| has_value(&x.ty()))
463 }
464 ExpressionResult::Return(..) => false,
465 ExpressionResult::ReturnObject { has_value, .. } => *has_value,
466 }
467 }
468}
469
470fn codeblock_with_expr(mut pre_statements: Vec<Expression>, expr: Expression) -> Expression {
471 if pre_statements.is_empty() {
472 expr
473 } else {
474 pre_statements.push(expr);
475 Expression::CodeBlock(pre_statements)
476 }
477}
478
479fn make_struct(it: impl Iterator<Item = (&'static str, Type, Expression)>) -> Expression {
480 let mut fields = BTreeMap::<SmolStr, Type>::new();
481 let mut values = HashMap::<SmolStr, Expression>::new();
482 let mut voids = Vec::new();
483 for (name, ty, expr) in it {
484 if !has_value(&ty) {
485 if ty != Type::Invalid {
486 voids.push(expr);
487 }
488 continue;
489 }
490 fields.insert(name.into(), ty);
491 values.insert(name.into(), expr);
492 }
493 codeblock_with_expr(
494 voids,
495 Expression::Struct { ty: Rc::new(Struct { fields, name: StructName::None }), values },
496 )
497}
498
499fn convert_struct(from: Expression, to: Type) -> Expression {
502 let Type::Struct(to) = to else {
503 assert_eq!(to, Type::Invalid);
504 return Expression::Invalid;
505 };
506 if let Expression::Struct { mut values, .. } = from {
507 let mut new_values = HashMap::new();
508 for (key, ty) in &to.fields {
509 let (key, expression) = values
510 .remove_entry(key)
511 .unwrap_or_else(|| (key.clone(), Expression::default_value_for_type(ty)));
512 new_values.insert(key, expression);
513 }
514 return Expression::Struct { values: new_values, ty: to };
515 }
516 static COUNT: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0);
517 let var_name = format_smolstr!(
518 "tmpobj_ret_conv_{}",
519 COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
520 );
521 let from_ty = from.ty();
522 let mut new_values = HashMap::new();
523 let Type::Struct(from_s) = &from_ty else {
524 assert_eq!(from_ty, Type::Invalid);
525 return Expression::Invalid;
526 };
527 for (key, ty) in &to.fields {
528 let expression = if from_s.fields.contains_key(key) {
529 Expression::StructFieldAccess {
530 base: Box::new(Expression::ReadLocalVariable {
531 name: var_name.clone(),
532 ty: from_ty.clone(),
533 }),
534 name: key.clone(),
535 }
536 } else {
537 Expression::default_value_for_type(ty)
538 };
539 new_values.insert(key.clone(), expression);
540 }
541 Expression::CodeBlock(vec![
542 Expression::StoreLocalVariable { name: var_name, value: Box::new(from) },
543 Expression::Struct { values: new_values, ty: to },
544 ])
545}
546
547fn has_value(ty: &Type) -> bool {
548 !matches!(ty, Type::Void | Type::Invalid)
549}