1use smol_str::{format_smolstr, SmolStr};
5use std::collections::{BTreeMap, HashMap};
6use std::rc::Rc;
7
8use crate::expression_tree::Expression;
9use crate::langtype::{Struct, Type};
10
11pub fn remove_return(doc: &crate::object_tree::Document) {
12 doc.visit_all_used_components(|component| {
13 crate::object_tree::visit_all_expressions(component, |e, _| {
14 let mut ret_ty = None;
15 fn visit(e: &Expression, ret_ty: &mut Option<Type>) {
16 if ret_ty.is_some() {
17 return;
18 }
19 match e {
20 Expression::ReturnStatement(x) => {
21 *ret_ty = Some(x.as_ref().map_or(Type::Void, |x| x.ty()));
22 }
23 _ => e.visit(|e| visit(e, ret_ty)),
24 };
25 }
26 visit(e, &mut ret_ty);
27 let Some(ret_ty) = ret_ty else { return };
28 let ctx = RemoveReturnContext { ret_ty };
29 *e = process_expression(std::mem::take(e), true, &ctx, &ctx.ret_ty)
30 .to_expression(&ctx.ret_ty);
31 })
32 });
33}
34
35fn process_expression(
36 e: Expression,
37 toplevel: bool,
38 ctx: &RemoveReturnContext,
39 ty: &Type,
40) -> ExpressionResult {
41 match e {
42 Expression::DebugHook { expression, .. } => {
43 process_expression(*expression, toplevel, ctx, ty)
44 }
45 Expression::ReturnStatement(expr) => ExpressionResult::Return(expr.map(|e| *e)),
46 Expression::CodeBlock(expr) => {
47 process_codeblock(expr.into_iter().peekable(), toplevel, ty, ctx)
48 }
49 Expression::Condition { condition, true_expr, false_expr } => {
50 let te = process_expression(*true_expr, false, ctx, ty);
51 let fe = process_expression(*false_expr, false, ctx, ty);
52 match (te, fe) {
53 (ExpressionResult::Just(te), ExpressionResult::Just(fe)) => {
54 Expression::Condition { condition, true_expr: te.into(), false_expr: fe.into() }
55 .into()
56 }
57 (ExpressionResult::Just(te), ExpressionResult::Return(fe)) => {
58 ExpressionResult::MaybeReturn {
59 pre_statements: vec![],
60 condition: *condition,
61 returned_value: fe,
62 actual_value: cleanup_empty_block(te),
63 }
64 }
65 (ExpressionResult::Return(te), ExpressionResult::Just(fe)) => {
66 ExpressionResult::MaybeReturn {
67 pre_statements: vec![],
68 condition: Expression::UnaryOp { sub: condition, op: '!' },
69 returned_value: te,
70 actual_value: cleanup_empty_block(fe),
71 }
72 }
73 (ExpressionResult::Return(te), ExpressionResult::Return(fe)) => {
74 ExpressionResult::Return(Some(Expression::Condition {
75 condition,
76 true_expr: te.unwrap_or(Expression::CodeBlock(vec![])).into(),
77 false_expr: fe.unwrap_or(Expression::CodeBlock(vec![])).into(),
78 }))
79 }
80 (te, fe) => {
81 let has_value = has_value(ty) && (te.has_value() || fe.has_value());
82 let ty = if has_value { ty } else { &Type::Void };
83 let te = te.into_return_object(ty, &ctx.ret_ty);
84 let fe = fe.into_return_object(ty, &ctx.ret_ty);
85 ExpressionResult::ReturnObject {
86 has_value,
87 has_return_value: self::has_value(&ctx.ret_ty),
88 value: Expression::Condition {
89 condition,
90 true_expr: te.into(),
91 false_expr: fe.into(),
92 },
93 }
94 }
95 }
96 }
97 Expression::Cast { from, to } => {
98 let ty = if !has_value(ty) { ty.clone() } else { from.ty() };
99 process_expression(*from, toplevel, ctx, &ty)
100 .map_value(|e| Expression::Cast { from: e.into(), to })
101 }
102 e => {
103 #[cfg(debug_assertions)]
105 {
106 e.visit_recursive(&mut |e| assert!(!matches!(e, Expression::ReturnStatement(_))));
107 }
108 ExpressionResult::Just(e)
109 }
110 }
111}
112
113fn cleanup_empty_block(te: Expression) -> Option<Expression> {
115 if matches!(&te, Expression::CodeBlock(stmts) if stmts.is_empty()) {
116 None
117 } else {
118 Some(te)
119 }
120}
121
122fn process_codeblock(
123 mut iter: std::iter::Peekable<impl Iterator<Item = Expression>>,
124 toplevel: bool,
125 ty: &Type,
126 ctx: &RemoveReturnContext,
127) -> ExpressionResult {
128 let mut stmts = vec![];
129 while let Some(e) = iter.next() {
130 let is_last = iter.peek().is_none();
131 match process_expression(e, toplevel, ctx, if is_last { ty } else { &Type::Void }) {
132 ExpressionResult::Just(x) => stmts.push(x),
133 ExpressionResult::Return(x) => {
134 stmts.extend(x);
135 return ExpressionResult::Return(
136 (!stmts.is_empty()).then_some(Expression::CodeBlock(stmts)),
137 );
138 }
139 ExpressionResult::MaybeReturn {
140 mut pre_statements,
141 condition,
142 returned_value,
143 actual_value,
144 } => {
145 stmts.append(&mut pre_statements);
146 if is_last {
147 return ExpressionResult::MaybeReturn {
148 pre_statements: stmts,
149 condition,
150 returned_value,
151 actual_value,
152 };
153 } else if toplevel {
154 let rest = process_codeblock(iter, true, ty, ctx).to_expression(&ctx.ret_ty);
155 let mut rest_ex = Expression::CodeBlock(
156 actual_value.into_iter().chain(core::iter::once(rest)).collect(),
157 );
158 if rest_ex.ty() != ctx.ret_ty {
159 rest_ex =
160 Expression::Cast { from: Box::new(rest_ex), to: ctx.ret_ty.clone() }
161 }
162 return ExpressionResult::MaybeReturn {
163 pre_statements: stmts,
164 condition,
165 returned_value,
166 actual_value: Some(rest_ex),
167 };
168 } else {
169 return continue_codeblock(
170 iter,
171 ty,
172 ctx,
173 ExpressionResult::MaybeReturn {
174 pre_statements: vec![],
175 condition,
176 returned_value,
177 actual_value,
178 }
179 .into_return_object(ty, &ctx.ret_ty),
180 stmts,
181 has_value(&ctx.ret_ty),
182 );
183 }
184 }
185 ExpressionResult::ReturnObject { value, has_value, has_return_value } => {
186 if is_last {
187 return ExpressionResult::ReturnObject {
188 value: codeblock_with_expr(stmts, value),
189 has_value,
190 has_return_value,
191 };
192 } else {
193 return continue_codeblock(iter, ty, ctx, value, stmts, has_return_value);
194 }
195 }
196 }
197 }
198 ExpressionResult::Just(Expression::CodeBlock(stmts))
199}
200
201fn continue_codeblock(
202 iter: std::iter::Peekable<impl Iterator<Item = Expression>>,
203 ty: &Type,
204 ctx: &RemoveReturnContext,
205 return_object: Expression,
206 mut stmts: Vec<Expression>,
207 has_return_value: bool,
208) -> ExpressionResult {
209 let rest = process_codeblock(iter, false, ty, ctx).into_return_object(ty, &ctx.ret_ty);
210 static COUNT: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0);
211 let unique_name = format_smolstr!(
212 "return_check_merge{}",
213 COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
214 );
215 let load = Box::new(Expression::ReadLocalVariable {
216 name: unique_name.clone(),
217 ty: return_object.ty(),
218 });
219 stmts.push(Expression::StoreLocalVariable { name: unique_name, value: return_object.into() });
220 stmts.push(Expression::Condition {
221 condition: Expression::StructFieldAccess {
222 base: load.clone(),
223 name: FIELD_CONDITION.into(),
224 }
225 .into(),
226 true_expr: rest.into(),
227 false_expr: ExpressionResult::Return(has_return_value.then(|| {
228 Expression::StructFieldAccess { base: load.clone(), name: FIELD_RETURNED.into() }
229 }))
230 .into_return_object(ty, &ctx.ret_ty)
231 .into(),
232 });
233 ExpressionResult::ReturnObject {
234 value: Expression::CodeBlock(stmts),
235 has_value: has_value(ty),
236 has_return_value,
237 }
238}
239
240struct RemoveReturnContext {
241 ret_ty: Type,
242}
243
244#[derive(Debug)]
245enum ExpressionResult {
246 Just(Expression),
248 MaybeReturn {
250 pre_statements: Vec<Expression>,
252 condition: Expression,
254 returned_value: Option<Expression>,
256 actual_value: Option<Expression>,
258 },
259 Return(Option<Expression>),
261 ReturnObject { value: Expression, has_value: bool, has_return_value: bool },
264}
265
266impl From<Expression> for ExpressionResult {
267 fn from(v: Expression) -> Self {
268 Self::Just(v)
269 }
270}
271
272const FIELD_CONDITION: &str = "condition";
273const FIELD_ACTUAL: &str = "actual";
274const FIELD_RETURNED: &str = "returned";
275
276impl ExpressionResult {
277 fn to_expression(self, ty: &Type) -> Expression {
278 match self {
279 ExpressionResult::Just(e) => e,
280 ExpressionResult::Return(e) => e.unwrap_or(Expression::CodeBlock(vec![])),
281 ExpressionResult::MaybeReturn {
282 mut pre_statements,
283 condition,
284 returned_value,
285 actual_value,
286 } => {
287 pre_statements.push(Expression::Condition {
288 condition: condition.into(),
289 true_expr: actual_value.unwrap_or(Expression::CodeBlock(vec![])).into(),
290 false_expr: returned_value.unwrap_or(Expression::CodeBlock(vec![])).into(),
291 });
292 Expression::CodeBlock(pre_statements)
293 }
294 ExpressionResult::ReturnObject { value, has_value, has_return_value } => {
295 static COUNT: std::sync::atomic::AtomicUsize =
296 std::sync::atomic::AtomicUsize::new(0);
297 let name = format_smolstr!(
298 "returned_expression{}",
299 COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
300 );
301 let load =
302 Box::new(Expression::ReadLocalVariable { name: name.clone(), ty: value.ty() });
303 Expression::CodeBlock(vec![
304 Expression::StoreLocalVariable { name, value: value.into() },
305 Expression::Condition {
306 condition: Expression::StructFieldAccess {
307 base: load.clone(),
308 name: FIELD_CONDITION.into(),
309 }
310 .into(),
311 true_expr: if has_value {
312 Expression::StructFieldAccess {
313 base: load.clone(),
314 name: FIELD_ACTUAL.into(),
315 }
316 } else {
317 Expression::default_value_for_type(ty)
318 }
319 .into(),
320 false_expr: if has_return_value {
321 Expression::StructFieldAccess {
322 base: load.clone(),
323 name: FIELD_RETURNED.into(),
324 }
325 } else {
326 Expression::default_value_for_type(ty)
327 }
328 .into(),
329 },
330 ])
331 }
332 }
333 }
334
335 fn into_return_object(self, ty: &Type, ret_ty: &Type) -> Expression {
336 match self {
337 ExpressionResult::Just(e) => {
338 let ret_value = Expression::default_value_for_type(ret_ty);
339 if has_value(ty) {
340 make_struct(
341 [
342 (FIELD_CONDITION, Type::Bool, Expression::BoolLiteral(true)),
343 (FIELD_RETURNED, ret_ty.clone(), ret_value),
344 (FIELD_ACTUAL, e.ty(), e),
345 ]
346 .into_iter(),
347 )
348 } else {
349 let object = make_struct(
350 [
351 (FIELD_CONDITION, Type::Bool, Expression::BoolLiteral(true)),
352 (FIELD_RETURNED, ret_ty.clone(), ret_value),
353 ]
354 .into_iter(),
355 );
356 if e.is_constant(None) {
357 object
358 } else {
359 Expression::CodeBlock(vec![e, object])
360 }
361 }
362 }
363 ExpressionResult::MaybeReturn {
364 pre_statements,
365 condition,
366 returned_value,
367 actual_value,
368 } => {
369 let mut true_expr = match actual_value {
370 Some(e) => ExpressionResult::Just(e).into_return_object(ty, ret_ty),
371 None => make_struct(
372 [(FIELD_CONDITION, Type::Bool, Expression::BoolLiteral(true))].into_iter(),
373 ),
374 };
375 let mut false_expr =
376 ExpressionResult::Return(returned_value).into_return_object(ty, ret_ty);
377 let true_ty = true_expr.ty();
378 let false_ty = false_expr.ty();
379 if true_ty != false_ty {
380 let common_ty = Expression::common_target_type_for_type_list(
381 [&true_ty, &false_ty].into_iter().cloned(),
382 );
383 if common_ty != true_ty {
384 true_expr =
385 convert_struct(std::mem::take(&mut true_expr), common_ty.clone())
386 }
387 if common_ty != false_ty {
388 false_expr = convert_struct(std::mem::take(&mut false_expr), common_ty)
389 }
390 }
391 let o = Expression::Condition {
392 condition: condition.into(),
393 true_expr: true_expr.into(),
394 false_expr: false_expr.into(),
395 };
396 codeblock_with_expr(pre_statements, o)
397 }
398 ExpressionResult::Return(r) => make_struct(
399 [(FIELD_CONDITION, Type::Bool, Expression::BoolLiteral(false))]
400 .into_iter()
401 .chain(r.map(|r| (FIELD_RETURNED, ret_ty.clone(), r)))
402 .chain(has_value(ty).then(|| {
403 (FIELD_ACTUAL, ty.clone(), Expression::default_value_for_type(ty))
404 })),
405 ),
406 ExpressionResult::ReturnObject { value, .. } => value,
407 }
408 }
409
410 fn map_value(self, f: impl FnOnce(Expression) -> Expression) -> Self {
411 match self {
412 ExpressionResult::Just(e) => ExpressionResult::Just(f(e)),
413 ExpressionResult::Return(e) => ExpressionResult::Return(e),
414 ExpressionResult::MaybeReturn {
415 pre_statements,
416 condition,
417 returned_value,
418 actual_value,
419 } => ExpressionResult::MaybeReturn {
420 pre_statements,
421 condition,
422 returned_value,
423 actual_value: actual_value.map(f),
424 },
425 ExpressionResult::ReturnObject { value, has_value, has_return_value } => {
426 if !has_value {
427 return ExpressionResult::ReturnObject { value, has_value, has_return_value };
428 }
429 static COUNT: std::sync::atomic::AtomicUsize =
430 std::sync::atomic::AtomicUsize::new(0);
431 let name = format_smolstr!(
432 "mapped_expression{}",
433 COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
434 );
435 let value_ty = value.ty();
436 let load = |field: &str| Expression::StructFieldAccess {
437 base: Box::new(Expression::ReadLocalVariable {
438 name: name.clone(),
439 ty: value_ty.clone(),
440 }),
441 name: field.into(),
442 };
443 let condition = (FIELD_CONDITION, Type::Bool, load(FIELD_CONDITION));
444 let actual = f(load(FIELD_ACTUAL));
445 let actual = (FIELD_ACTUAL, actual.ty(), actual);
446 let ret = has_return_value.then(|| {
447 let r = load(FIELD_RETURNED);
448 (FIELD_RETURNED, r.ty(), r)
449 });
450 ExpressionResult::ReturnObject {
451 value: Expression::CodeBlock(vec![
452 Expression::StoreLocalVariable { name, value: value.into() },
453 make_struct([condition, actual].into_iter().chain(ret.into_iter())),
454 ]),
455 has_value,
456 has_return_value,
457 }
458 }
459 }
460 }
461
462 fn has_value(&self) -> bool {
463 match self {
464 ExpressionResult::Just(expression) => has_value(&expression.ty()),
465 ExpressionResult::MaybeReturn { actual_value, .. } => {
466 actual_value.as_ref().is_some_and(|x| has_value(&x.ty()))
467 }
468 ExpressionResult::Return(..) => false,
469 ExpressionResult::ReturnObject { has_value, .. } => *has_value,
470 }
471 }
472}
473
474fn codeblock_with_expr(mut pre_statements: Vec<Expression>, expr: Expression) -> Expression {
475 if pre_statements.is_empty() {
476 expr
477 } else {
478 pre_statements.push(expr);
479 Expression::CodeBlock(pre_statements)
480 }
481}
482
483fn make_struct(it: impl Iterator<Item = (&'static str, Type, Expression)>) -> Expression {
484 let mut fields = BTreeMap::<SmolStr, Type>::new();
485 let mut values = HashMap::<SmolStr, Expression>::new();
486 let mut voids = Vec::new();
487 for (name, ty, expr) in it {
488 if !has_value(&ty) {
489 if ty != Type::Invalid {
490 voids.push(expr);
491 }
492 continue;
493 }
494 fields.insert(name.into(), ty);
495 values.insert(name.into(), expr);
496 }
497 codeblock_with_expr(
498 voids,
499 Expression::Struct {
500 ty: Rc::new(Struct { fields, name: None, node: None, rust_attributes: None }),
501 values,
502 },
503 )
504}
505
506fn convert_struct(from: Expression, to: Type) -> Expression {
509 let Type::Struct(to) = to else {
510 assert_eq!(to, Type::Invalid);
511 return Expression::Invalid;
512 };
513 if let Expression::Struct { mut values, .. } = from {
514 let mut new_values = HashMap::new();
515 for (key, ty) in &to.fields {
516 let (key, expression) = values
517 .remove_entry(key)
518 .unwrap_or_else(|| (key.clone(), Expression::default_value_for_type(ty)));
519 new_values.insert(key, expression);
520 }
521 return Expression::Struct { values: new_values, ty: to };
522 }
523 static COUNT: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0);
524 let var_name = format_smolstr!(
525 "tmpobj_ret_conv_{}",
526 COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
527 );
528 let from_ty = from.ty();
529 let mut new_values = HashMap::new();
530 let Type::Struct(from_s) = &from_ty else {
531 assert_eq!(from_ty, Type::Invalid);
532 return Expression::Invalid;
533 };
534 for (key, ty) in &to.fields {
535 let expression = if from_s.fields.contains_key(key) {
536 Expression::StructFieldAccess {
537 base: Box::new(Expression::ReadLocalVariable {
538 name: var_name.clone(),
539 ty: from_ty.clone(),
540 }),
541 name: key.clone(),
542 }
543 } else {
544 Expression::default_value_for_type(ty)
545 };
546 new_values.insert(key.clone(), expression);
547 }
548 Expression::CodeBlock(vec![
549 Expression::StoreLocalVariable { name: var_name, value: Box::new(from) },
550 Expression::Struct { values: new_values, ty: to },
551 ])
552}
553
554fn has_value(ty: &Type) -> bool {
555 !matches!(ty, Type::Void | Type::Invalid)
556}