i_slint_compiler/
builtin_macros.rs

1// Copyright © SixtyFPS GmbH <info@slint.dev>
2// SPDX-License-Identifier: GPL-3.0-only OR LicenseRef-Slint-Royalty-free-2.0 OR LicenseRef-Slint-Software-3.0
3
4//! This module contains the implementation of the builtin macros.
5//! They are just transformations that convert into some more complicated expression tree
6
7use crate::diagnostics::{BuildDiagnostics, Spanned};
8use crate::expression_tree::{
9    BuiltinFunction, BuiltinMacroFunction, Callable, EasingCurve, Expression, MinMaxOp, Unit,
10};
11use crate::langtype::{EnumerationValue, Type};
12use crate::parser::NodeOrToken;
13use smol_str::{format_smolstr, ToSmolStr};
14
15/// Used for uniquely name some variables
16static COUNTER: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(1);
17
18/// "Expand" the macro `mac` (at location `n`) with the arguments `sub_expr`
19pub fn lower_macro(
20    mac: BuiltinMacroFunction,
21    n: &dyn Spanned,
22    mut sub_expr: impl Iterator<Item = (Expression, Option<NodeOrToken>)>,
23    diag: &mut BuildDiagnostics,
24) -> Expression {
25    match mac {
26        BuiltinMacroFunction::Min => min_max_macro(n, MinMaxOp::Min, sub_expr.collect(), diag),
27        BuiltinMacroFunction::Max => min_max_macro(n, MinMaxOp::Max, sub_expr.collect(), diag),
28        BuiltinMacroFunction::Clamp => clamp_macro(n, sub_expr.collect(), diag),
29        BuiltinMacroFunction::Mod => mod_macro(n, sub_expr.collect(), diag),
30        BuiltinMacroFunction::Abs => abs_macro(n, sub_expr.collect(), diag),
31        BuiltinMacroFunction::Debug => debug_macro(n, sub_expr.collect(), diag),
32        BuiltinMacroFunction::CubicBezier => {
33            let mut has_error = None;
34            let expected_argument_type_error =
35                "Arguments to cubic bezier curve must be number literal";
36            // FIXME: this is not pretty to be handling there.
37            // Maybe "cubic_bezier" should be a function that is lowered later
38            let mut a = || match sub_expr.next() {
39                None => {
40                    has_error.get_or_insert((n.to_source_location(), "Not enough arguments"));
41                    0.
42                }
43                Some((Expression::NumberLiteral(val, Unit::None), _)) => val as f32,
44                // handle negative numbers
45                Some((Expression::UnaryOp { sub, op: '-' }, n)) => match *sub {
46                    Expression::NumberLiteral(val, Unit::None) => (-1.0 * val) as f32,
47                    _ => {
48                        has_error
49                            .get_or_insert((n.to_source_location(), expected_argument_type_error));
50                        0.
51                    }
52                },
53                Some((_, n)) => {
54                    has_error.get_or_insert((n.to_source_location(), expected_argument_type_error));
55                    0.
56                }
57            };
58            let expr = Expression::EasingCurve(EasingCurve::CubicBezier(a(), a(), a(), a()));
59            if let Some((_, n)) = sub_expr.next() {
60                has_error
61                    .get_or_insert((n.to_source_location(), "Too many argument for bezier curve"));
62            }
63            if let Some((n, msg)) = has_error {
64                diag.push_error(msg.into(), &n);
65            }
66
67            expr
68        }
69        BuiltinMacroFunction::Rgb => rgb_macro(n, sub_expr.collect(), diag),
70        BuiltinMacroFunction::Hsv => hsv_macro(n, sub_expr.collect(), diag),
71    }
72}
73
74fn min_max_macro(
75    node: &dyn Spanned,
76    op: MinMaxOp,
77    args: Vec<(Expression, Option<NodeOrToken>)>,
78    diag: &mut BuildDiagnostics,
79) -> Expression {
80    if args.is_empty() {
81        diag.push_error("Needs at least one argument".into(), node);
82        return Expression::Invalid;
83    }
84    let ty = Expression::common_target_type_for_type_list(args.iter().map(|expr| expr.0.ty()));
85    if ty.as_unit_product().is_none() {
86        diag.push_error("Invalid argument type".into(), node);
87        return Expression::Invalid;
88    }
89    let mut args = args.into_iter();
90    let (base, arg_node) = args.next().unwrap();
91    let mut base = base.maybe_convert_to(ty.clone(), &arg_node, diag);
92    for (next, arg_node) in args {
93        let rhs = next.maybe_convert_to(ty.clone(), &arg_node, diag);
94        base = min_max_expression(base, rhs, op);
95    }
96    base
97}
98
99fn clamp_macro(
100    node: &dyn Spanned,
101    args: Vec<(Expression, Option<NodeOrToken>)>,
102    diag: &mut BuildDiagnostics,
103) -> Expression {
104    if args.len() != 3 {
105        diag.push_error(
106            "`clamp` needs three values: the `value` to clamp, the `minimum` and the `maximum`"
107                .into(),
108            node,
109        );
110        return Expression::Invalid;
111    }
112    let (value, value_node) = args.first().unwrap().clone();
113    let ty = value.ty();
114    if ty.as_unit_product().is_none() {
115        diag.push_error("Invalid argument type".into(), &value_node);
116        return Expression::Invalid;
117    }
118
119    let (min, min_node) = args.get(1).unwrap().clone();
120    let min = min.maybe_convert_to(ty.clone(), &min_node, diag);
121    let (max, max_node) = args.get(2).unwrap().clone();
122    let max = max.maybe_convert_to(ty.clone(), &max_node, diag);
123
124    let value = min_max_expression(value, max, MinMaxOp::Min);
125    min_max_expression(min, value, MinMaxOp::Max)
126}
127
128fn mod_macro(
129    node: &dyn Spanned,
130    args: Vec<(Expression, Option<NodeOrToken>)>,
131    diag: &mut BuildDiagnostics,
132) -> Expression {
133    if args.len() != 2 {
134        diag.push_error("Needs 2 arguments".into(), node);
135        return Expression::Invalid;
136    }
137    let (lhs_ty, rhs_ty) = (args[0].0.ty(), args[1].0.ty());
138    let common_ty = if lhs_ty.default_unit().is_some() {
139        lhs_ty
140    } else if rhs_ty.default_unit().is_some() {
141        rhs_ty
142    } else if matches!(lhs_ty, Type::UnitProduct(_)) {
143        lhs_ty
144    } else if matches!(rhs_ty, Type::UnitProduct(_)) {
145        rhs_ty
146    } else {
147        Type::Float32
148    };
149
150    let source_location = Some(node.to_source_location());
151    let function = Callable::Builtin(BuiltinFunction::Mod);
152    let arguments = args.into_iter().map(|(e, n)| e.maybe_convert_to(common_ty.clone(), &n, diag));
153    if matches!(common_ty, Type::Float32) {
154        Expression::FunctionCall { function, arguments: arguments.collect(), source_location }
155    } else {
156        Expression::Cast {
157            from: Expression::FunctionCall {
158                function,
159                arguments: arguments
160                    .map(|a| Expression::Cast { from: a.into(), to: Type::Float32 })
161                    .collect(),
162                source_location,
163            }
164            .into(),
165            to: common_ty.clone(),
166        }
167    }
168}
169
170fn abs_macro(
171    node: &dyn Spanned,
172    args: Vec<(Expression, Option<NodeOrToken>)>,
173    diag: &mut BuildDiagnostics,
174) -> Expression {
175    if args.len() != 1 {
176        diag.push_error("Needs 1 argument".into(), node);
177        return Expression::Invalid;
178    }
179    let ty = args[0].0.ty();
180    let ty = if ty.default_unit().is_some() || matches!(ty, Type::UnitProduct(_)) {
181        ty
182    } else {
183        Type::Float32
184    };
185
186    let source_location = Some(node.to_source_location());
187    let function = Callable::Builtin(BuiltinFunction::Abs);
188    if matches!(ty, Type::Float32) {
189        let arguments =
190            args.into_iter().map(|(e, n)| e.maybe_convert_to(ty.clone(), &n, diag)).collect();
191        Expression::FunctionCall { function, arguments, source_location }
192    } else {
193        Expression::Cast {
194            from: Expression::FunctionCall {
195                function,
196                arguments: args
197                    .into_iter()
198                    .map(|(a, _)| Expression::Cast { from: a.into(), to: Type::Float32 })
199                    .collect(),
200                source_location,
201            }
202            .into(),
203            to: ty,
204        }
205    }
206}
207
208fn rgb_macro(
209    node: &dyn Spanned,
210    args: Vec<(Expression, Option<NodeOrToken>)>,
211    diag: &mut BuildDiagnostics,
212) -> Expression {
213    if args.len() < 3 || args.len() > 4 {
214        diag.push_error(
215            format!("This function needs 3 or 4 arguments, but {} were provided", args.len()),
216            node,
217        );
218        return Expression::Invalid;
219    }
220    let mut arguments: Vec<_> = args
221        .into_iter()
222        .enumerate()
223        .map(|(i, (expr, n))| {
224            if i < 3 {
225                if expr.ty() == Type::Percent {
226                    Expression::BinaryExpression {
227                        lhs: Box::new(expr.maybe_convert_to(Type::Float32, &n, diag)),
228                        rhs: Box::new(Expression::NumberLiteral(255., Unit::None)),
229                        op: '*',
230                    }
231                } else {
232                    expr.maybe_convert_to(Type::Float32, &n, diag)
233                }
234            } else {
235                expr.maybe_convert_to(Type::Float32, &n, diag)
236            }
237        })
238        .collect();
239    if arguments.len() < 4 {
240        arguments.push(Expression::NumberLiteral(1., Unit::None))
241    }
242    Expression::FunctionCall {
243        function: BuiltinFunction::Rgb.into(),
244        arguments,
245        source_location: Some(node.to_source_location()),
246    }
247}
248
249fn hsv_macro(
250    node: &dyn Spanned,
251    args: Vec<(Expression, Option<NodeOrToken>)>,
252    diag: &mut BuildDiagnostics,
253) -> Expression {
254    if args.len() < 3 || args.len() > 4 {
255        diag.push_error(
256            format!("This function needs 3 or 4 arguments, but {} were provided", args.len()),
257            node,
258        );
259        return Expression::Invalid;
260    }
261    let mut arguments: Vec<_> =
262        args.into_iter().map(|(expr, n)| expr.maybe_convert_to(Type::Float32, &n, diag)).collect();
263    if arguments.len() < 4 {
264        arguments.push(Expression::NumberLiteral(1., Unit::None))
265    }
266    Expression::FunctionCall {
267        function: BuiltinFunction::Hsv.into(),
268        arguments,
269        source_location: Some(node.to_source_location()),
270    }
271}
272
273fn debug_macro(
274    node: &dyn Spanned,
275    args: Vec<(Expression, Option<NodeOrToken>)>,
276    diag: &mut BuildDiagnostics,
277) -> Expression {
278    let mut string = None;
279    for (expr, node) in args {
280        let val = to_debug_string(expr, &node, diag);
281        string = Some(match string {
282            None => val,
283            Some(string) => Expression::BinaryExpression {
284                lhs: Box::new(string),
285                op: '+',
286                rhs: Box::new(Expression::BinaryExpression {
287                    lhs: Box::new(Expression::StringLiteral(" ".into())),
288                    op: '+',
289                    rhs: Box::new(val),
290                }),
291            },
292        });
293    }
294    Expression::FunctionCall {
295        function: BuiltinFunction::Debug.into(),
296        arguments: vec![string.unwrap_or_else(|| Expression::default_value_for_type(&Type::String))],
297        source_location: Some(node.to_source_location()),
298    }
299}
300
301fn to_debug_string(
302    expr: Expression,
303    node: &dyn Spanned,
304    diag: &mut BuildDiagnostics,
305) -> Expression {
306    let ty = expr.ty();
307    match &ty {
308        Type::Invalid => Expression::Invalid,
309        Type::Void
310        | Type::InferredCallback
311        | Type::InferredProperty
312        | Type::Callback { .. }
313        | Type::ComponentFactory
314        | Type::Function { .. }
315        | Type::ElementReference
316        | Type::LayoutCache
317        | Type::Model
318        | Type::PathData => {
319            diag.push_error("Cannot debug this expression".into(), node);
320            Expression::Invalid
321        }
322        Type::Float32 | Type::Int32 => expr.maybe_convert_to(Type::String, node, diag),
323        Type::String => expr,
324        // TODO
325        Type::Color | Type::Brush | Type::Image | Type::Easing | Type::Array(_) => {
326            Expression::StringLiteral("<debug-of-this-type-not-yet-implemented>".into())
327        }
328        Type::Duration
329        | Type::PhysicalLength
330        | Type::LogicalLength
331        | Type::Rem
332        | Type::Angle
333        | Type::Percent
334        | Type::UnitProduct(_) => Expression::BinaryExpression {
335            lhs: Box::new(
336                Expression::Cast { from: Box::new(expr), to: Type::Float32 }.maybe_convert_to(
337                    Type::String,
338                    node,
339                    diag,
340                ),
341            ),
342            op: '+',
343            rhs: Box::new(Expression::StringLiteral(
344                Type::UnitProduct(ty.as_unit_product().unwrap()).to_smolstr(),
345            )),
346        },
347        Type::Bool => Expression::Condition {
348            condition: Box::new(expr),
349            true_expr: Box::new(Expression::StringLiteral("true".into())),
350            false_expr: Box::new(Expression::StringLiteral("false".into())),
351        },
352        Type::Struct(s) => {
353            let local_object = format_smolstr!(
354                "debug_struct{}",
355                COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
356            );
357            let mut string = None;
358            for k in s.fields.keys() {
359                let field_name = if string.is_some() {
360                    format_smolstr!(", {}: ", k)
361                } else {
362                    format_smolstr!("{{ {}: ", k)
363                };
364                let value = to_debug_string(
365                    Expression::StructFieldAccess {
366                        base: Box::new(Expression::ReadLocalVariable {
367                            name: local_object.clone(),
368                            ty: ty.clone(),
369                        }),
370                        name: k.clone(),
371                    },
372                    node,
373                    diag,
374                );
375                let field = Expression::BinaryExpression {
376                    lhs: Box::new(Expression::StringLiteral(field_name)),
377                    op: '+',
378                    rhs: Box::new(value),
379                };
380                string = Some(match string {
381                    None => field,
382                    Some(x) => Expression::BinaryExpression {
383                        lhs: Box::new(x),
384                        op: '+',
385                        rhs: Box::new(field),
386                    },
387                });
388            }
389            match string {
390                None => Expression::StringLiteral("{}".into()),
391                Some(string) => Expression::CodeBlock(vec![
392                    Expression::StoreLocalVariable { name: local_object, value: Box::new(expr) },
393                    Expression::BinaryExpression {
394                        lhs: Box::new(string),
395                        op: '+',
396                        rhs: Box::new(Expression::StringLiteral(" }".into())),
397                    },
398                ]),
399            }
400        }
401        Type::Enumeration(enu) => {
402            let local_object = "debug_enum";
403            let mut v = vec![Expression::StoreLocalVariable {
404                name: local_object.into(),
405                value: Box::new(expr),
406            }];
407            let mut cond =
408                Expression::StringLiteral(format_smolstr!("Error: invalid value for {}", ty));
409            for (idx, val) in enu.values.iter().enumerate() {
410                cond = Expression::Condition {
411                    condition: Box::new(Expression::BinaryExpression {
412                        lhs: Box::new(Expression::ReadLocalVariable {
413                            name: local_object.into(),
414                            ty: ty.clone(),
415                        }),
416                        rhs: Box::new(Expression::EnumerationValue(EnumerationValue {
417                            value: idx,
418                            enumeration: enu.clone(),
419                        })),
420                        op: '=',
421                    }),
422                    true_expr: Box::new(Expression::StringLiteral(val.clone())),
423                    false_expr: Box::new(cond),
424                };
425            }
426            v.push(cond);
427            Expression::CodeBlock(v)
428        }
429    }
430}
431
432/// Generate an expression which is like `min(lhs, rhs)` if op is '<' or `max(lhs, rhs)` if op is '>'.
433/// counter is an unique id.
434/// The rhs and lhs of the expression must have the same numerical type
435pub fn min_max_expression(lhs: Expression, rhs: Expression, op: MinMaxOp) -> Expression {
436    let lhs_ty = lhs.ty();
437    let rhs_ty = rhs.ty();
438    let ty = match (lhs_ty, rhs_ty) {
439        (a, b) if a == b => a,
440        (Type::Int32, Type::Float32) | (Type::Float32, Type::Int32) => Type::Float32,
441        _ => Type::Invalid,
442    };
443    Expression::MinMax { ty, op, lhs: Box::new(lhs), rhs: Box::new(rhs) }
444}