1use smol_str::{SmolStr, format_smolstr};
5use std::collections::{BTreeMap, HashMap};
6use std::rc::Rc;
7
8use crate::expression_tree::Expression;
9use crate::langtype::{Struct, StructName, Type};
10
11pub fn remove_return(doc: &crate::object_tree::Document) {
12 doc.visit_all_used_components(|component| {
13 crate::object_tree::visit_all_expressions(component, |e, _| {
14 let mut ret_ty = None;
15 fn visit(e: &Expression, ret_ty: &mut Option<Type>) {
16 if ret_ty.is_some() {
17 return;
18 }
19 match e {
20 Expression::ReturnStatement(x) => {
21 *ret_ty = Some(x.as_ref().map_or(Type::Void, |x| x.ty()));
22 }
23 _ => e.visit(|e| visit(e, ret_ty)),
24 };
25 }
26 visit(e, &mut ret_ty);
27 let Some(ret_ty) = ret_ty else { return };
28 let ctx = RemoveReturnContext { ret_ty };
29 *e = process_expression(std::mem::take(e), true, &ctx, &ctx.ret_ty)
30 .into_expression(&ctx.ret_ty);
31 })
32 });
33}
34
35fn process_expression(
36 e: Expression,
37 toplevel: bool,
38 ctx: &RemoveReturnContext,
39 ty: &Type,
40) -> ExpressionResult {
41 match e {
42 Expression::DebugHook { expression, .. } => {
43 process_expression(*expression, toplevel, ctx, ty)
44 }
45 Expression::ReturnStatement(expr) => ExpressionResult::Return(expr.map(|e| *e)),
46 Expression::CodeBlock(expr) => {
47 process_codeblock(expr.into_iter().peekable(), toplevel, ty, ctx)
48 }
49 Expression::Condition { condition, true_expr, false_expr } => {
50 let te = process_expression(*true_expr, false, ctx, ty);
51 let fe = process_expression(*false_expr, false, ctx, ty);
52 match (te, fe) {
53 (ExpressionResult::Just(te), ExpressionResult::Just(fe)) => {
54 Expression::Condition { condition, true_expr: te.into(), false_expr: fe.into() }
55 .into()
56 }
57 (ExpressionResult::Just(te), ExpressionResult::Return(fe)) => {
58 ExpressionResult::MaybeReturn {
59 pre_statements: Vec::new(),
60 condition: *condition,
61 returned_value: fe,
62 actual_value: cleanup_empty_block(te),
63 }
64 }
65 (ExpressionResult::Return(te), ExpressionResult::Just(fe)) => {
66 ExpressionResult::MaybeReturn {
67 pre_statements: Vec::new(),
68 condition: Expression::UnaryOp { sub: condition, op: '!' },
69 returned_value: te,
70 actual_value: cleanup_empty_block(fe),
71 }
72 }
73 (ExpressionResult::Return(te), ExpressionResult::Return(fe)) => {
74 ExpressionResult::Return(Some(Expression::Condition {
75 condition,
76 true_expr: te.unwrap_or(Expression::CodeBlock(Vec::new())).into(),
77 false_expr: fe.unwrap_or(Expression::CodeBlock(Vec::new())).into(),
78 }))
79 }
80 (te, fe) => {
81 let has_value = has_value(ty) && (te.has_value() || fe.has_value());
82 let ty = if has_value { ty } else { &Type::Void };
83 let te = te.into_return_object(ty, &ctx.ret_ty);
84 let fe = fe.into_return_object(ty, &ctx.ret_ty);
85 ExpressionResult::ReturnObject {
86 has_value,
87 has_return_value: self::has_value(&ctx.ret_ty),
88 value: Expression::Condition {
89 condition,
90 true_expr: te.into(),
91 false_expr: fe.into(),
92 },
93 }
94 }
95 }
96 }
97 Expression::Cast { from, to } => {
98 let ty = if !has_value(ty) { ty.clone() } else { from.ty() };
99 process_expression(*from, toplevel, ctx, &ty)
100 .map_value(|e| Expression::Cast { from: e.into(), to })
101 }
102 Expression::StoreLocalVariable { name, value } => {
103 let inner_ty = value.ty();
104 match process_expression(*value, false, ctx, &inner_ty) {
105 ExpressionResult::Just(e) => {
106 ExpressionResult::Just(Expression::StoreLocalVariable {
107 name,
108 value: Box::new(e),
109 })
110 }
111 ExpressionResult::Return(r) => ExpressionResult::Return(r),
112 ExpressionResult::MaybeReturn {
113 pre_statements,
114 condition,
115 returned_value,
116 actual_value,
117 } => ExpressionResult::MaybeReturn {
118 pre_statements,
119 condition,
120 returned_value,
121 actual_value: Some(Expression::StoreLocalVariable {
122 name,
123 value: Box::new(
124 actual_value.unwrap_or(Expression::default_value_for_type(&inner_ty)),
125 ),
126 }),
127 },
128 ExpressionResult::ReturnObject { value, has_return_value, .. } => {
129 static COUNT: std::sync::atomic::AtomicUsize =
130 std::sync::atomic::AtomicUsize::new(0);
131 let tmp_name: SmolStr = format_smolstr!(
132 "return_check_store{}",
133 COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
134 );
135 let value_ty = value.ty();
136 let load = |field: &str| Expression::StructFieldAccess {
137 base: Box::new(Expression::ReadLocalVariable {
138 name: tmp_name.clone(),
139 ty: value_ty.clone(),
140 }),
141 name: field.into(),
142 };
143 let condition = load(FIELD_CONDITION);
144 let returned_value = has_return_value.then(|| load(FIELD_RETURNED));
145 let actual_value = Some(Expression::StoreLocalVariable {
146 name,
147 value: Box::new(load(FIELD_ACTUAL)),
148 });
149 ExpressionResult::MaybeReturn {
150 pre_statements: vec![Expression::StoreLocalVariable {
151 name: tmp_name,
152 value: Box::new(value),
153 }],
154 condition,
155 returned_value,
156 actual_value,
157 }
158 }
159 }
160 }
161 e => {
162 #[cfg(debug_assertions)]
164 {
165 e.visit_recursive(&mut |e| assert!(!matches!(e, Expression::ReturnStatement(_))));
166 }
167 ExpressionResult::Just(e)
168 }
169 }
170}
171
172fn cleanup_empty_block(te: Expression) -> Option<Expression> {
174 if matches!(&te, Expression::CodeBlock(stmts) if stmts.is_empty()) { None } else { Some(te) }
175}
176
177fn process_codeblock(
178 mut iter: std::iter::Peekable<impl Iterator<Item = Expression>>,
179 toplevel: bool,
180 ty: &Type,
181 ctx: &RemoveReturnContext,
182) -> ExpressionResult {
183 let mut stmts = Vec::new();
184 while let Some(e) = iter.next() {
185 let is_last = iter.peek().is_none();
186 match process_expression(e, toplevel, ctx, if is_last { ty } else { &Type::Void }) {
187 ExpressionResult::Just(x) => stmts.push(x),
188 ExpressionResult::Return(x) => {
189 stmts.extend(x);
190 return ExpressionResult::Return(
191 (!stmts.is_empty()).then_some(Expression::CodeBlock(stmts)),
192 );
193 }
194 ExpressionResult::MaybeReturn {
195 mut pre_statements,
196 condition,
197 returned_value,
198 actual_value,
199 } => {
200 stmts.append(&mut pre_statements);
201 if is_last {
202 return ExpressionResult::MaybeReturn {
203 pre_statements: stmts,
204 condition,
205 returned_value,
206 actual_value,
207 };
208 } else if toplevel {
209 let rest = process_codeblock(iter, true, ty, ctx).into_expression(&ctx.ret_ty);
210 let mut rest_ex = Expression::CodeBlock(
211 actual_value.into_iter().chain(core::iter::once(rest)).collect(),
212 );
213 if rest_ex.ty() != ctx.ret_ty {
214 rest_ex =
215 Expression::Cast { from: Box::new(rest_ex), to: ctx.ret_ty.clone() }
216 }
217 return ExpressionResult::MaybeReturn {
218 pre_statements: stmts,
219 condition,
220 returned_value,
221 actual_value: Some(rest_ex),
222 };
223 } else {
224 return continue_codeblock(
225 iter,
226 ty,
227 ctx,
228 ExpressionResult::MaybeReturn {
229 pre_statements: Vec::new(),
230 condition,
231 returned_value,
232 actual_value,
233 }
234 .into_return_object(ty, &ctx.ret_ty),
235 stmts,
236 has_value(&ctx.ret_ty),
237 );
238 }
239 }
240 ExpressionResult::ReturnObject { value, has_value, has_return_value } => {
241 if is_last {
242 return ExpressionResult::ReturnObject {
243 value: codeblock_with_expr(stmts, value),
244 has_value,
245 has_return_value,
246 };
247 } else {
248 return continue_codeblock(iter, ty, ctx, value, stmts, has_return_value);
249 }
250 }
251 }
252 }
253 ExpressionResult::Just(Expression::CodeBlock(stmts))
254}
255
256fn continue_codeblock(
257 iter: std::iter::Peekable<impl Iterator<Item = Expression>>,
258 ty: &Type,
259 ctx: &RemoveReturnContext,
260 return_object: Expression,
261 mut stmts: Vec<Expression>,
262 has_return_value: bool,
263) -> ExpressionResult {
264 let rest = process_codeblock(iter, false, ty, ctx).into_return_object(ty, &ctx.ret_ty);
265 static COUNT: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0);
266 let unique_name = format_smolstr!(
267 "return_check_merge{}",
268 COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
269 );
270 let load = Box::new(Expression::ReadLocalVariable {
271 name: unique_name.clone(),
272 ty: return_object.ty(),
273 });
274 stmts.push(Expression::StoreLocalVariable { name: unique_name, value: return_object.into() });
275 stmts.push(Expression::Condition {
276 condition: Expression::StructFieldAccess {
277 base: load.clone(),
278 name: FIELD_CONDITION.into(),
279 }
280 .into(),
281 true_expr: rest.into(),
282 false_expr: ExpressionResult::Return(has_return_value.then(|| {
283 Expression::StructFieldAccess { base: load.clone(), name: FIELD_RETURNED.into() }
284 }))
285 .into_return_object(ty, &ctx.ret_ty)
286 .into(),
287 });
288 ExpressionResult::ReturnObject {
289 value: Expression::CodeBlock(stmts),
290 has_value: has_value(ty),
291 has_return_value,
292 }
293}
294
295struct RemoveReturnContext {
296 ret_ty: Type,
297}
298
299#[derive(Debug)]
300#[allow(clippy::large_enum_variant)]
301enum ExpressionResult {
302 Just(Expression),
304 MaybeReturn {
306 pre_statements: Vec<Expression>,
308 condition: Expression,
310 returned_value: Option<Expression>,
312 actual_value: Option<Expression>,
314 },
315 Return(Option<Expression>),
317 ReturnObject { value: Expression, has_value: bool, has_return_value: bool },
320}
321
322impl From<Expression> for ExpressionResult {
323 fn from(v: Expression) -> Self {
324 Self::Just(v)
325 }
326}
327
328const FIELD_CONDITION: &str = "condition";
329const FIELD_ACTUAL: &str = "actual";
330const FIELD_RETURNED: &str = "returned";
331
332impl ExpressionResult {
333 fn into_expression(self, ty: &Type) -> Expression {
334 match self {
335 ExpressionResult::Just(e) => e,
336 ExpressionResult::Return(e) => e.unwrap_or(Expression::CodeBlock(Vec::new())),
337 ExpressionResult::MaybeReturn {
338 mut pre_statements,
339 condition,
340 returned_value,
341 actual_value,
342 } => {
343 pre_statements.push(Expression::Condition {
344 condition: condition.into(),
345 true_expr: actual_value.unwrap_or(Expression::CodeBlock(Vec::new())).into(),
346 false_expr: returned_value.unwrap_or(Expression::CodeBlock(Vec::new())).into(),
347 });
348 Expression::CodeBlock(pre_statements)
349 }
350 ExpressionResult::ReturnObject { value, has_value, has_return_value } => {
351 static COUNT: std::sync::atomic::AtomicUsize =
352 std::sync::atomic::AtomicUsize::new(0);
353 let name = format_smolstr!(
354 "returned_expression{}",
355 COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
356 );
357 let load =
358 Box::new(Expression::ReadLocalVariable { name: name.clone(), ty: value.ty() });
359 Expression::CodeBlock(vec![
360 Expression::StoreLocalVariable { name, value: value.into() },
361 Expression::Condition {
362 condition: Expression::StructFieldAccess {
363 base: load.clone(),
364 name: FIELD_CONDITION.into(),
365 }
366 .into(),
367 true_expr: if has_value {
368 Expression::StructFieldAccess {
369 base: load.clone(),
370 name: FIELD_ACTUAL.into(),
371 }
372 } else {
373 Expression::default_value_for_type(ty)
374 }
375 .into(),
376 false_expr: if has_return_value {
377 Expression::StructFieldAccess {
378 base: load.clone(),
379 name: FIELD_RETURNED.into(),
380 }
381 } else {
382 Expression::default_value_for_type(ty)
383 }
384 .into(),
385 },
386 ])
387 }
388 }
389 }
390
391 fn into_return_object(self, ty: &Type, ret_ty: &Type) -> Expression {
392 match self {
393 ExpressionResult::Just(e) => {
394 let ret_value = Expression::default_value_for_type(ret_ty);
395 if has_value(ty) {
396 make_struct(
397 [
398 (FIELD_CONDITION, Type::Bool, Expression::BoolLiteral(true)),
399 (FIELD_RETURNED, ret_ty.clone(), ret_value),
400 (FIELD_ACTUAL, e.ty(), e),
401 ]
402 .into_iter(),
403 )
404 } else {
405 let object = make_struct(
406 [
407 (FIELD_CONDITION, Type::Bool, Expression::BoolLiteral(true)),
408 (FIELD_RETURNED, ret_ty.clone(), ret_value),
409 ]
410 .into_iter(),
411 );
412 if e.is_constant(None) {
413 object
414 } else {
415 Expression::CodeBlock(vec![e, object])
416 }
417 }
418 }
419 ExpressionResult::MaybeReturn {
420 pre_statements,
421 condition,
422 returned_value,
423 actual_value,
424 } => {
425 let mut true_expr = match actual_value {
426 Some(e) => ExpressionResult::Just(e).into_return_object(ty, ret_ty),
427 None => make_struct(
428 [(FIELD_CONDITION, Type::Bool, Expression::BoolLiteral(true))].into_iter(),
429 ),
430 };
431 let mut false_expr =
432 ExpressionResult::Return(returned_value).into_return_object(ty, ret_ty);
433 let true_ty = true_expr.ty();
434 let false_ty = false_expr.ty();
435 if true_ty != false_ty {
436 let common_ty = Expression::common_target_type_for_type_list(
437 [&true_ty, &false_ty].into_iter().cloned(),
438 );
439 if common_ty != true_ty {
440 true_expr =
441 convert_struct(std::mem::take(&mut true_expr), common_ty.clone())
442 }
443 if common_ty != false_ty {
444 false_expr = convert_struct(std::mem::take(&mut false_expr), common_ty)
445 }
446 }
447 let o = Expression::Condition {
448 condition: condition.into(),
449 true_expr: true_expr.into(),
450 false_expr: false_expr.into(),
451 };
452 codeblock_with_expr(pre_statements, o)
453 }
454 ExpressionResult::Return(r) => make_struct(
455 [(FIELD_CONDITION, Type::Bool, Expression::BoolLiteral(false))]
456 .into_iter()
457 .chain(r.map(|r| (FIELD_RETURNED, ret_ty.clone(), r)))
458 .chain(has_value(ty).then(|| {
459 (FIELD_ACTUAL, ty.clone(), Expression::default_value_for_type(ty))
460 })),
461 ),
462 ExpressionResult::ReturnObject { value, .. } => value,
463 }
464 }
465
466 fn map_value(self, f: impl FnOnce(Expression) -> Expression) -> Self {
467 match self {
468 ExpressionResult::Just(e) => ExpressionResult::Just(f(e)),
469 ExpressionResult::Return(e) => ExpressionResult::Return(e),
470 ExpressionResult::MaybeReturn {
471 pre_statements,
472 condition,
473 returned_value,
474 actual_value,
475 } => ExpressionResult::MaybeReturn {
476 pre_statements,
477 condition,
478 returned_value,
479 actual_value: actual_value.map(f),
480 },
481 ExpressionResult::ReturnObject { value, has_value, has_return_value } => {
482 if !has_value {
483 return ExpressionResult::ReturnObject { value, has_value, has_return_value };
484 }
485 static COUNT: std::sync::atomic::AtomicUsize =
486 std::sync::atomic::AtomicUsize::new(0);
487 let name = format_smolstr!(
488 "mapped_expression{}",
489 COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
490 );
491 let value_ty = value.ty();
492 let load = |field: &str| Expression::StructFieldAccess {
493 base: Box::new(Expression::ReadLocalVariable {
494 name: name.clone(),
495 ty: value_ty.clone(),
496 }),
497 name: field.into(),
498 };
499 let condition = (FIELD_CONDITION, Type::Bool, load(FIELD_CONDITION));
500 let actual = f(load(FIELD_ACTUAL));
501 let actual = (FIELD_ACTUAL, actual.ty(), actual);
502 let ret = has_return_value.then(|| {
503 let r = load(FIELD_RETURNED);
504 (FIELD_RETURNED, r.ty(), r)
505 });
506 ExpressionResult::ReturnObject {
507 value: Expression::CodeBlock(vec![
508 Expression::StoreLocalVariable { name, value: value.into() },
509 make_struct([condition, actual].into_iter().chain(ret.into_iter())),
510 ]),
511 has_value,
512 has_return_value,
513 }
514 }
515 }
516 }
517
518 fn has_value(&self) -> bool {
519 match self {
520 ExpressionResult::Just(expression) => has_value(&expression.ty()),
521 ExpressionResult::MaybeReturn { actual_value, .. } => {
522 actual_value.as_ref().is_some_and(|x| has_value(&x.ty()))
523 }
524 ExpressionResult::Return(..) => false,
525 ExpressionResult::ReturnObject { has_value, .. } => *has_value,
526 }
527 }
528}
529
530fn codeblock_with_expr(mut pre_statements: Vec<Expression>, expr: Expression) -> Expression {
531 if pre_statements.is_empty() {
532 expr
533 } else {
534 pre_statements.push(expr);
535 Expression::CodeBlock(pre_statements)
536 }
537}
538
539fn make_struct(it: impl Iterator<Item = (&'static str, Type, Expression)>) -> Expression {
540 let mut fields = BTreeMap::<SmolStr, Type>::new();
541 let mut values = HashMap::<SmolStr, Expression>::new();
542 let mut voids = Vec::new();
543 for (name, ty, expr) in it {
544 if !has_value(&ty) {
545 if ty != Type::Invalid {
546 voids.push(expr);
547 }
548 continue;
549 }
550 fields.insert(name.into(), ty);
551 values.insert(name.into(), expr);
552 }
553 codeblock_with_expr(
554 voids,
555 Expression::Struct { ty: Rc::new(Struct { fields, name: StructName::None }), values },
556 )
557}
558
559fn convert_struct(from: Expression, to: Type) -> Expression {
562 let Type::Struct(to) = to else {
563 assert_eq!(to, Type::Invalid);
564 return Expression::Invalid;
565 };
566 if let Expression::Struct { mut values, .. } = from {
567 let mut new_values = HashMap::new();
568 for (key, ty) in &to.fields {
569 let (key, expression) = values
570 .remove_entry(key)
571 .unwrap_or_else(|| (key.clone(), Expression::default_value_for_type(ty)));
572 new_values.insert(key, expression);
573 }
574 return Expression::Struct { values: new_values, ty: to };
575 }
576 static COUNT: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0);
577 let var_name = format_smolstr!(
578 "tmpobj_ret_conv_{}",
579 COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
580 );
581 let from_ty = from.ty();
582 let mut new_values = HashMap::new();
583 let Type::Struct(from_s) = &from_ty else {
584 assert_eq!(from_ty, Type::Invalid);
585 return Expression::Invalid;
586 };
587 for (key, ty) in &to.fields {
588 let expression = if from_s.fields.contains_key(key) {
589 Expression::StructFieldAccess {
590 base: Box::new(Expression::ReadLocalVariable {
591 name: var_name.clone(),
592 ty: from_ty.clone(),
593 }),
594 name: key.clone(),
595 }
596 } else {
597 Expression::default_value_for_type(ty)
598 };
599 new_values.insert(key.clone(), expression);
600 }
601 Expression::CodeBlock(vec![
602 Expression::StoreLocalVariable { name: var_name, value: Box::new(from) },
603 Expression::Struct { values: new_values, ty: to },
604 ])
605}
606
607fn has_value(ty: &Type) -> bool {
608 !matches!(ty, Type::Void | Type::Invalid)
609}