1use crate::diagnostics::{BuildDiagnostics, Spanned};
8use crate::expression_tree::{
9 BuiltinFunction, BuiltinMacroFunction, Callable, EasingCurve, Expression, MinMaxOp, Unit,
10};
11use crate::langtype::{EnumerationValue, Type};
12use crate::parser::NodeOrToken;
13use smol_str::{format_smolstr, ToSmolStr};
14
15static COUNTER: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(1);
17
18pub fn lower_macro(
20 mac: BuiltinMacroFunction,
21 n: &dyn Spanned,
22 mut sub_expr: impl Iterator<Item = (Expression, Option<NodeOrToken>)>,
23 diag: &mut BuildDiagnostics,
24) -> Expression {
25 match mac {
26 BuiltinMacroFunction::Min => min_max_macro(n, MinMaxOp::Min, sub_expr.collect(), diag),
27 BuiltinMacroFunction::Max => min_max_macro(n, MinMaxOp::Max, sub_expr.collect(), diag),
28 BuiltinMacroFunction::Clamp => clamp_macro(n, sub_expr.collect(), diag),
29 BuiltinMacroFunction::Mod => mod_macro(n, sub_expr.collect(), diag),
30 BuiltinMacroFunction::Abs => abs_macro(n, sub_expr.collect(), diag),
31 BuiltinMacroFunction::Sign => {
32 let Some((x, arg_node)) = sub_expr.next() else {
33 diag.push_error("Expected one argument".into(), n);
34 return Expression::Invalid;
35 };
36 if sub_expr.next().is_some() {
37 diag.push_error("Expected only one argument".into(), n);
38 }
39 Expression::Condition {
40 condition: Expression::BinaryExpression {
41 lhs: x.maybe_convert_to(Type::Float32, &arg_node, diag).into(),
42 rhs: Expression::NumberLiteral(0., Unit::None).into(),
43 op: '<',
44 }
45 .into(),
46 true_expr: Expression::NumberLiteral(-1., Unit::None).into(),
47 false_expr: Expression::NumberLiteral(1., Unit::None).into(),
48 }
49 }
50 BuiltinMacroFunction::Debug => debug_macro(n, sub_expr.collect(), diag),
51 BuiltinMacroFunction::CubicBezier => {
52 let mut has_error = None;
53 let expected_argument_type_error =
54 "Arguments to cubic bezier curve must be number literal";
55 let mut a = || match sub_expr.next() {
58 None => {
59 has_error.get_or_insert((n.to_source_location(), "Not enough arguments"));
60 0.
61 }
62 Some((Expression::NumberLiteral(val, Unit::None), _)) => val as f32,
63 Some((Expression::UnaryOp { sub, op: '-' }, n)) => match *sub {
65 Expression::NumberLiteral(val, Unit::None) => (-1.0 * val) as f32,
66 _ => {
67 has_error
68 .get_or_insert((n.to_source_location(), expected_argument_type_error));
69 0.
70 }
71 },
72 Some((_, n)) => {
73 has_error.get_or_insert((n.to_source_location(), expected_argument_type_error));
74 0.
75 }
76 };
77 let expr = Expression::EasingCurve(EasingCurve::CubicBezier(a(), a(), a(), a()));
78 if let Some((_, n)) = sub_expr.next() {
79 has_error
80 .get_or_insert((n.to_source_location(), "Too many argument for bezier curve"));
81 }
82 if let Some((n, msg)) = has_error {
83 diag.push_error(msg.into(), &n);
84 }
85
86 expr
87 }
88 BuiltinMacroFunction::Rgb => rgb_macro(n, sub_expr.collect(), diag),
89 BuiltinMacroFunction::Hsv => hsv_macro(n, sub_expr.collect(), diag),
90 }
91}
92
93fn min_max_macro(
94 node: &dyn Spanned,
95 op: MinMaxOp,
96 args: Vec<(Expression, Option<NodeOrToken>)>,
97 diag: &mut BuildDiagnostics,
98) -> Expression {
99 if args.is_empty() {
100 diag.push_error("Needs at least one argument".into(), node);
101 return Expression::Invalid;
102 }
103 let ty = Expression::common_target_type_for_type_list(args.iter().map(|expr| expr.0.ty()));
104 if ty.as_unit_product().is_none() {
105 diag.push_error("Invalid argument type".into(), node);
106 return Expression::Invalid;
107 }
108 let mut args = args.into_iter();
109 let (base, arg_node) = args.next().unwrap();
110 let mut base = base.maybe_convert_to(ty.clone(), &arg_node, diag);
111 for (next, arg_node) in args {
112 let rhs = next.maybe_convert_to(ty.clone(), &arg_node, diag);
113 base = min_max_expression(base, rhs, op);
114 }
115 base
116}
117
118fn clamp_macro(
119 node: &dyn Spanned,
120 args: Vec<(Expression, Option<NodeOrToken>)>,
121 diag: &mut BuildDiagnostics,
122) -> Expression {
123 if args.len() != 3 {
124 diag.push_error(
125 "`clamp` needs three values: the `value` to clamp, the `minimum` and the `maximum`"
126 .into(),
127 node,
128 );
129 return Expression::Invalid;
130 }
131 let (value, value_node) = args.first().unwrap().clone();
132 let ty = value.ty();
133 if ty.as_unit_product().is_none() {
134 diag.push_error("Invalid argument type".into(), &value_node);
135 return Expression::Invalid;
136 }
137
138 let (min, min_node) = args.get(1).unwrap().clone();
139 let min = min.maybe_convert_to(ty.clone(), &min_node, diag);
140 let (max, max_node) = args.get(2).unwrap().clone();
141 let max = max.maybe_convert_to(ty.clone(), &max_node, diag);
142
143 let value = min_max_expression(value, max, MinMaxOp::Min);
144 min_max_expression(min, value, MinMaxOp::Max)
145}
146
147fn mod_macro(
148 node: &dyn Spanned,
149 args: Vec<(Expression, Option<NodeOrToken>)>,
150 diag: &mut BuildDiagnostics,
151) -> Expression {
152 if args.len() != 2 {
153 diag.push_error("Needs 2 arguments".into(), node);
154 return Expression::Invalid;
155 }
156 let (lhs_ty, rhs_ty) = (args[0].0.ty(), args[1].0.ty());
157 let common_ty = if lhs_ty.default_unit().is_some() {
158 lhs_ty
159 } else if rhs_ty.default_unit().is_some() {
160 rhs_ty
161 } else if matches!(lhs_ty, Type::UnitProduct(_)) {
162 lhs_ty
163 } else if matches!(rhs_ty, Type::UnitProduct(_)) {
164 rhs_ty
165 } else {
166 Type::Float32
167 };
168
169 let source_location = Some(node.to_source_location());
170 let function = Callable::Builtin(BuiltinFunction::Mod);
171 let arguments = args.into_iter().map(|(e, n)| e.maybe_convert_to(common_ty.clone(), &n, diag));
172 if matches!(common_ty, Type::Float32) {
173 Expression::FunctionCall { function, arguments: arguments.collect(), source_location }
174 } else {
175 Expression::Cast {
176 from: Expression::FunctionCall {
177 function,
178 arguments: arguments
179 .map(|a| Expression::Cast { from: a.into(), to: Type::Float32 })
180 .collect(),
181 source_location,
182 }
183 .into(),
184 to: common_ty.clone(),
185 }
186 }
187}
188
189fn abs_macro(
190 node: &dyn Spanned,
191 args: Vec<(Expression, Option<NodeOrToken>)>,
192 diag: &mut BuildDiagnostics,
193) -> Expression {
194 if args.len() != 1 {
195 diag.push_error("Needs 1 argument".into(), node);
196 return Expression::Invalid;
197 }
198 let ty = args[0].0.ty();
199 let ty = if ty.default_unit().is_some() || matches!(ty, Type::UnitProduct(_)) {
200 ty
201 } else {
202 Type::Float32
203 };
204
205 let source_location = Some(node.to_source_location());
206 let function = Callable::Builtin(BuiltinFunction::Abs);
207 if matches!(ty, Type::Float32) {
208 let arguments =
209 args.into_iter().map(|(e, n)| e.maybe_convert_to(ty.clone(), &n, diag)).collect();
210 Expression::FunctionCall { function, arguments, source_location }
211 } else {
212 Expression::Cast {
213 from: Expression::FunctionCall {
214 function,
215 arguments: args
216 .into_iter()
217 .map(|(a, _)| Expression::Cast { from: a.into(), to: Type::Float32 })
218 .collect(),
219 source_location,
220 }
221 .into(),
222 to: ty,
223 }
224 }
225}
226
227fn rgb_macro(
228 node: &dyn Spanned,
229 args: Vec<(Expression, Option<NodeOrToken>)>,
230 diag: &mut BuildDiagnostics,
231) -> Expression {
232 if args.len() < 3 || args.len() > 4 {
233 diag.push_error(
234 format!("This function needs 3 or 4 arguments, but {} were provided", args.len()),
235 node,
236 );
237 return Expression::Invalid;
238 }
239 let mut arguments: Vec<_> = args
240 .into_iter()
241 .enumerate()
242 .map(|(i, (expr, n))| {
243 if i < 3 {
244 if expr.ty() == Type::Percent {
245 Expression::BinaryExpression {
246 lhs: Box::new(expr.maybe_convert_to(Type::Float32, &n, diag)),
247 rhs: Box::new(Expression::NumberLiteral(255., Unit::None)),
248 op: '*',
249 }
250 } else {
251 expr.maybe_convert_to(Type::Float32, &n, diag)
252 }
253 } else {
254 expr.maybe_convert_to(Type::Float32, &n, diag)
255 }
256 })
257 .collect();
258 if arguments.len() < 4 {
259 arguments.push(Expression::NumberLiteral(1., Unit::None))
260 }
261 Expression::FunctionCall {
262 function: BuiltinFunction::Rgb.into(),
263 arguments,
264 source_location: Some(node.to_source_location()),
265 }
266}
267
268fn hsv_macro(
269 node: &dyn Spanned,
270 args: Vec<(Expression, Option<NodeOrToken>)>,
271 diag: &mut BuildDiagnostics,
272) -> Expression {
273 if args.len() < 3 || args.len() > 4 {
274 diag.push_error(
275 format!("This function needs 3 or 4 arguments, but {} were provided", args.len()),
276 node,
277 );
278 return Expression::Invalid;
279 }
280 let mut arguments: Vec<_> =
281 args.into_iter().map(|(expr, n)| expr.maybe_convert_to(Type::Float32, &n, diag)).collect();
282 if arguments.len() < 4 {
283 arguments.push(Expression::NumberLiteral(1., Unit::None))
284 }
285 Expression::FunctionCall {
286 function: BuiltinFunction::Hsv.into(),
287 arguments,
288 source_location: Some(node.to_source_location()),
289 }
290}
291
292fn debug_macro(
293 node: &dyn Spanned,
294 args: Vec<(Expression, Option<NodeOrToken>)>,
295 diag: &mut BuildDiagnostics,
296) -> Expression {
297 let mut string = None;
298 for (expr, node) in args {
299 let val = to_debug_string(expr, &node, diag);
300 string = Some(match string {
301 None => val,
302 Some(string) => Expression::BinaryExpression {
303 lhs: Box::new(string),
304 op: '+',
305 rhs: Box::new(Expression::BinaryExpression {
306 lhs: Box::new(Expression::StringLiteral(" ".into())),
307 op: '+',
308 rhs: Box::new(val),
309 }),
310 },
311 });
312 }
313 Expression::FunctionCall {
314 function: BuiltinFunction::Debug.into(),
315 arguments: vec![string.unwrap_or_else(|| Expression::default_value_for_type(&Type::String))],
316 source_location: Some(node.to_source_location()),
317 }
318}
319
320fn to_debug_string(
321 expr: Expression,
322 node: &dyn Spanned,
323 diag: &mut BuildDiagnostics,
324) -> Expression {
325 let ty = expr.ty();
326 match &ty {
327 Type::Invalid => Expression::Invalid,
328 Type::Void
329 | Type::InferredCallback
330 | Type::InferredProperty
331 | Type::Callback { .. }
332 | Type::ComponentFactory
333 | Type::Function { .. }
334 | Type::ElementReference
335 | Type::LayoutCache
336 | Type::Model
337 | Type::PathData => {
338 diag.push_error("Cannot debug this expression".into(), node);
339 Expression::Invalid
340 }
341 Type::Float32 | Type::Int32 => expr.maybe_convert_to(Type::String, node, diag),
342 Type::String => expr,
343 Type::Color | Type::Brush | Type::Image | Type::Easing | Type::Array(_) => {
345 Expression::StringLiteral("<debug-of-this-type-not-yet-implemented>".into())
346 }
347 Type::Duration
348 | Type::PhysicalLength
349 | Type::LogicalLength
350 | Type::Rem
351 | Type::Angle
352 | Type::Percent
353 | Type::UnitProduct(_) => Expression::BinaryExpression {
354 lhs: Box::new(
355 Expression::Cast { from: Box::new(expr), to: Type::Float32 }.maybe_convert_to(
356 Type::String,
357 node,
358 diag,
359 ),
360 ),
361 op: '+',
362 rhs: Box::new(Expression::StringLiteral(
363 Type::UnitProduct(ty.as_unit_product().unwrap()).to_smolstr(),
364 )),
365 },
366 Type::Bool => Expression::Condition {
367 condition: Box::new(expr),
368 true_expr: Box::new(Expression::StringLiteral("true".into())),
369 false_expr: Box::new(Expression::StringLiteral("false".into())),
370 },
371 Type::Struct(s) => {
372 let local_object = format_smolstr!(
373 "debug_struct{}",
374 COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
375 );
376 let mut string = None;
377 for k in s.fields.keys() {
378 let field_name = if string.is_some() {
379 format_smolstr!(", {}: ", k)
380 } else {
381 format_smolstr!("{{ {}: ", k)
382 };
383 let value = to_debug_string(
384 Expression::StructFieldAccess {
385 base: Box::new(Expression::ReadLocalVariable {
386 name: local_object.clone(),
387 ty: ty.clone(),
388 }),
389 name: k.clone(),
390 },
391 node,
392 diag,
393 );
394 let field = Expression::BinaryExpression {
395 lhs: Box::new(Expression::StringLiteral(field_name)),
396 op: '+',
397 rhs: Box::new(value),
398 };
399 string = Some(match string {
400 None => field,
401 Some(x) => Expression::BinaryExpression {
402 lhs: Box::new(x),
403 op: '+',
404 rhs: Box::new(field),
405 },
406 });
407 }
408 match string {
409 None => Expression::StringLiteral("{}".into()),
410 Some(string) => Expression::CodeBlock(vec![
411 Expression::StoreLocalVariable { name: local_object, value: Box::new(expr) },
412 Expression::BinaryExpression {
413 lhs: Box::new(string),
414 op: '+',
415 rhs: Box::new(Expression::StringLiteral(" }".into())),
416 },
417 ]),
418 }
419 }
420 Type::Enumeration(enu) => {
421 let local_object = "debug_enum";
422 let mut v = vec![Expression::StoreLocalVariable {
423 name: local_object.into(),
424 value: Box::new(expr),
425 }];
426 let mut cond =
427 Expression::StringLiteral(format_smolstr!("Error: invalid value for {}", ty));
428 for (idx, val) in enu.values.iter().enumerate() {
429 cond = Expression::Condition {
430 condition: Box::new(Expression::BinaryExpression {
431 lhs: Box::new(Expression::ReadLocalVariable {
432 name: local_object.into(),
433 ty: ty.clone(),
434 }),
435 rhs: Box::new(Expression::EnumerationValue(EnumerationValue {
436 value: idx,
437 enumeration: enu.clone(),
438 })),
439 op: '=',
440 }),
441 true_expr: Box::new(Expression::StringLiteral(val.clone())),
442 false_expr: Box::new(cond),
443 };
444 }
445 v.push(cond);
446 Expression::CodeBlock(v)
447 }
448 }
449}
450
451pub fn min_max_expression(lhs: Expression, rhs: Expression, op: MinMaxOp) -> Expression {
455 let lhs_ty = lhs.ty();
456 let rhs_ty = rhs.ty();
457 let ty = match (lhs_ty, rhs_ty) {
458 (a, b) if a == b => a,
459 (Type::Int32, Type::Float32) | (Type::Float32, Type::Int32) => Type::Float32,
460 _ => Type::Invalid,
461 };
462 Expression::MinMax { ty, op, lhs: Box::new(lhs), rhs: Box::new(rhs) }
463}