i_slint_compiler/passes/
remove_return.rs

1// Copyright © SixtyFPS GmbH <info@slint.dev>
2// SPDX-License-Identifier: GPL-3.0-only OR LicenseRef-Slint-Royalty-free-2.0 OR LicenseRef-Slint-Software-3.0
3
4use smol_str::{format_smolstr, SmolStr};
5use std::collections::{BTreeMap, HashMap};
6use std::rc::Rc;
7
8use crate::expression_tree::Expression;
9use crate::langtype::{Struct, Type};
10
11pub fn remove_return(doc: &crate::object_tree::Document) {
12    doc.visit_all_used_components(|component| {
13        crate::object_tree::visit_all_expressions(component, |e, _| {
14            let mut ret_ty = None;
15            fn visit(e: &Expression, ret_ty: &mut Option<Type>) {
16                if ret_ty.is_some() {
17                    return;
18                }
19                match e {
20                    Expression::ReturnStatement(x) => {
21                        *ret_ty = Some(x.as_ref().map_or(Type::Void, |x| x.ty()));
22                    }
23                    _ => e.visit(|e| visit(e, ret_ty)),
24                };
25            }
26            visit(e, &mut ret_ty);
27            let Some(ret_ty) = ret_ty else { return };
28            let ctx = RemoveReturnContext { ret_ty };
29            *e = process_expression(std::mem::take(e), true, &ctx, &ctx.ret_ty)
30                .to_expression(&ctx.ret_ty);
31        })
32    });
33}
34
35fn process_expression(
36    e: Expression,
37    toplevel: bool,
38    ctx: &RemoveReturnContext,
39    ty: &Type,
40) -> ExpressionResult {
41    match e {
42        Expression::DebugHook { expression, .. } => {
43            process_expression(*expression, toplevel, ctx, ty)
44        }
45        Expression::ReturnStatement(expr) => ExpressionResult::Return(expr.map(|e| *e)),
46        Expression::CodeBlock(expr) => {
47            process_codeblock(expr.into_iter().peekable(), toplevel, ty, ctx)
48        }
49        Expression::Condition { condition, true_expr, false_expr } => {
50            let te = process_expression(*true_expr, false, ctx, ty);
51            let fe = process_expression(*false_expr, false, ctx, ty);
52            match (te, fe) {
53                (ExpressionResult::Just(te), ExpressionResult::Just(fe)) => {
54                    Expression::Condition { condition, true_expr: te.into(), false_expr: fe.into() }
55                        .into()
56                }
57                (ExpressionResult::Just(te), ExpressionResult::Return(fe)) => {
58                    ExpressionResult::MaybeReturn {
59                        pre_statements: vec![],
60                        condition: *condition,
61                        returned_value: fe,
62                        actual_value: cleanup_empty_block(te),
63                    }
64                }
65                (ExpressionResult::Return(te), ExpressionResult::Just(fe)) => {
66                    ExpressionResult::MaybeReturn {
67                        pre_statements: vec![],
68                        condition: Expression::UnaryOp { sub: condition, op: '!' },
69                        returned_value: te,
70                        actual_value: cleanup_empty_block(fe),
71                    }
72                }
73                (ExpressionResult::Return(te), ExpressionResult::Return(fe)) => {
74                    ExpressionResult::Return(Some(Expression::Condition {
75                        condition,
76                        true_expr: te.unwrap_or(Expression::CodeBlock(vec![])).into(),
77                        false_expr: fe.unwrap_or(Expression::CodeBlock(vec![])).into(),
78                    }))
79                }
80                (te, fe) => {
81                    let has_value = has_value(ty) && (te.has_value() || fe.has_value());
82                    let ty = if has_value { ty } else { &Type::Void };
83                    let te = te.into_return_object(ty, &ctx.ret_ty);
84                    let fe = fe.into_return_object(ty, &ctx.ret_ty);
85                    ExpressionResult::ReturnObject {
86                        has_value,
87                        has_return_value: self::has_value(&ctx.ret_ty),
88                        value: Expression::Condition {
89                            condition,
90                            true_expr: te.into(),
91                            false_expr: fe.into(),
92                        },
93                    }
94                }
95            }
96        }
97        Expression::Cast { from, to } => {
98            let ty = if !has_value(ty) { ty.clone() } else { from.ty() };
99            process_expression(*from, toplevel, ctx, &ty)
100                .map_value(|e| Expression::Cast { from: e.into(), to })
101        }
102        e => {
103            // Normally there shouldn't be any 'return' statements in there since return are not allowed in arbitrary expressions
104            #[cfg(debug_assertions)]
105            {
106                e.visit_recursive(&mut |e| assert!(!matches!(e, Expression::ReturnStatement(_))));
107            }
108            ExpressionResult::Just(e)
109        }
110    }
111}
112
113/// Return the expression, unless it is an empty codeblock, then return None
114fn cleanup_empty_block(te: Expression) -> Option<Expression> {
115    if matches!(&te, Expression::CodeBlock(stmts) if stmts.is_empty()) {
116        None
117    } else {
118        Some(te)
119    }
120}
121
122fn process_codeblock(
123    mut iter: std::iter::Peekable<impl Iterator<Item = Expression>>,
124    toplevel: bool,
125    ty: &Type,
126    ctx: &RemoveReturnContext,
127) -> ExpressionResult {
128    let mut stmts = vec![];
129    while let Some(e) = iter.next() {
130        let is_last = iter.peek().is_none();
131        match process_expression(e, toplevel, ctx, if is_last { ty } else { &Type::Void }) {
132            ExpressionResult::Just(x) => stmts.push(x),
133            ExpressionResult::Return(x) => {
134                stmts.extend(x);
135                return ExpressionResult::Return(
136                    (!stmts.is_empty()).then_some(Expression::CodeBlock(stmts)),
137                );
138            }
139            ExpressionResult::MaybeReturn {
140                mut pre_statements,
141                condition,
142                returned_value,
143                actual_value,
144            } => {
145                stmts.append(&mut pre_statements);
146                if is_last {
147                    return ExpressionResult::MaybeReturn {
148                        pre_statements: stmts,
149                        condition,
150                        returned_value,
151                        actual_value,
152                    };
153                } else if toplevel {
154                    let rest = process_codeblock(iter, true, ty, ctx).to_expression(&ctx.ret_ty);
155                    let mut rest_ex = Expression::CodeBlock(
156                        actual_value.into_iter().chain(core::iter::once(rest)).collect(),
157                    );
158                    if rest_ex.ty() != ctx.ret_ty {
159                        rest_ex =
160                            Expression::Cast { from: Box::new(rest_ex), to: ctx.ret_ty.clone() }
161                    }
162                    return ExpressionResult::MaybeReturn {
163                        pre_statements: stmts,
164                        condition,
165                        returned_value,
166                        actual_value: Some(rest_ex),
167                    };
168                } else {
169                    return continue_codeblock(
170                        iter,
171                        ty,
172                        ctx,
173                        ExpressionResult::MaybeReturn {
174                            pre_statements: vec![],
175                            condition,
176                            returned_value,
177                            actual_value,
178                        }
179                        .into_return_object(ty, &ctx.ret_ty),
180                        stmts,
181                        has_value(&ctx.ret_ty),
182                    );
183                }
184            }
185            ExpressionResult::ReturnObject { value, has_value, has_return_value } => {
186                if is_last {
187                    return ExpressionResult::ReturnObject {
188                        value: codeblock_with_expr(stmts, value),
189                        has_value,
190                        has_return_value,
191                    };
192                } else {
193                    return continue_codeblock(iter, ty, ctx, value, stmts, has_return_value);
194                }
195            }
196        }
197    }
198    ExpressionResult::Just(Expression::CodeBlock(stmts))
199}
200
201fn continue_codeblock(
202    iter: std::iter::Peekable<impl Iterator<Item = Expression>>,
203    ty: &Type,
204    ctx: &RemoveReturnContext,
205    return_object: Expression,
206    mut stmts: Vec<Expression>,
207    has_return_value: bool,
208) -> ExpressionResult {
209    let rest = process_codeblock(iter, false, ty, ctx).into_return_object(ty, &ctx.ret_ty);
210    static COUNT: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0);
211    let unique_name = format_smolstr!(
212        "return_check_merge{}",
213        COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
214    );
215    let load = Box::new(Expression::ReadLocalVariable {
216        name: unique_name.clone(),
217        ty: return_object.ty(),
218    });
219    stmts.push(Expression::StoreLocalVariable { name: unique_name, value: return_object.into() });
220    stmts.push(Expression::Condition {
221        condition: Expression::StructFieldAccess {
222            base: load.clone(),
223            name: FIELD_CONDITION.into(),
224        }
225        .into(),
226        true_expr: rest.into(),
227        false_expr: ExpressionResult::Return(has_return_value.then(|| {
228            Expression::StructFieldAccess { base: load.clone(), name: FIELD_RETURNED.into() }
229        }))
230        .into_return_object(ty, &ctx.ret_ty)
231        .into(),
232    });
233    ExpressionResult::ReturnObject {
234        value: Expression::CodeBlock(stmts),
235        has_value: has_value(ty),
236        has_return_value,
237    }
238}
239
240struct RemoveReturnContext {
241    ret_ty: Type,
242}
243
244#[derive(Debug)]
245enum ExpressionResult {
246    /// The expression maps directly to a LLR expression
247    Just(Expression),
248    /// The expression used `return` so we need to check for the return slot
249    MaybeReturn {
250        /// Some statements that initializes some temporary variable (eg arguments to something called later)
251        pre_statements: Vec<Expression>,
252        /// Boolean expression: false means return
253        condition: Expression,
254        /// Value being returned if condition is false
255        returned_value: Option<Expression>,
256        /// The value when we don't return
257        actual_value: Option<Expression>,
258    },
259    /// The expression returns unconditionally
260    Return(Option<Expression>),
261    /// The expression is of type `{ condition: bool, actual: ty, returned: ret_ty}`
262    /// which is the equivalent of `if condition { actual } else { return R }`
263    ReturnObject { value: Expression, has_value: bool, has_return_value: bool },
264}
265
266impl From<Expression> for ExpressionResult {
267    fn from(v: Expression) -> Self {
268        Self::Just(v)
269    }
270}
271
272const FIELD_CONDITION: &str = "condition";
273const FIELD_ACTUAL: &str = "actual";
274const FIELD_RETURNED: &str = "returned";
275
276impl ExpressionResult {
277    fn to_expression(self, ty: &Type) -> Expression {
278        match self {
279            ExpressionResult::Just(e) => e,
280            ExpressionResult::Return(e) => e.unwrap_or(Expression::CodeBlock(vec![])),
281            ExpressionResult::MaybeReturn {
282                mut pre_statements,
283                condition,
284                returned_value,
285                actual_value,
286            } => {
287                pre_statements.push(Expression::Condition {
288                    condition: condition.into(),
289                    true_expr: actual_value.unwrap_or(Expression::CodeBlock(vec![])).into(),
290                    false_expr: returned_value.unwrap_or(Expression::CodeBlock(vec![])).into(),
291                });
292                Expression::CodeBlock(pre_statements)
293            }
294            ExpressionResult::ReturnObject { value, has_value, has_return_value } => {
295                static COUNT: std::sync::atomic::AtomicUsize =
296                    std::sync::atomic::AtomicUsize::new(0);
297                let name = format_smolstr!(
298                    "returned_expression{}",
299                    COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
300                );
301                let load =
302                    Box::new(Expression::ReadLocalVariable { name: name.clone(), ty: value.ty() });
303                Expression::CodeBlock(vec![
304                    Expression::StoreLocalVariable { name, value: value.into() },
305                    Expression::Condition {
306                        condition: Expression::StructFieldAccess {
307                            base: load.clone(),
308                            name: FIELD_CONDITION.into(),
309                        }
310                        .into(),
311                        true_expr: if has_value {
312                            Expression::StructFieldAccess {
313                                base: load.clone(),
314                                name: FIELD_ACTUAL.into(),
315                            }
316                        } else {
317                            Expression::default_value_for_type(ty)
318                        }
319                        .into(),
320                        false_expr: if has_return_value {
321                            Expression::StructFieldAccess {
322                                base: load.clone(),
323                                name: FIELD_RETURNED.into(),
324                            }
325                        } else {
326                            Expression::default_value_for_type(ty)
327                        }
328                        .into(),
329                    },
330                ])
331            }
332        }
333    }
334
335    fn into_return_object(self, ty: &Type, ret_ty: &Type) -> Expression {
336        match self {
337            ExpressionResult::Just(e) => {
338                let ret_value = Expression::default_value_for_type(ret_ty);
339                if has_value(ty) {
340                    make_struct(
341                        [
342                            (FIELD_CONDITION, Type::Bool, Expression::BoolLiteral(true)),
343                            (FIELD_RETURNED, ret_ty.clone(), ret_value),
344                            (FIELD_ACTUAL, e.ty(), e),
345                        ]
346                        .into_iter(),
347                    )
348                } else {
349                    let object = make_struct(
350                        [
351                            (FIELD_CONDITION, Type::Bool, Expression::BoolLiteral(true)),
352                            (FIELD_RETURNED, ret_ty.clone(), ret_value),
353                        ]
354                        .into_iter(),
355                    );
356                    if e.is_constant(None) {
357                        object
358                    } else {
359                        Expression::CodeBlock(vec![e, object])
360                    }
361                }
362            }
363            ExpressionResult::MaybeReturn {
364                pre_statements,
365                condition,
366                returned_value,
367                actual_value,
368            } => {
369                let mut true_expr = match actual_value {
370                    Some(e) => ExpressionResult::Just(e).into_return_object(ty, ret_ty),
371                    None => make_struct(
372                        [(FIELD_CONDITION, Type::Bool, Expression::BoolLiteral(true))].into_iter(),
373                    ),
374                };
375                let mut false_expr =
376                    ExpressionResult::Return(returned_value).into_return_object(ty, ret_ty);
377                let true_ty = true_expr.ty();
378                let false_ty = false_expr.ty();
379                if true_ty != false_ty {
380                    let common_ty = Expression::common_target_type_for_type_list(
381                        [&true_ty, &false_ty].into_iter().cloned(),
382                    );
383                    if common_ty != true_ty {
384                        true_expr =
385                            convert_struct(std::mem::take(&mut true_expr), common_ty.clone())
386                    }
387                    if common_ty != false_ty {
388                        false_expr = convert_struct(std::mem::take(&mut false_expr), common_ty)
389                    }
390                }
391                let o = Expression::Condition {
392                    condition: condition.into(),
393                    true_expr: true_expr.into(),
394                    false_expr: false_expr.into(),
395                };
396                codeblock_with_expr(pre_statements, o)
397            }
398            ExpressionResult::Return(r) => make_struct(
399                [(FIELD_CONDITION, Type::Bool, Expression::BoolLiteral(false))]
400                    .into_iter()
401                    .chain(r.map(|r| (FIELD_RETURNED, ret_ty.clone(), r)))
402                    .chain(has_value(ty).then(|| {
403                        (FIELD_ACTUAL, ty.clone(), Expression::default_value_for_type(ty))
404                    })),
405            ),
406            ExpressionResult::ReturnObject { value, .. } => value,
407        }
408    }
409
410    fn map_value(self, f: impl FnOnce(Expression) -> Expression) -> Self {
411        match self {
412            ExpressionResult::Just(e) => ExpressionResult::Just(f(e)),
413            ExpressionResult::Return(e) => ExpressionResult::Return(e),
414            ExpressionResult::MaybeReturn {
415                pre_statements,
416                condition,
417                returned_value,
418                actual_value,
419            } => ExpressionResult::MaybeReturn {
420                pre_statements,
421                condition,
422                returned_value,
423                actual_value: actual_value.map(f),
424            },
425            ExpressionResult::ReturnObject { value, has_value, has_return_value } => {
426                if !has_value {
427                    return ExpressionResult::ReturnObject { value, has_value, has_return_value };
428                }
429                static COUNT: std::sync::atomic::AtomicUsize =
430                    std::sync::atomic::AtomicUsize::new(0);
431                let name = format_smolstr!(
432                    "mapped_expression{}",
433                    COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
434                );
435                let value_ty = value.ty();
436                let load = |field: &str| Expression::StructFieldAccess {
437                    base: Box::new(Expression::ReadLocalVariable {
438                        name: name.clone(),
439                        ty: value_ty.clone(),
440                    }),
441                    name: field.into(),
442                };
443                let condition = (FIELD_CONDITION, Type::Bool, load(FIELD_CONDITION));
444                let actual = f(load(FIELD_ACTUAL));
445                let actual = (FIELD_ACTUAL, actual.ty(), actual);
446                let ret = has_return_value.then(|| {
447                    let r = load(FIELD_RETURNED);
448                    (FIELD_RETURNED, r.ty(), r)
449                });
450                ExpressionResult::ReturnObject {
451                    value: Expression::CodeBlock(vec![
452                        Expression::StoreLocalVariable { name, value: value.into() },
453                        make_struct([condition, actual].into_iter().chain(ret.into_iter())),
454                    ]),
455                    has_value,
456                    has_return_value,
457                }
458            }
459        }
460    }
461
462    fn has_value(&self) -> bool {
463        match self {
464            ExpressionResult::Just(expression) => has_value(&expression.ty()),
465            ExpressionResult::MaybeReturn { actual_value, .. } => {
466                actual_value.as_ref().is_some_and(|x| has_value(&x.ty()))
467            }
468            ExpressionResult::Return(..) => false,
469            ExpressionResult::ReturnObject { has_value, .. } => *has_value,
470        }
471    }
472}
473
474fn codeblock_with_expr(mut pre_statements: Vec<Expression>, expr: Expression) -> Expression {
475    if pre_statements.is_empty() {
476        expr
477    } else {
478        pre_statements.push(expr);
479        Expression::CodeBlock(pre_statements)
480    }
481}
482
483fn make_struct(it: impl Iterator<Item = (&'static str, Type, Expression)>) -> Expression {
484    let mut fields = BTreeMap::<SmolStr, Type>::new();
485    let mut values = HashMap::<SmolStr, Expression>::new();
486    let mut voids = Vec::new();
487    for (name, ty, expr) in it {
488        if !has_value(&ty) {
489            if ty != Type::Invalid {
490                voids.push(expr);
491            }
492            continue;
493        }
494        fields.insert(name.into(), ty);
495        values.insert(name.into(), expr);
496    }
497    codeblock_with_expr(
498        voids,
499        Expression::Struct {
500            ty: Rc::new(Struct { fields, name: None, node: None, rust_attributes: None }),
501            values,
502        },
503    )
504}
505
506/// Given an expression `from` of type Struct, convert to another type struct with more fields
507/// Add missing members in `from`
508fn convert_struct(from: Expression, to: Type) -> Expression {
509    let Type::Struct(to) = to else {
510        assert_eq!(to, Type::Invalid);
511        return Expression::Invalid;
512    };
513    if let Expression::Struct { mut values, .. } = from {
514        let mut new_values = HashMap::new();
515        for (key, ty) in &to.fields {
516            let (key, expression) = values
517                .remove_entry(key)
518                .unwrap_or_else(|| (key.clone(), Expression::default_value_for_type(ty)));
519            new_values.insert(key, expression);
520        }
521        return Expression::Struct { values: new_values, ty: to };
522    }
523    static COUNT: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0);
524    let var_name = format_smolstr!(
525        "tmpobj_ret_conv_{}",
526        COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
527    );
528    let from_ty = from.ty();
529    let mut new_values = HashMap::new();
530    let Type::Struct(from_s) = &from_ty else {
531        assert_eq!(from_ty, Type::Invalid);
532        return Expression::Invalid;
533    };
534    for (key, ty) in &to.fields {
535        let expression = if from_s.fields.contains_key(key) {
536            Expression::StructFieldAccess {
537                base: Box::new(Expression::ReadLocalVariable {
538                    name: var_name.clone(),
539                    ty: from_ty.clone(),
540                }),
541                name: key.clone(),
542            }
543        } else {
544            Expression::default_value_for_type(ty)
545        };
546        new_values.insert(key.clone(), expression);
547    }
548    Expression::CodeBlock(vec![
549        Expression::StoreLocalVariable { name: var_name, value: Box::new(from) },
550        Expression::Struct { values: new_values, ty: to },
551    ])
552}
553
554fn has_value(ty: &Type) -> bool {
555    !matches!(ty, Type::Void | Type::Invalid)
556}