Skip to main content

i_slint_compiler/passes/
remove_return.rs

1// Copyright © SixtyFPS GmbH <info@slint.dev>
2// SPDX-License-Identifier: GPL-3.0-only OR LicenseRef-Slint-Royalty-free-2.0 OR LicenseRef-Slint-Software-3.0
3
4use smol_str::{SmolStr, format_smolstr};
5use std::collections::{BTreeMap, HashMap};
6use std::rc::Rc;
7
8use crate::expression_tree::Expression;
9use crate::langtype::{Struct, StructName, Type};
10
11pub fn remove_return(doc: &crate::object_tree::Document) {
12    doc.visit_all_used_components(|component| {
13        crate::object_tree::visit_all_expressions(component, |e, _| {
14            let mut ret_ty = None;
15            fn visit(e: &Expression, ret_ty: &mut Option<Type>) {
16                if ret_ty.is_some() {
17                    return;
18                }
19                match e {
20                    Expression::ReturnStatement(x) => {
21                        *ret_ty = Some(x.as_ref().map_or(Type::Void, |x| x.ty()));
22                    }
23                    _ => e.visit(|e| visit(e, ret_ty)),
24                };
25            }
26            visit(e, &mut ret_ty);
27            let Some(ret_ty) = ret_ty else { return };
28            let ctx = RemoveReturnContext { ret_ty };
29            *e = process_expression(std::mem::take(e), true, &ctx, &ctx.ret_ty)
30                .into_expression(&ctx.ret_ty);
31        })
32    });
33}
34
35fn process_expression(
36    e: Expression,
37    toplevel: bool,
38    ctx: &RemoveReturnContext,
39    ty: &Type,
40) -> ExpressionResult {
41    match e {
42        Expression::DebugHook { expression, .. } => {
43            process_expression(*expression, toplevel, ctx, ty)
44        }
45        Expression::ReturnStatement(expr) => ExpressionResult::Return(expr.map(|e| *e)),
46        Expression::CodeBlock(expr) => {
47            process_codeblock(expr.into_iter().peekable(), toplevel, ty, ctx)
48        }
49        Expression::Condition { condition, true_expr, false_expr } => {
50            let te = process_expression(*true_expr, false, ctx, ty);
51            let fe = process_expression(*false_expr, false, ctx, ty);
52            match (te, fe) {
53                (ExpressionResult::Just(te), ExpressionResult::Just(fe)) => {
54                    Expression::Condition { condition, true_expr: te.into(), false_expr: fe.into() }
55                        .into()
56                }
57                (ExpressionResult::Just(te), ExpressionResult::Return(fe)) => {
58                    ExpressionResult::MaybeReturn {
59                        pre_statements: Vec::new(),
60                        condition: *condition,
61                        returned_value: fe,
62                        actual_value: cleanup_empty_block(te),
63                    }
64                }
65                (ExpressionResult::Return(te), ExpressionResult::Just(fe)) => {
66                    ExpressionResult::MaybeReturn {
67                        pre_statements: Vec::new(),
68                        condition: Expression::UnaryOp { sub: condition, op: '!' },
69                        returned_value: te,
70                        actual_value: cleanup_empty_block(fe),
71                    }
72                }
73                (ExpressionResult::Return(te), ExpressionResult::Return(fe)) => {
74                    ExpressionResult::Return(Some(Expression::Condition {
75                        condition,
76                        true_expr: te.unwrap_or(Expression::CodeBlock(Vec::new())).into(),
77                        false_expr: fe.unwrap_or(Expression::CodeBlock(Vec::new())).into(),
78                    }))
79                }
80                (te, fe) => {
81                    let has_value = has_value(ty) && (te.has_value() || fe.has_value());
82                    let ty = if has_value { ty } else { &Type::Void };
83                    let te = te.into_return_object(ty, &ctx.ret_ty);
84                    let fe = fe.into_return_object(ty, &ctx.ret_ty);
85                    ExpressionResult::ReturnObject {
86                        has_value,
87                        has_return_value: self::has_value(&ctx.ret_ty),
88                        value: Expression::Condition {
89                            condition,
90                            true_expr: te.into(),
91                            false_expr: fe.into(),
92                        },
93                    }
94                }
95            }
96        }
97        Expression::Cast { from, to } => {
98            let ty = if !has_value(ty) { ty.clone() } else { from.ty() };
99            process_expression(*from, toplevel, ctx, &ty)
100                .map_value(|e| Expression::Cast { from: e.into(), to })
101        }
102        Expression::StoreLocalVariable { name, value } => {
103            let inner_ty = value.ty();
104            match process_expression(*value, false, ctx, &inner_ty) {
105                ExpressionResult::Just(e) => {
106                    ExpressionResult::Just(Expression::StoreLocalVariable {
107                        name,
108                        value: Box::new(e),
109                    })
110                }
111                ExpressionResult::Return(r) => ExpressionResult::Return(r),
112                ExpressionResult::MaybeReturn {
113                    pre_statements,
114                    condition,
115                    returned_value,
116                    actual_value,
117                } => ExpressionResult::MaybeReturn {
118                    pre_statements,
119                    condition,
120                    returned_value,
121                    actual_value: Some(Expression::StoreLocalVariable {
122                        name,
123                        value: Box::new(
124                            actual_value.unwrap_or(Expression::default_value_for_type(&inner_ty)),
125                        ),
126                    }),
127                },
128                ExpressionResult::ReturnObject { value, has_return_value, .. } => {
129                    static COUNT: std::sync::atomic::AtomicUsize =
130                        std::sync::atomic::AtomicUsize::new(0);
131                    let tmp_name: SmolStr = format_smolstr!(
132                        "return_check_store{}",
133                        COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
134                    );
135                    let value_ty = value.ty();
136                    let load = |field: &str| Expression::StructFieldAccess {
137                        base: Box::new(Expression::ReadLocalVariable {
138                            name: tmp_name.clone(),
139                            ty: value_ty.clone(),
140                        }),
141                        name: field.into(),
142                    };
143                    let condition = load(FIELD_CONDITION);
144                    let returned_value = has_return_value.then(|| load(FIELD_RETURNED));
145                    let actual_value = Some(Expression::StoreLocalVariable {
146                        name,
147                        value: Box::new(load(FIELD_ACTUAL)),
148                    });
149                    ExpressionResult::MaybeReturn {
150                        pre_statements: vec![Expression::StoreLocalVariable {
151                            name: tmp_name,
152                            value: Box::new(value),
153                        }],
154                        condition,
155                        returned_value,
156                        actual_value,
157                    }
158                }
159            }
160        }
161        e => {
162            // Normally there shouldn't be any 'return' statements in there since return are not allowed in arbitrary expressions
163            #[cfg(debug_assertions)]
164            {
165                e.visit_recursive(&mut |e| assert!(!matches!(e, Expression::ReturnStatement(_))));
166            }
167            ExpressionResult::Just(e)
168        }
169    }
170}
171
172/// Return the expression, unless it is an empty codeblock, then return None
173fn cleanup_empty_block(te: Expression) -> Option<Expression> {
174    if matches!(&te, Expression::CodeBlock(stmts) if stmts.is_empty()) { None } else { Some(te) }
175}
176
177fn process_codeblock(
178    mut iter: std::iter::Peekable<impl Iterator<Item = Expression>>,
179    toplevel: bool,
180    ty: &Type,
181    ctx: &RemoveReturnContext,
182) -> ExpressionResult {
183    let mut stmts = Vec::new();
184    while let Some(e) = iter.next() {
185        let is_last = iter.peek().is_none();
186        match process_expression(e, toplevel, ctx, if is_last { ty } else { &Type::Void }) {
187            ExpressionResult::Just(x) => stmts.push(x),
188            ExpressionResult::Return(x) => {
189                stmts.extend(x);
190                return ExpressionResult::Return(
191                    (!stmts.is_empty()).then_some(Expression::CodeBlock(stmts)),
192                );
193            }
194            ExpressionResult::MaybeReturn {
195                mut pre_statements,
196                condition,
197                returned_value,
198                actual_value,
199            } => {
200                stmts.append(&mut pre_statements);
201                if is_last {
202                    return ExpressionResult::MaybeReturn {
203                        pre_statements: stmts,
204                        condition,
205                        returned_value,
206                        actual_value,
207                    };
208                } else if toplevel {
209                    let rest = process_codeblock(iter, true, ty, ctx).into_expression(&ctx.ret_ty);
210                    let mut rest_ex = Expression::CodeBlock(
211                        actual_value.into_iter().chain(core::iter::once(rest)).collect(),
212                    );
213                    if rest_ex.ty() != ctx.ret_ty {
214                        rest_ex =
215                            Expression::Cast { from: Box::new(rest_ex), to: ctx.ret_ty.clone() }
216                    }
217                    return ExpressionResult::MaybeReturn {
218                        pre_statements: stmts,
219                        condition,
220                        returned_value,
221                        actual_value: Some(rest_ex),
222                    };
223                } else {
224                    return continue_codeblock(
225                        iter,
226                        ty,
227                        ctx,
228                        ExpressionResult::MaybeReturn {
229                            pre_statements: Vec::new(),
230                            condition,
231                            returned_value,
232                            actual_value,
233                        }
234                        .into_return_object(ty, &ctx.ret_ty),
235                        stmts,
236                        has_value(&ctx.ret_ty),
237                    );
238                }
239            }
240            ExpressionResult::ReturnObject { value, has_value, has_return_value } => {
241                if is_last {
242                    return ExpressionResult::ReturnObject {
243                        value: codeblock_with_expr(stmts, value),
244                        has_value,
245                        has_return_value,
246                    };
247                } else {
248                    return continue_codeblock(iter, ty, ctx, value, stmts, has_return_value);
249                }
250            }
251        }
252    }
253    ExpressionResult::Just(Expression::CodeBlock(stmts))
254}
255
256fn continue_codeblock(
257    iter: std::iter::Peekable<impl Iterator<Item = Expression>>,
258    ty: &Type,
259    ctx: &RemoveReturnContext,
260    return_object: Expression,
261    mut stmts: Vec<Expression>,
262    has_return_value: bool,
263) -> ExpressionResult {
264    let rest = process_codeblock(iter, false, ty, ctx).into_return_object(ty, &ctx.ret_ty);
265    static COUNT: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0);
266    let unique_name = format_smolstr!(
267        "return_check_merge{}",
268        COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
269    );
270    let load = Box::new(Expression::ReadLocalVariable {
271        name: unique_name.clone(),
272        ty: return_object.ty(),
273    });
274    stmts.push(Expression::StoreLocalVariable { name: unique_name, value: return_object.into() });
275    stmts.push(Expression::Condition {
276        condition: Expression::StructFieldAccess {
277            base: load.clone(),
278            name: FIELD_CONDITION.into(),
279        }
280        .into(),
281        true_expr: rest.into(),
282        false_expr: ExpressionResult::Return(has_return_value.then(|| {
283            Expression::StructFieldAccess { base: load.clone(), name: FIELD_RETURNED.into() }
284        }))
285        .into_return_object(ty, &ctx.ret_ty)
286        .into(),
287    });
288    ExpressionResult::ReturnObject {
289        value: Expression::CodeBlock(stmts),
290        has_value: has_value(ty),
291        has_return_value,
292    }
293}
294
295struct RemoveReturnContext {
296    ret_ty: Type,
297}
298
299#[derive(Debug)]
300#[allow(clippy::large_enum_variant)]
301enum ExpressionResult {
302    /// The expression maps directly to a LLR expression
303    Just(Expression),
304    /// The expression used `return` so we need to check for the return slot
305    MaybeReturn {
306        /// Some statements that initializes some temporary variable (eg arguments to something called later)
307        pre_statements: Vec<Expression>,
308        /// Boolean expression: false means return
309        condition: Expression,
310        /// Value being returned if condition is false
311        returned_value: Option<Expression>,
312        /// The value when we don't return
313        actual_value: Option<Expression>,
314    },
315    /// The expression returns unconditionally
316    Return(Option<Expression>),
317    /// The expression is of type `{ condition: bool, actual: ty, returned: ret_ty}`
318    /// which is the equivalent of `if condition { actual } else { return R }`
319    ReturnObject { value: Expression, has_value: bool, has_return_value: bool },
320}
321
322impl From<Expression> for ExpressionResult {
323    fn from(v: Expression) -> Self {
324        Self::Just(v)
325    }
326}
327
328const FIELD_CONDITION: &str = "condition";
329const FIELD_ACTUAL: &str = "actual";
330const FIELD_RETURNED: &str = "returned";
331
332impl ExpressionResult {
333    fn into_expression(self, ty: &Type) -> Expression {
334        match self {
335            ExpressionResult::Just(e) => e,
336            ExpressionResult::Return(e) => e.unwrap_or(Expression::CodeBlock(Vec::new())),
337            ExpressionResult::MaybeReturn {
338                mut pre_statements,
339                condition,
340                returned_value,
341                actual_value,
342            } => {
343                pre_statements.push(Expression::Condition {
344                    condition: condition.into(),
345                    true_expr: actual_value.unwrap_or(Expression::CodeBlock(Vec::new())).into(),
346                    false_expr: returned_value.unwrap_or(Expression::CodeBlock(Vec::new())).into(),
347                });
348                Expression::CodeBlock(pre_statements)
349            }
350            ExpressionResult::ReturnObject { value, has_value, has_return_value } => {
351                static COUNT: std::sync::atomic::AtomicUsize =
352                    std::sync::atomic::AtomicUsize::new(0);
353                let name = format_smolstr!(
354                    "returned_expression{}",
355                    COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
356                );
357                let load =
358                    Box::new(Expression::ReadLocalVariable { name: name.clone(), ty: value.ty() });
359                Expression::CodeBlock(vec![
360                    Expression::StoreLocalVariable { name, value: value.into() },
361                    Expression::Condition {
362                        condition: Expression::StructFieldAccess {
363                            base: load.clone(),
364                            name: FIELD_CONDITION.into(),
365                        }
366                        .into(),
367                        true_expr: if has_value {
368                            Expression::StructFieldAccess {
369                                base: load.clone(),
370                                name: FIELD_ACTUAL.into(),
371                            }
372                        } else {
373                            Expression::default_value_for_type(ty)
374                        }
375                        .into(),
376                        false_expr: if has_return_value {
377                            Expression::StructFieldAccess {
378                                base: load.clone(),
379                                name: FIELD_RETURNED.into(),
380                            }
381                        } else {
382                            Expression::default_value_for_type(ty)
383                        }
384                        .into(),
385                    },
386                ])
387            }
388        }
389    }
390
391    fn into_return_object(self, ty: &Type, ret_ty: &Type) -> Expression {
392        match self {
393            ExpressionResult::Just(e) => {
394                let ret_value = Expression::default_value_for_type(ret_ty);
395                if has_value(ty) {
396                    make_struct(
397                        [
398                            (FIELD_CONDITION, Type::Bool, Expression::BoolLiteral(true)),
399                            (FIELD_RETURNED, ret_ty.clone(), ret_value),
400                            (FIELD_ACTUAL, e.ty(), e),
401                        ]
402                        .into_iter(),
403                    )
404                } else {
405                    let object = make_struct(
406                        [
407                            (FIELD_CONDITION, Type::Bool, Expression::BoolLiteral(true)),
408                            (FIELD_RETURNED, ret_ty.clone(), ret_value),
409                        ]
410                        .into_iter(),
411                    );
412                    if e.is_constant(None) {
413                        object
414                    } else {
415                        Expression::CodeBlock(vec![e, object])
416                    }
417                }
418            }
419            ExpressionResult::MaybeReturn {
420                pre_statements,
421                condition,
422                returned_value,
423                actual_value,
424            } => {
425                let mut true_expr = match actual_value {
426                    Some(e) => ExpressionResult::Just(e).into_return_object(ty, ret_ty),
427                    None => make_struct(
428                        [(FIELD_CONDITION, Type::Bool, Expression::BoolLiteral(true))].into_iter(),
429                    ),
430                };
431                let mut false_expr =
432                    ExpressionResult::Return(returned_value).into_return_object(ty, ret_ty);
433                let true_ty = true_expr.ty();
434                let false_ty = false_expr.ty();
435                if true_ty != false_ty {
436                    let common_ty = Expression::common_target_type_for_type_list(
437                        [&true_ty, &false_ty].into_iter().cloned(),
438                    );
439                    if common_ty != true_ty {
440                        true_expr =
441                            convert_struct(std::mem::take(&mut true_expr), common_ty.clone())
442                    }
443                    if common_ty != false_ty {
444                        false_expr = convert_struct(std::mem::take(&mut false_expr), common_ty)
445                    }
446                }
447                let o = Expression::Condition {
448                    condition: condition.into(),
449                    true_expr: true_expr.into(),
450                    false_expr: false_expr.into(),
451                };
452                codeblock_with_expr(pre_statements, o)
453            }
454            ExpressionResult::Return(r) => make_struct(
455                [(FIELD_CONDITION, Type::Bool, Expression::BoolLiteral(false))]
456                    .into_iter()
457                    .chain(r.map(|r| (FIELD_RETURNED, ret_ty.clone(), r)))
458                    .chain(has_value(ty).then(|| {
459                        (FIELD_ACTUAL, ty.clone(), Expression::default_value_for_type(ty))
460                    })),
461            ),
462            ExpressionResult::ReturnObject { value, .. } => value,
463        }
464    }
465
466    fn map_value(self, f: impl FnOnce(Expression) -> Expression) -> Self {
467        match self {
468            ExpressionResult::Just(e) => ExpressionResult::Just(f(e)),
469            ExpressionResult::Return(e) => ExpressionResult::Return(e),
470            ExpressionResult::MaybeReturn {
471                pre_statements,
472                condition,
473                returned_value,
474                actual_value,
475            } => ExpressionResult::MaybeReturn {
476                pre_statements,
477                condition,
478                returned_value,
479                actual_value: actual_value.map(f),
480            },
481            ExpressionResult::ReturnObject { value, has_value, has_return_value } => {
482                if !has_value {
483                    return ExpressionResult::ReturnObject { value, has_value, has_return_value };
484                }
485                static COUNT: std::sync::atomic::AtomicUsize =
486                    std::sync::atomic::AtomicUsize::new(0);
487                let name = format_smolstr!(
488                    "mapped_expression{}",
489                    COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
490                );
491                let value_ty = value.ty();
492                let load = |field: &str| Expression::StructFieldAccess {
493                    base: Box::new(Expression::ReadLocalVariable {
494                        name: name.clone(),
495                        ty: value_ty.clone(),
496                    }),
497                    name: field.into(),
498                };
499                let condition = (FIELD_CONDITION, Type::Bool, load(FIELD_CONDITION));
500                let actual = f(load(FIELD_ACTUAL));
501                let actual = (FIELD_ACTUAL, actual.ty(), actual);
502                let ret = has_return_value.then(|| {
503                    let r = load(FIELD_RETURNED);
504                    (FIELD_RETURNED, r.ty(), r)
505                });
506                ExpressionResult::ReturnObject {
507                    value: Expression::CodeBlock(vec![
508                        Expression::StoreLocalVariable { name, value: value.into() },
509                        make_struct([condition, actual].into_iter().chain(ret.into_iter())),
510                    ]),
511                    has_value,
512                    has_return_value,
513                }
514            }
515        }
516    }
517
518    fn has_value(&self) -> bool {
519        match self {
520            ExpressionResult::Just(expression) => has_value(&expression.ty()),
521            ExpressionResult::MaybeReturn { actual_value, .. } => {
522                actual_value.as_ref().is_some_and(|x| has_value(&x.ty()))
523            }
524            ExpressionResult::Return(..) => false,
525            ExpressionResult::ReturnObject { has_value, .. } => *has_value,
526        }
527    }
528}
529
530fn codeblock_with_expr(mut pre_statements: Vec<Expression>, expr: Expression) -> Expression {
531    if pre_statements.is_empty() {
532        expr
533    } else {
534        pre_statements.push(expr);
535        Expression::CodeBlock(pre_statements)
536    }
537}
538
539fn make_struct(it: impl Iterator<Item = (&'static str, Type, Expression)>) -> Expression {
540    let mut fields = BTreeMap::<SmolStr, Type>::new();
541    let mut values = HashMap::<SmolStr, Expression>::new();
542    let mut voids = Vec::new();
543    for (name, ty, expr) in it {
544        if !has_value(&ty) {
545            if ty != Type::Invalid {
546                voids.push(expr);
547            }
548            continue;
549        }
550        fields.insert(name.into(), ty);
551        values.insert(name.into(), expr);
552    }
553    codeblock_with_expr(
554        voids,
555        Expression::Struct { ty: Rc::new(Struct { fields, name: StructName::None }), values },
556    )
557}
558
559/// Given an expression `from` of type Struct, convert to another type struct with more fields
560/// Add missing members in `from`
561fn convert_struct(from: Expression, to: Type) -> Expression {
562    let Type::Struct(to) = to else {
563        assert_eq!(to, Type::Invalid);
564        return Expression::Invalid;
565    };
566    if let Expression::Struct { mut values, .. } = from {
567        let mut new_values = HashMap::new();
568        for (key, ty) in &to.fields {
569            let (key, expression) = values
570                .remove_entry(key)
571                .unwrap_or_else(|| (key.clone(), Expression::default_value_for_type(ty)));
572            new_values.insert(key, expression);
573        }
574        return Expression::Struct { values: new_values, ty: to };
575    }
576    static COUNT: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0);
577    let var_name = format_smolstr!(
578        "tmpobj_ret_conv_{}",
579        COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
580    );
581    let from_ty = from.ty();
582    let mut new_values = HashMap::new();
583    let Type::Struct(from_s) = &from_ty else {
584        assert_eq!(from_ty, Type::Invalid);
585        return Expression::Invalid;
586    };
587    for (key, ty) in &to.fields {
588        let expression = if from_s.fields.contains_key(key) {
589            Expression::StructFieldAccess {
590                base: Box::new(Expression::ReadLocalVariable {
591                    name: var_name.clone(),
592                    ty: from_ty.clone(),
593                }),
594                name: key.clone(),
595            }
596        } else {
597            Expression::default_value_for_type(ty)
598        };
599        new_values.insert(key.clone(), expression);
600    }
601    Expression::CodeBlock(vec![
602        Expression::StoreLocalVariable { name: var_name, value: Box::new(from) },
603        Expression::Struct { values: new_values, ty: to },
604    ])
605}
606
607fn has_value(ty: &Type) -> bool {
608    !matches!(ty, Type::Void | Type::Invalid)
609}