1use crate::diagnostics::{BuildDiagnostics, Spanned};
8use crate::expression_tree::{
9 BuiltinFunction, BuiltinMacroFunction, Callable, EasingCurve, Expression, MinMaxOp, Unit,
10};
11use crate::langtype::Type;
12use crate::parser::NodeOrToken;
13use smol_str::{ToSmolStr, format_smolstr};
14
15static COUNTER: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(1);
17
18pub fn lower_macro(
20 mac: BuiltinMacroFunction,
21 n: &dyn Spanned,
22 mut sub_expr: impl Iterator<Item = (Expression, Option<NodeOrToken>)>,
23 diag: &mut BuildDiagnostics,
24) -> Expression {
25 match mac {
26 BuiltinMacroFunction::Min => min_max_macro(n, MinMaxOp::Min, sub_expr.collect(), diag),
27 BuiltinMacroFunction::Max => min_max_macro(n, MinMaxOp::Max, sub_expr.collect(), diag),
28 BuiltinMacroFunction::Clamp => clamp_macro(n, sub_expr.collect(), diag),
29 BuiltinMacroFunction::Mod => mod_macro(n, sub_expr.collect(), diag),
30 BuiltinMacroFunction::Abs => abs_macro(n, sub_expr.collect(), diag),
31 BuiltinMacroFunction::Sign => {
32 let Some((x, arg_node)) = sub_expr.next() else {
33 diag.push_error("Expected one argument".into(), n);
34 return Expression::Invalid;
35 };
36 if sub_expr.next().is_some() {
37 diag.push_error("Expected only one argument".into(), n);
38 }
39 Expression::Condition {
40 condition: Expression::BinaryExpression {
41 lhs: x.maybe_convert_to(Type::Float32, &arg_node, diag).into(),
42 rhs: Expression::NumberLiteral(0., Unit::None).into(),
43 op: '<',
44 }
45 .into(),
46 true_expr: Expression::NumberLiteral(-1., Unit::None).into(),
47 false_expr: Expression::NumberLiteral(1., Unit::None).into(),
48 }
49 }
50 BuiltinMacroFunction::Debug => debug_macro(n, sub_expr.collect(), diag),
51 BuiltinMacroFunction::CubicBezier => {
52 let mut has_error = None;
53 let expected_argument_type_error =
54 "Arguments to cubic bezier curve must be number literal";
55 let mut a = || match sub_expr.next() {
58 None => {
59 has_error.get_or_insert((n.to_source_location(), "Not enough arguments"));
60 0.
61 }
62 Some((Expression::NumberLiteral(val, Unit::None), _)) => val as f32,
63 Some((Expression::UnaryOp { sub, op: '-' }, n)) => match *sub {
65 Expression::NumberLiteral(val, Unit::None) => -val as f32,
66 _ => {
67 has_error
68 .get_or_insert((n.to_source_location(), expected_argument_type_error));
69 0.
70 }
71 },
72 Some((_, n)) => {
73 has_error.get_or_insert((n.to_source_location(), expected_argument_type_error));
74 0.
75 }
76 };
77 let expr = Expression::EasingCurve(EasingCurve::CubicBezier(a(), a(), a(), a()));
78 if let Some((_, n)) = sub_expr.next() {
79 has_error
80 .get_or_insert((n.to_source_location(), "Too many argument for bezier curve"));
81 }
82 if let Some((n, msg)) = has_error {
83 diag.push_error(msg.into(), &n);
84 }
85
86 expr
87 }
88 BuiltinMacroFunction::Rgb => rgb_macro(n, sub_expr.collect(), diag),
89 BuiltinMacroFunction::Hsv => hsv_macro(n, sub_expr.collect(), diag),
90 BuiltinMacroFunction::Oklch => oklch_macro(n, sub_expr.collect(), diag),
91 }
92}
93
94fn min_max_macro(
95 node: &dyn Spanned,
96 op: MinMaxOp,
97 args: Vec<(Expression, Option<NodeOrToken>)>,
98 diag: &mut BuildDiagnostics,
99) -> Expression {
100 if args.is_empty() {
101 diag.push_error("Needs at least one argument".into(), node);
102 return Expression::Invalid;
103 }
104 let ty = Expression::common_target_type_for_type_list(args.iter().map(|expr| expr.0.ty()));
105 if ty.as_unit_product().is_none() {
106 diag.push_error("Invalid argument type".into(), node);
107 return Expression::Invalid;
108 }
109 let mut args = args.into_iter();
110 let (base, arg_node) = args.next().unwrap();
111 let mut base = base.maybe_convert_to(ty.clone(), &arg_node, diag);
112 for (next, arg_node) in args {
113 let rhs = next.maybe_convert_to(ty.clone(), &arg_node, diag);
114 base = min_max_expression(base, rhs, op);
115 }
116 base
117}
118
119fn clamp_macro(
120 node: &dyn Spanned,
121 args: Vec<(Expression, Option<NodeOrToken>)>,
122 diag: &mut BuildDiagnostics,
123) -> Expression {
124 if args.len() != 3 {
125 diag.push_error(
126 "`clamp` needs three values: the `value` to clamp, the `minimum` and the `maximum`"
127 .into(),
128 node,
129 );
130 return Expression::Invalid;
131 }
132 let (value, value_node) = args.first().unwrap().clone();
133 let ty = value.ty();
134 if ty.as_unit_product().is_none() {
135 diag.push_error("Invalid argument type".into(), &value_node);
136 return Expression::Invalid;
137 }
138
139 let (min, min_node) = args.get(1).unwrap().clone();
140 let min = min.maybe_convert_to(ty.clone(), &min_node, diag);
141 let (max, max_node) = args.get(2).unwrap().clone();
142 let max = max.maybe_convert_to(ty.clone(), &max_node, diag);
143
144 let value = min_max_expression(value, max, MinMaxOp::Min);
145 min_max_expression(min, value, MinMaxOp::Max)
146}
147
148fn mod_macro(
149 node: &dyn Spanned,
150 args: Vec<(Expression, Option<NodeOrToken>)>,
151 diag: &mut BuildDiagnostics,
152) -> Expression {
153 if args.len() != 2 {
154 diag.push_error("Needs 2 arguments".into(), node);
155 return Expression::Invalid;
156 }
157 let (lhs_ty, rhs_ty) = (args[0].0.ty(), args[1].0.ty());
158 let common_ty = if lhs_ty.default_unit().is_some() {
159 lhs_ty
160 } else if rhs_ty.default_unit().is_some() {
161 rhs_ty
162 } else if matches!(lhs_ty, Type::UnitProduct(_)) {
163 lhs_ty
164 } else if matches!(rhs_ty, Type::UnitProduct(_)) {
165 rhs_ty
166 } else {
167 Type::Float32
168 };
169
170 let source_location = Some(node.to_source_location());
171 let function = Callable::Builtin(BuiltinFunction::Mod);
172 let arguments = args.into_iter().map(|(e, n)| e.maybe_convert_to(common_ty.clone(), &n, diag));
173 if matches!(common_ty, Type::Float32) {
174 Expression::FunctionCall { function, arguments: arguments.collect(), source_location }
175 } else {
176 Expression::Cast {
177 from: Expression::FunctionCall {
178 function,
179 arguments: arguments
180 .map(|a| Expression::Cast { from: a.into(), to: Type::Float32 })
181 .collect(),
182 source_location,
183 }
184 .into(),
185 to: common_ty.clone(),
186 }
187 }
188}
189
190fn abs_macro(
191 node: &dyn Spanned,
192 args: Vec<(Expression, Option<NodeOrToken>)>,
193 diag: &mut BuildDiagnostics,
194) -> Expression {
195 if args.len() != 1 {
196 diag.push_error("Needs 1 argument".into(), node);
197 return Expression::Invalid;
198 }
199 let ty = args[0].0.ty();
200 let ty = if ty.default_unit().is_some() || matches!(ty, Type::UnitProduct(_)) {
201 ty
202 } else {
203 Type::Float32
204 };
205
206 let source_location = Some(node.to_source_location());
207 let function = Callable::Builtin(BuiltinFunction::Abs);
208 if matches!(ty, Type::Float32) {
209 let arguments =
210 args.into_iter().map(|(e, n)| e.maybe_convert_to(ty.clone(), &n, diag)).collect();
211 Expression::FunctionCall { function, arguments, source_location }
212 } else {
213 Expression::Cast {
214 from: Expression::FunctionCall {
215 function,
216 arguments: args
217 .into_iter()
218 .map(|(a, _)| Expression::Cast { from: a.into(), to: Type::Float32 })
219 .collect(),
220 source_location,
221 }
222 .into(),
223 to: ty,
224 }
225 }
226}
227
228fn rgb_macro(
229 node: &dyn Spanned,
230 args: Vec<(Expression, Option<NodeOrToken>)>,
231 diag: &mut BuildDiagnostics,
232) -> Expression {
233 if args.len() < 3 || args.len() > 4 {
234 diag.push_error(
235 format!("This function needs 3 or 4 arguments, but {} were provided", args.len()),
236 node,
237 );
238 return Expression::Invalid;
239 }
240 let mut arguments: Vec<_> = args
241 .into_iter()
242 .enumerate()
243 .map(|(i, (expr, n))| {
244 if i < 3 {
245 if expr.ty() == Type::Percent {
246 Expression::BinaryExpression {
247 lhs: Box::new(expr.maybe_convert_to(Type::Float32, &n, diag)),
248 rhs: Box::new(Expression::NumberLiteral(255., Unit::None)),
249 op: '*',
250 }
251 } else {
252 expr.maybe_convert_to(Type::Float32, &n, diag)
253 }
254 } else {
255 expr.maybe_convert_to(Type::Float32, &n, diag)
256 }
257 })
258 .collect();
259 if arguments.len() < 4 {
260 arguments.push(Expression::NumberLiteral(1., Unit::None))
261 }
262 Expression::FunctionCall {
263 function: BuiltinFunction::Rgb.into(),
264 arguments,
265 source_location: Some(node.to_source_location()),
266 }
267}
268
269fn hsv_macro(
270 node: &dyn Spanned,
271 args: Vec<(Expression, Option<NodeOrToken>)>,
272 diag: &mut BuildDiagnostics,
273) -> Expression {
274 if args.len() < 3 || args.len() > 4 {
275 diag.push_error(
276 format!("This function needs 3 or 4 arguments, but {} were provided", args.len()),
277 node,
278 );
279 return Expression::Invalid;
280 }
281 let mut arguments: Vec<_> = args
282 .into_iter()
283 .enumerate()
284 .map(|(i, (expr, n))| {
285 if i == 0 && expr.ty() == Type::Angle {
287 Expression::BinaryExpression {
288 lhs: Box::new(expr),
289 rhs: Box::new(Expression::NumberLiteral(1., Unit::Deg)),
290 op: '/',
291 }
292 } else {
293 expr.maybe_convert_to(Type::Float32, &n, diag)
294 }
295 })
296 .collect();
297 if arguments.len() < 4 {
298 arguments.push(Expression::NumberLiteral(1., Unit::None))
299 }
300 Expression::FunctionCall {
301 function: BuiltinFunction::Hsv.into(),
302 arguments,
303 source_location: Some(node.to_source_location()),
304 }
305}
306
307fn oklch_macro(
308 node: &dyn Spanned,
309 args: Vec<(Expression, Option<NodeOrToken>)>,
310 diag: &mut BuildDiagnostics,
311) -> Expression {
312 if args.len() < 3 || args.len() > 4 {
313 diag.push_error(
314 format!("This function needs 3 or 4 arguments, but {} were provided", args.len()),
315 node,
316 );
317 return Expression::Invalid;
318 }
319 let mut arguments: Vec<_> = args
320 .into_iter()
321 .enumerate()
322 .map(|(i, (expr, n))| {
323 if i == 1 && expr.ty() == Type::Percent {
325 Expression::BinaryExpression {
326 lhs: Box::new(expr),
327 rhs: Box::new(Expression::NumberLiteral(0.004, Unit::None)),
328 op: '*',
329 }
330 } else if i == 2 && expr.ty() == Type::Angle {
332 Expression::BinaryExpression {
333 lhs: Box::new(expr),
334 rhs: Box::new(Expression::NumberLiteral(1., Unit::Deg)),
335 op: '/',
336 }
337 } else {
338 expr.maybe_convert_to(Type::Float32, &n, diag)
339 }
340 })
341 .collect();
342 if arguments.len() < 4 {
343 arguments.push(Expression::NumberLiteral(1., Unit::None))
344 }
345 Expression::FunctionCall {
346 function: BuiltinFunction::Oklch.into(),
347 arguments,
348 source_location: Some(node.to_source_location()),
349 }
350}
351
352fn debug_macro(
353 node: &dyn Spanned,
354 args: Vec<(Expression, Option<NodeOrToken>)>,
355 diag: &mut BuildDiagnostics,
356) -> Expression {
357 let mut string = None;
358 for (expr, node) in args {
359 let val = to_debug_string(expr, &node, diag);
360 string = Some(match string {
361 None => val,
362 Some(string) => Expression::BinaryExpression {
363 lhs: Box::new(string),
364 op: '+',
365 rhs: Box::new(Expression::BinaryExpression {
366 lhs: Box::new(Expression::StringLiteral(" ".into())),
367 op: '+',
368 rhs: Box::new(val),
369 }),
370 },
371 });
372 }
373 Expression::FunctionCall {
374 function: BuiltinFunction::Debug.into(),
375 arguments: vec![
376 string.unwrap_or_else(|| Expression::default_value_for_type(&Type::String)),
377 ],
378 source_location: Some(node.to_source_location()),
379 }
380}
381
382fn to_debug_string(
383 expr: Expression,
384 node: &dyn Spanned,
385 diag: &mut BuildDiagnostics,
386) -> Expression {
387 let ty = expr.ty();
388 match &ty {
389 Type::Invalid => Expression::Invalid,
390 Type::Void
391 | Type::InferredCallback
392 | Type::InferredProperty
393 | Type::Callback { .. }
394 | Type::ComponentFactory
395 | Type::Function { .. }
396 | Type::ElementReference
397 | Type::LayoutCache
398 | Type::ArrayOfU16
399 | Type::Model
400 | Type::PathData => {
401 diag.push_error("Cannot debug this expression".into(), node);
402 Expression::Invalid
403 }
404 Type::Float32 | Type::Int32 => expr.maybe_convert_to(Type::String, node, diag),
405 Type::String => expr,
406 Type::Color
408 | Type::Brush
409 | Type::Image
410 | Type::Easing
411 | Type::StyledText
412 | Type::Array(_) => {
413 Expression::StringLiteral("<debug-of-this-type-not-yet-implemented>".into())
414 }
415 Type::Duration
416 | Type::PhysicalLength
417 | Type::LogicalLength
418 | Type::Rem
419 | Type::Angle
420 | Type::Percent
421 | Type::UnitProduct(_) => Expression::BinaryExpression {
422 lhs: Box::new(
423 Expression::Cast { from: Box::new(expr), to: Type::Float32 }.maybe_convert_to(
424 Type::String,
425 node,
426 diag,
427 ),
428 ),
429 op: '+',
430 rhs: Box::new(Expression::StringLiteral(
431 Type::UnitProduct(ty.as_unit_product().unwrap()).to_smolstr(),
432 )),
433 },
434 Type::Bool => Expression::Condition {
435 condition: Box::new(expr),
436 true_expr: Box::new(Expression::StringLiteral("true".into())),
437 false_expr: Box::new(Expression::StringLiteral("false".into())),
438 },
439 Type::Struct(s) => {
440 let local_object = format_smolstr!(
441 "debug_struct{}",
442 COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
443 );
444 let mut string = None;
445 for k in s.fields.keys() {
446 let field_name = if string.is_some() {
447 format_smolstr!(", {}: ", k)
448 } else {
449 format_smolstr!("{{ {}: ", k)
450 };
451 let value = to_debug_string(
452 Expression::StructFieldAccess {
453 base: Box::new(Expression::ReadLocalVariable {
454 name: local_object.clone(),
455 ty: ty.clone(),
456 }),
457 name: k.clone(),
458 },
459 node,
460 diag,
461 );
462 let field = Expression::BinaryExpression {
463 lhs: Box::new(Expression::StringLiteral(field_name)),
464 op: '+',
465 rhs: Box::new(value),
466 };
467 string = Some(match string {
468 None => field,
469 Some(x) => Expression::BinaryExpression {
470 lhs: Box::new(x),
471 op: '+',
472 rhs: Box::new(field),
473 },
474 });
475 }
476 match string {
477 None => Expression::StringLiteral("{}".into()),
478 Some(string) => Expression::CodeBlock(vec![
479 Expression::StoreLocalVariable { name: local_object, value: Box::new(expr) },
480 Expression::BinaryExpression {
481 lhs: Box::new(string),
482 op: '+',
483 rhs: Box::new(Expression::StringLiteral(" }".into())),
484 },
485 ]),
486 }
487 }
488 Type::Enumeration(_) => Expression::Cast { from: Box::new(expr), to: (Type::String) },
489 }
490}
491
492pub fn min_max_expression(lhs: Expression, rhs: Expression, op: MinMaxOp) -> Expression {
496 let lhs_ty = lhs.ty();
497 let rhs_ty = rhs.ty();
498 let ty = match (lhs_ty, rhs_ty) {
499 (a, b) if a == b => a,
500 (Type::Int32, Type::Float32) | (Type::Float32, Type::Int32) => Type::Float32,
501 _ => Type::Invalid,
502 };
503 Expression::MinMax { ty, op, lhs: Box::new(lhs), rhs: Box::new(rhs) }
504}