1use num_traits::{Float, NumCast, PrimInt};
130use wick_core::{Ast, BinOp, CompareOp, UnaryOp};
131
132use std::collections::HashMap;
134use std::sync::Arc;
135pub use wick_core::Numeric;
136
137#[cfg(feature = "wgsl")]
138pub mod wgsl;
139
140#[cfg(feature = "glsl")]
141pub mod glsl;
142
143#[cfg(feature = "rust")]
144pub mod rust;
145
146#[cfg(feature = "c")]
147pub mod c;
148
149#[cfg(feature = "opencl")]
150pub mod opencl;
151
152#[cfg(feature = "cuda")]
153pub mod cuda;
154
155#[cfg(feature = "hip")]
156pub mod hip;
157
158#[cfg(feature = "tokenstream")]
159pub mod tokenstream;
160
161#[cfg(any(feature = "lua", feature = "lua-codegen"))]
162pub mod lua;
163
164#[cfg(feature = "cranelift")]
165pub mod cranelift;
166
167#[cfg(feature = "optimize")]
168pub mod optimize;
169
170#[derive(Debug, Clone, PartialEq)]
176pub enum Error {
177 UnknownVariable(String),
179 UnknownFunction(String),
181 WrongArgCount {
183 func: String,
184 expected: usize,
185 got: usize,
186 },
187 UnsupportedOperation(String),
189 InvalidLiteral(f64),
191 NegativeExponent,
193}
194
195impl std::fmt::Display for Error {
196 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197 match self {
198 Error::UnknownVariable(name) => write!(f, "unknown variable: '{name}'"),
199 Error::UnknownFunction(name) => write!(f, "unknown function: '{name}'"),
200 Error::WrongArgCount {
201 func,
202 expected,
203 got,
204 } => {
205 write!(f, "function '{func}' expects {expected} args, got {got}")
206 }
207 Error::UnsupportedOperation(op) => {
208 write!(f, "operation '{op}' not supported for this numeric type")
209 }
210 Error::InvalidLiteral(n) => {
211 write!(f, "literal {n} cannot be converted to integer type")
212 }
213 Error::NegativeExponent => {
214 write!(f, "negative exponent not allowed for integer types")
215 }
216 }
217 }
218}
219
220impl std::error::Error for Error {}
221
222pub trait ScalarFn<T>: Send + Sync {
228 fn name(&self) -> &str;
230
231 fn arg_count(&self) -> usize;
233
234 fn call(&self, args: &[T]) -> T;
236}
237
238#[derive(Clone)]
240pub struct FunctionRegistry<T> {
241 funcs: HashMap<String, Arc<dyn ScalarFn<T>>>,
242}
243
244impl<T> Default for FunctionRegistry<T> {
245 fn default() -> Self {
246 Self {
247 funcs: HashMap::new(),
248 }
249 }
250}
251
252impl<T> FunctionRegistry<T> {
253 pub fn new() -> Self {
254 Self::default()
255 }
256
257 pub fn register<F: ScalarFn<T> + 'static>(&mut self, func: F) {
258 self.funcs.insert(func.name().to_string(), Arc::new(func));
259 }
260
261 pub fn get(&self, name: &str) -> Option<&Arc<dyn ScalarFn<T>>> {
262 self.funcs.get(name)
263 }
264
265 pub fn names(&self) -> impl Iterator<Item = &str> {
267 self.funcs.keys().map(|s| s.as_str())
268 }
269}
270
271pub fn eval<T: Float>(
279 ast: &Ast,
280 vars: &HashMap<String, T>,
281 funcs: &FunctionRegistry<T>,
282) -> Result<T, Error> {
283 match ast {
284 Ast::Num(n) => Ok(T::from(*n).unwrap()),
285
286 Ast::Var(name) => vars
287 .get(name)
288 .copied()
289 .ok_or_else(|| Error::UnknownVariable(name.clone())),
290
291 Ast::BinOp(op, left, right) => {
292 let l = eval(left, vars, funcs)?;
293 let r = eval(right, vars, funcs)?;
294 match op {
295 BinOp::Add => Ok(l + r),
296 BinOp::Sub => Ok(l - r),
297 BinOp::Mul => Ok(l * r),
298 BinOp::Div => Ok(l / r),
299 BinOp::Pow => Ok(l.powf(r)),
300 BinOp::Rem => Ok(l % r),
301 BinOp::BitAnd => Err(Error::UnsupportedOperation("&".into())),
302 BinOp::BitOr => Err(Error::UnsupportedOperation("|".into())),
303 BinOp::Shl => Err(Error::UnsupportedOperation("<<".into())),
304 BinOp::Shr => Err(Error::UnsupportedOperation(">>".into())),
305 }
306 }
307
308 Ast::UnaryOp(op, inner) => {
309 let v = eval(inner, vars, funcs)?;
310 match op {
311 UnaryOp::Neg => Ok(-v),
312 UnaryOp::BitNot => Err(Error::UnsupportedOperation("~".into())),
313 UnaryOp::Not => {
314 if v == T::zero() {
315 Ok(T::one())
316 } else {
317 Ok(T::zero())
318 }
319 }
320 }
321 }
322
323 Ast::Compare(op, left, right) => {
324 let l = eval(left, vars, funcs)?;
325 let r = eval(right, vars, funcs)?;
326 let result = match op {
327 CompareOp::Lt => l < r,
328 CompareOp::Le => l <= r,
329 CompareOp::Gt => l > r,
330 CompareOp::Ge => l >= r,
331 CompareOp::Eq => l == r,
332 CompareOp::Ne => l != r,
333 };
334 Ok(if result { T::one() } else { T::zero() })
335 }
336
337 Ast::And(left, right) => {
338 let l = eval(left, vars, funcs)?;
339 if l == T::zero() {
340 Ok(T::zero()) } else {
342 let r = eval(right, vars, funcs)?;
343 Ok(if r != T::zero() { T::one() } else { T::zero() })
344 }
345 }
346
347 Ast::Or(left, right) => {
348 let l = eval(left, vars, funcs)?;
349 if l != T::zero() {
350 Ok(T::one()) } else {
352 let r = eval(right, vars, funcs)?;
353 Ok(if r != T::zero() { T::one() } else { T::zero() })
354 }
355 }
356
357 Ast::If(cond, then_expr, else_expr) => {
358 let c = eval(cond, vars, funcs)?;
359 if c != T::zero() {
360 eval(then_expr, vars, funcs)
361 } else {
362 eval(else_expr, vars, funcs)
363 }
364 }
365
366 Ast::Call(name, args) => {
367 let func = funcs
368 .get(name)
369 .ok_or_else(|| Error::UnknownFunction(name.clone()))?;
370
371 if args.len() != func.arg_count() {
372 return Err(Error::WrongArgCount {
373 func: name.clone(),
374 expected: func.arg_count(),
375 got: args.len(),
376 });
377 }
378
379 let arg_vals: Vec<T> = args
380 .iter()
381 .map(|a| eval(a, vars, funcs))
382 .collect::<Result<_, _>>()?;
383
384 Ok(func.call(&arg_vals))
385 }
386
387 Ast::Let { name, value, body } => {
388 let val = eval(value, vars, funcs)?;
389 let mut new_vars = vars.clone();
390 new_vars.insert(name.clone(), val);
391 eval(body, &new_vars, funcs)
392 }
393 }
394}
395
396pub fn eval_int<T: PrimInt + NumCast>(
403 ast: &Ast,
404 vars: &HashMap<String, T>,
405 funcs: &FunctionRegistry<T>,
406) -> Result<T, Error> {
407 match ast {
408 Ast::Num(n) => {
409 if n.fract() != 0.0 {
411 return Err(Error::InvalidLiteral(*n));
412 }
413 T::from(*n).ok_or(Error::InvalidLiteral(*n))
414 }
415
416 Ast::Var(name) => vars
417 .get(name)
418 .copied()
419 .ok_or_else(|| Error::UnknownVariable(name.clone())),
420
421 Ast::BinOp(op, left, right) => {
422 let l = eval_int(left, vars, funcs)?;
423 let r = eval_int(right, vars, funcs)?;
424 match op {
425 BinOp::Add => Ok(l + r),
426 BinOp::Sub => Ok(l - r),
427 BinOp::Mul => Ok(l * r),
428 BinOp::Div => Ok(l / r),
429 BinOp::Rem => Ok(l % r),
430 BinOp::Pow => {
431 if r < T::zero() {
433 return Err(Error::NegativeExponent);
434 }
435 let mut result = T::one();
437 let mut exp = r;
438 let mut base = l;
439 while exp > T::zero() {
440 if exp & T::one() == T::one() {
441 result = result * base;
442 }
443 base = base * base;
444 exp = exp >> 1;
445 }
446 Ok(result)
447 }
448 BinOp::BitAnd => Ok(l & r),
449 BinOp::BitOr => Ok(l | r),
450 BinOp::Shl => {
451 let shift: u32 = r.to_u32().unwrap_or(0);
453 Ok(l << shift as usize)
454 }
455 BinOp::Shr => {
456 let shift: u32 = r.to_u32().unwrap_or(0);
457 Ok(l >> shift as usize)
458 }
459 }
460 }
461
462 Ast::UnaryOp(op, inner) => {
463 let v = eval_int(inner, vars, funcs)?;
464 match op {
465 UnaryOp::Neg => Ok(T::zero() - v),
466 UnaryOp::BitNot => Ok(!v),
467 UnaryOp::Not => {
468 if v == T::zero() {
469 Ok(T::one())
470 } else {
471 Ok(T::zero())
472 }
473 }
474 }
475 }
476
477 Ast::Compare(op, left, right) => {
478 let l = eval_int(left, vars, funcs)?;
479 let r = eval_int(right, vars, funcs)?;
480 let result = match op {
481 CompareOp::Lt => l < r,
482 CompareOp::Le => l <= r,
483 CompareOp::Gt => l > r,
484 CompareOp::Ge => l >= r,
485 CompareOp::Eq => l == r,
486 CompareOp::Ne => l != r,
487 };
488 Ok(if result { T::one() } else { T::zero() })
489 }
490
491 Ast::And(left, right) => {
492 let l = eval_int(left, vars, funcs)?;
493 if l == T::zero() {
494 Ok(T::zero())
495 } else {
496 let r = eval_int(right, vars, funcs)?;
497 Ok(if r != T::zero() { T::one() } else { T::zero() })
498 }
499 }
500
501 Ast::Or(left, right) => {
502 let l = eval_int(left, vars, funcs)?;
503 if l != T::zero() {
504 Ok(T::one())
505 } else {
506 let r = eval_int(right, vars, funcs)?;
507 Ok(if r != T::zero() { T::one() } else { T::zero() })
508 }
509 }
510
511 Ast::If(cond, then_expr, else_expr) => {
512 let c = eval_int(cond, vars, funcs)?;
513 if c != T::zero() {
514 eval_int(then_expr, vars, funcs)
515 } else {
516 eval_int(else_expr, vars, funcs)
517 }
518 }
519
520 Ast::Call(name, args) => {
521 let func = funcs
522 .get(name)
523 .ok_or_else(|| Error::UnknownFunction(name.clone()))?;
524
525 if args.len() != func.arg_count() {
526 return Err(Error::WrongArgCount {
527 func: name.clone(),
528 expected: func.arg_count(),
529 got: args.len(),
530 });
531 }
532
533 let arg_vals: Vec<T> = args
534 .iter()
535 .map(|a| eval_int(a, vars, funcs))
536 .collect::<Result<_, _>>()?;
537
538 Ok(func.call(&arg_vals))
539 }
540
541 Ast::Let { name, value, body } => {
542 let val = eval_int(value, vars, funcs)?;
543 let mut new_vars = vars.clone();
544 new_vars.insert(name.clone(), val);
545 eval_int(body, &new_vars, funcs)
546 }
547 }
548}
549
550pub struct Pi;
556impl<T: Float> ScalarFn<T> for Pi {
557 fn name(&self) -> &str {
558 "pi"
559 }
560 fn arg_count(&self) -> usize {
561 0
562 }
563 fn call(&self, _args: &[T]) -> T {
564 T::from(std::f64::consts::PI).unwrap()
565 }
566}
567
568pub struct E;
570impl<T: Float> ScalarFn<T> for E {
571 fn name(&self) -> &str {
572 "e"
573 }
574 fn arg_count(&self) -> usize {
575 0
576 }
577 fn call(&self, _args: &[T]) -> T {
578 T::from(std::f64::consts::E).unwrap()
579 }
580}
581
582pub struct Tau;
584impl<T: Float> ScalarFn<T> for Tau {
585 fn name(&self) -> &str {
586 "tau"
587 }
588 fn arg_count(&self) -> usize {
589 0
590 }
591 fn call(&self, _args: &[T]) -> T {
592 T::from(std::f64::consts::TAU).unwrap()
593 }
594}
595
596pub struct Sin;
601impl<T: Float> ScalarFn<T> for Sin {
602 fn name(&self) -> &str {
603 "sin"
604 }
605 fn arg_count(&self) -> usize {
606 1
607 }
608 fn call(&self, args: &[T]) -> T {
609 args[0].sin()
610 }
611}
612
613pub struct Cos;
614impl<T: Float> ScalarFn<T> for Cos {
615 fn name(&self) -> &str {
616 "cos"
617 }
618 fn arg_count(&self) -> usize {
619 1
620 }
621 fn call(&self, args: &[T]) -> T {
622 args[0].cos()
623 }
624}
625
626pub struct Tan;
627impl<T: Float> ScalarFn<T> for Tan {
628 fn name(&self) -> &str {
629 "tan"
630 }
631 fn arg_count(&self) -> usize {
632 1
633 }
634 fn call(&self, args: &[T]) -> T {
635 args[0].tan()
636 }
637}
638
639pub struct Asin;
640impl<T: Float> ScalarFn<T> for Asin {
641 fn name(&self) -> &str {
642 "asin"
643 }
644 fn arg_count(&self) -> usize {
645 1
646 }
647 fn call(&self, args: &[T]) -> T {
648 args[0].asin()
649 }
650}
651
652pub struct Acos;
653impl<T: Float> ScalarFn<T> for Acos {
654 fn name(&self) -> &str {
655 "acos"
656 }
657 fn arg_count(&self) -> usize {
658 1
659 }
660 fn call(&self, args: &[T]) -> T {
661 args[0].acos()
662 }
663}
664
665pub struct Atan;
666impl<T: Float> ScalarFn<T> for Atan {
667 fn name(&self) -> &str {
668 "atan"
669 }
670 fn arg_count(&self) -> usize {
671 1
672 }
673 fn call(&self, args: &[T]) -> T {
674 args[0].atan()
675 }
676}
677
678pub struct Atan2;
679impl<T: Float> ScalarFn<T> for Atan2 {
680 fn name(&self) -> &str {
681 "atan2"
682 }
683 fn arg_count(&self) -> usize {
684 2
685 }
686 fn call(&self, args: &[T]) -> T {
687 args[0].atan2(args[1])
688 }
689}
690
691pub struct Sinh;
692impl<T: Float> ScalarFn<T> for Sinh {
693 fn name(&self) -> &str {
694 "sinh"
695 }
696 fn arg_count(&self) -> usize {
697 1
698 }
699 fn call(&self, args: &[T]) -> T {
700 args[0].sinh()
701 }
702}
703
704pub struct Cosh;
705impl<T: Float> ScalarFn<T> for Cosh {
706 fn name(&self) -> &str {
707 "cosh"
708 }
709 fn arg_count(&self) -> usize {
710 1
711 }
712 fn call(&self, args: &[T]) -> T {
713 args[0].cosh()
714 }
715}
716
717pub struct Tanh;
718impl<T: Float> ScalarFn<T> for Tanh {
719 fn name(&self) -> &str {
720 "tanh"
721 }
722 fn arg_count(&self) -> usize {
723 1
724 }
725 fn call(&self, args: &[T]) -> T {
726 args[0].tanh()
727 }
728}
729
730pub struct Exp;
735impl<T: Float> ScalarFn<T> for Exp {
736 fn name(&self) -> &str {
737 "exp"
738 }
739 fn arg_count(&self) -> usize {
740 1
741 }
742 fn call(&self, args: &[T]) -> T {
743 args[0].exp()
744 }
745}
746
747pub struct Exp2;
748impl<T: Float> ScalarFn<T> for Exp2 {
749 fn name(&self) -> &str {
750 "exp2"
751 }
752 fn arg_count(&self) -> usize {
753 1
754 }
755 fn call(&self, args: &[T]) -> T {
756 args[0].exp2()
757 }
758}
759
760pub struct Log;
761impl<T: Float> ScalarFn<T> for Log {
762 fn name(&self) -> &str {
763 "log"
764 }
765 fn arg_count(&self) -> usize {
766 1
767 }
768 fn call(&self, args: &[T]) -> T {
769 args[0].ln()
770 }
771}
772
773pub struct Ln;
774impl<T: Float> ScalarFn<T> for Ln {
775 fn name(&self) -> &str {
776 "ln"
777 }
778 fn arg_count(&self) -> usize {
779 1
780 }
781 fn call(&self, args: &[T]) -> T {
782 args[0].ln()
783 }
784}
785
786pub struct Log2;
787impl<T: Float> ScalarFn<T> for Log2 {
788 fn name(&self) -> &str {
789 "log2"
790 }
791 fn arg_count(&self) -> usize {
792 1
793 }
794 fn call(&self, args: &[T]) -> T {
795 args[0].log2()
796 }
797}
798
799pub struct Log10;
800impl<T: Float> ScalarFn<T> for Log10 {
801 fn name(&self) -> &str {
802 "log10"
803 }
804 fn arg_count(&self) -> usize {
805 1
806 }
807 fn call(&self, args: &[T]) -> T {
808 args[0].log10()
809 }
810}
811
812pub struct Pow;
813impl<T: Float> ScalarFn<T> for Pow {
814 fn name(&self) -> &str {
815 "pow"
816 }
817 fn arg_count(&self) -> usize {
818 2
819 }
820 fn call(&self, args: &[T]) -> T {
821 args[0].powf(args[1])
822 }
823}
824
825pub struct Sqrt;
826impl<T: Float> ScalarFn<T> for Sqrt {
827 fn name(&self) -> &str {
828 "sqrt"
829 }
830 fn arg_count(&self) -> usize {
831 1
832 }
833 fn call(&self, args: &[T]) -> T {
834 args[0].sqrt()
835 }
836}
837
838pub struct InverseSqrt;
839impl<T: Float> ScalarFn<T> for InverseSqrt {
840 fn name(&self) -> &str {
841 "inversesqrt"
842 }
843 fn arg_count(&self) -> usize {
844 1
845 }
846 fn call(&self, args: &[T]) -> T {
847 T::one() / args[0].sqrt()
848 }
849}
850
851pub struct Abs;
856impl<T: Float> ScalarFn<T> for Abs {
857 fn name(&self) -> &str {
858 "abs"
859 }
860 fn arg_count(&self) -> usize {
861 1
862 }
863 fn call(&self, args: &[T]) -> T {
864 args[0].abs()
865 }
866}
867
868pub struct Sign;
869impl<T: Float> ScalarFn<T> for Sign {
870 fn name(&self) -> &str {
871 "sign"
872 }
873 fn arg_count(&self) -> usize {
874 1
875 }
876 fn call(&self, args: &[T]) -> T {
877 let x = args[0];
878 if x > T::zero() {
879 T::one()
880 } else if x < T::zero() {
881 -T::one()
882 } else {
883 T::zero()
884 }
885 }
886}
887
888pub struct Floor;
889impl<T: Float> ScalarFn<T> for Floor {
890 fn name(&self) -> &str {
891 "floor"
892 }
893 fn arg_count(&self) -> usize {
894 1
895 }
896 fn call(&self, args: &[T]) -> T {
897 args[0].floor()
898 }
899}
900
901pub struct Ceil;
902impl<T: Float> ScalarFn<T> for Ceil {
903 fn name(&self) -> &str {
904 "ceil"
905 }
906 fn arg_count(&self) -> usize {
907 1
908 }
909 fn call(&self, args: &[T]) -> T {
910 args[0].ceil()
911 }
912}
913
914pub struct Round;
915impl<T: Float> ScalarFn<T> for Round {
916 fn name(&self) -> &str {
917 "round"
918 }
919 fn arg_count(&self) -> usize {
920 1
921 }
922 fn call(&self, args: &[T]) -> T {
923 args[0].round()
924 }
925}
926
927pub struct Trunc;
928impl<T: Float> ScalarFn<T> for Trunc {
929 fn name(&self) -> &str {
930 "trunc"
931 }
932 fn arg_count(&self) -> usize {
933 1
934 }
935 fn call(&self, args: &[T]) -> T {
936 args[0].trunc()
937 }
938}
939
940pub struct Fract;
941impl<T: Float> ScalarFn<T> for Fract {
942 fn name(&self) -> &str {
943 "fract"
944 }
945 fn arg_count(&self) -> usize {
946 1
947 }
948 fn call(&self, args: &[T]) -> T {
949 args[0].fract()
950 }
951}
952
953pub struct Min;
954impl<T: Float> ScalarFn<T> for Min {
955 fn name(&self) -> &str {
956 "min"
957 }
958 fn arg_count(&self) -> usize {
959 2
960 }
961 fn call(&self, args: &[T]) -> T {
962 args[0].min(args[1])
963 }
964}
965
966pub struct Max;
967impl<T: Float> ScalarFn<T> for Max {
968 fn name(&self) -> &str {
969 "max"
970 }
971 fn arg_count(&self) -> usize {
972 2
973 }
974 fn call(&self, args: &[T]) -> T {
975 args[0].max(args[1])
976 }
977}
978
979pub struct Clamp;
980impl<T: Float> ScalarFn<T> for Clamp {
981 fn name(&self) -> &str {
982 "clamp"
983 }
984 fn arg_count(&self) -> usize {
985 3
986 }
987 fn call(&self, args: &[T]) -> T {
988 args[0].max(args[1]).min(args[2])
989 }
990}
991
992pub struct Saturate;
993impl<T: Float> ScalarFn<T> for Saturate {
994 fn name(&self) -> &str {
995 "saturate"
996 }
997 fn arg_count(&self) -> usize {
998 1
999 }
1000 fn call(&self, args: &[T]) -> T {
1001 args[0].max(T::zero()).min(T::one())
1002 }
1003}
1004
1005pub struct Lerp;
1011impl<T: Float> ScalarFn<T> for Lerp {
1012 fn name(&self) -> &str {
1013 "lerp"
1014 }
1015 fn arg_count(&self) -> usize {
1016 3
1017 }
1018 fn call(&self, args: &[T]) -> T {
1019 let (a, b, t) = (args[0], args[1], args[2]);
1020 a + (b - a) * t
1021 }
1022}
1023
1024pub struct Mix;
1026impl<T: Float> ScalarFn<T> for Mix {
1027 fn name(&self) -> &str {
1028 "mix"
1029 }
1030 fn arg_count(&self) -> usize {
1031 3
1032 }
1033 fn call(&self, args: &[T]) -> T {
1034 let (a, b, t) = (args[0], args[1], args[2]);
1035 a + (b - a) * t
1036 }
1037}
1038
1039pub struct Step;
1041impl<T: Float> ScalarFn<T> for Step {
1042 fn name(&self) -> &str {
1043 "step"
1044 }
1045 fn arg_count(&self) -> usize {
1046 2
1047 }
1048 fn call(&self, args: &[T]) -> T {
1049 if args[1] < args[0] {
1050 T::zero()
1051 } else {
1052 T::one()
1053 }
1054 }
1055}
1056
1057pub struct Smoothstep;
1059impl<T: Float> ScalarFn<T> for Smoothstep {
1060 fn name(&self) -> &str {
1061 "smoothstep"
1062 }
1063 fn arg_count(&self) -> usize {
1064 3
1065 }
1066 fn call(&self, args: &[T]) -> T {
1067 let (edge0, edge1, x) = (args[0], args[1], args[2]);
1068 let t = ((x - edge0) / (edge1 - edge0)).max(T::zero()).min(T::one());
1069 let three = T::from(3.0).unwrap();
1070 let two = T::from(2.0).unwrap();
1071 t * t * (three - two * t)
1072 }
1073}
1074
1075pub struct InverseLerp;
1077impl<T: Float> ScalarFn<T> for InverseLerp {
1078 fn name(&self) -> &str {
1079 "inverse_lerp"
1080 }
1081 fn arg_count(&self) -> usize {
1082 3
1083 }
1084 fn call(&self, args: &[T]) -> T {
1085 let (a, b, v) = (args[0], args[1], args[2]);
1086 (v - a) / (b - a)
1087 }
1088}
1089
1090pub struct Remap;
1092impl<T: Float> ScalarFn<T> for Remap {
1093 fn name(&self) -> &str {
1094 "remap"
1095 }
1096 fn arg_count(&self) -> usize {
1097 5
1098 }
1099 fn call(&self, args: &[T]) -> T {
1100 let (x, in_lo, in_hi, out_lo, out_hi) = (args[0], args[1], args[2], args[3], args[4]);
1101 let t = (x - in_lo) / (in_hi - in_lo);
1102 out_lo + (out_hi - out_lo) * t
1103 }
1104}
1105
1106pub struct Xor;
1112impl<T: PrimInt> ScalarFn<T> for Xor {
1113 fn name(&self) -> &str {
1114 "xor"
1115 }
1116 fn arg_count(&self) -> usize {
1117 2
1118 }
1119 fn call(&self, args: &[T]) -> T {
1120 args[0] ^ args[1]
1121 }
1122}
1123
1124pub struct AbsInt;
1126impl<T: PrimInt> ScalarFn<T> for AbsInt {
1127 fn name(&self) -> &str {
1128 "abs"
1129 }
1130 fn arg_count(&self) -> usize {
1131 1
1132 }
1133 fn call(&self, args: &[T]) -> T {
1134 let x = args[0];
1135 if x < T::zero() { T::zero() - x } else { x }
1136 }
1137}
1138
1139pub struct MinInt;
1141impl<T: PrimInt> ScalarFn<T> for MinInt {
1142 fn name(&self) -> &str {
1143 "min"
1144 }
1145 fn arg_count(&self) -> usize {
1146 2
1147 }
1148 fn call(&self, args: &[T]) -> T {
1149 if args[0] < args[1] { args[0] } else { args[1] }
1150 }
1151}
1152
1153pub struct MaxInt;
1155impl<T: PrimInt> ScalarFn<T> for MaxInt {
1156 fn name(&self) -> &str {
1157 "max"
1158 }
1159 fn arg_count(&self) -> usize {
1160 2
1161 }
1162 fn call(&self, args: &[T]) -> T {
1163 if args[0] > args[1] { args[0] } else { args[1] }
1164 }
1165}
1166
1167pub struct ClampInt;
1169impl<T: PrimInt> ScalarFn<T> for ClampInt {
1170 fn name(&self) -> &str {
1171 "clamp"
1172 }
1173 fn arg_count(&self) -> usize {
1174 3
1175 }
1176 fn call(&self, args: &[T]) -> T {
1177 let (x, lo, hi) = (args[0], args[1], args[2]);
1178 if x < lo {
1179 lo
1180 } else if x > hi {
1181 hi
1182 } else {
1183 x
1184 }
1185 }
1186}
1187
1188pub struct SignInt;
1190impl<T: PrimInt> ScalarFn<T> for SignInt {
1191 fn name(&self) -> &str {
1192 "sign"
1193 }
1194 fn arg_count(&self) -> usize {
1195 1
1196 }
1197 fn call(&self, args: &[T]) -> T {
1198 let x = args[0];
1199 if x > T::zero() {
1200 T::one()
1201 } else if x < T::zero() {
1202 T::zero() - T::one()
1203 } else {
1204 T::zero()
1205 }
1206 }
1207}
1208
1209pub fn register_scalar<T: Float + 'static>(registry: &mut FunctionRegistry<T>) {
1215 registry.register(Pi);
1217 registry.register(E);
1218 registry.register(Tau);
1219
1220 registry.register(Sin);
1222 registry.register(Cos);
1223 registry.register(Tan);
1224 registry.register(Asin);
1225 registry.register(Acos);
1226 registry.register(Atan);
1227 registry.register(Atan2);
1228 registry.register(Sinh);
1229 registry.register(Cosh);
1230 registry.register(Tanh);
1231
1232 registry.register(Exp);
1234 registry.register(Exp2);
1235 registry.register(Log);
1236 registry.register(Ln);
1237 registry.register(Log2);
1238 registry.register(Log10);
1239 registry.register(Pow);
1240 registry.register(Sqrt);
1241 registry.register(InverseSqrt);
1242
1243 registry.register(Abs);
1245 registry.register(Sign);
1246 registry.register(Floor);
1247 registry.register(Ceil);
1248 registry.register(Round);
1249 registry.register(Trunc);
1250 registry.register(Fract);
1251 registry.register(Min);
1252 registry.register(Max);
1253 registry.register(Clamp);
1254 registry.register(Saturate);
1255
1256 registry.register(Lerp);
1258 registry.register(Mix);
1259 registry.register(Step);
1260 registry.register(Smoothstep);
1261 registry.register(InverseLerp);
1262 registry.register(Remap);
1263}
1264
1265pub fn scalar_registry<T: Float + 'static>() -> FunctionRegistry<T> {
1267 let mut registry = FunctionRegistry::new();
1268 register_scalar(&mut registry);
1269 registry
1270}
1271
1272pub fn register_scalar_int<T: PrimInt + 'static>(registry: &mut FunctionRegistry<T>) {
1276 registry.register(AbsInt);
1277 registry.register(MinInt);
1278 registry.register(MaxInt);
1279 registry.register(ClampInt);
1280 registry.register(SignInt);
1281 registry.register(Xor);
1282}
1283
1284pub fn scalar_registry_int<T: PrimInt + 'static>() -> FunctionRegistry<T> {
1286 let mut registry = FunctionRegistry::new();
1287 register_scalar_int(&mut registry);
1288 registry
1289}
1290
1291#[cfg(test)]
1296mod tests {
1297 use super::*;
1298 use wick_core::Expr;
1299
1300 fn eval_expr(expr: &str, vars: &[(&str, f32)]) -> f32 {
1301 let registry = scalar_registry();
1302 let expr = Expr::parse(expr).unwrap();
1303 let var_map: HashMap<String, f32> = vars.iter().map(|(k, v)| (k.to_string(), *v)).collect();
1304 eval(expr.ast(), &var_map, ®istry).unwrap()
1305 }
1306
1307 #[test]
1308 fn test_constants() {
1309 assert!((eval_expr("pi()", &[]) - std::f32::consts::PI).abs() < 0.001);
1310 assert!((eval_expr("e()", &[]) - std::f32::consts::E).abs() < 0.001);
1311 assert!((eval_expr("tau()", &[]) - std::f32::consts::TAU).abs() < 0.001);
1312 }
1313
1314 #[test]
1315 fn test_trig() {
1316 assert!(eval_expr("sin(0)", &[]).abs() < 0.001);
1317 assert!((eval_expr("cos(0)", &[]) - 1.0).abs() < 0.001);
1318 }
1319
1320 #[test]
1321 fn test_exp_log() {
1322 assert!((eval_expr("exp(0)", &[]) - 1.0).abs() < 0.001);
1323 assert!((eval_expr("ln(1)", &[]) - 0.0).abs() < 0.001);
1324 assert!((eval_expr("sqrt(16)", &[]) - 4.0).abs() < 0.001);
1325 }
1326
1327 #[test]
1328 fn test_common() {
1329 assert_eq!(eval_expr("abs(-5)", &[]), 5.0);
1330 assert_eq!(eval_expr("floor(3.7)", &[]), 3.0);
1331 assert_eq!(eval_expr("ceil(3.2)", &[]), 4.0);
1332 assert_eq!(eval_expr("min(3, 7)", &[]), 3.0);
1333 assert_eq!(eval_expr("max(3, 7)", &[]), 7.0);
1334 assert_eq!(eval_expr("clamp(5, 0, 3)", &[]), 3.0);
1335 assert_eq!(eval_expr("saturate(1.5)", &[]), 1.0);
1336 }
1337
1338 #[test]
1339 fn test_interpolation() {
1340 assert_eq!(eval_expr("lerp(0, 10, 0.5)", &[]), 5.0);
1341 assert_eq!(eval_expr("mix(0, 10, 0.5)", &[]), 5.0);
1342 assert_eq!(eval_expr("step(0.5, 0.3)", &[]), 0.0);
1343 assert_eq!(eval_expr("step(0.5, 0.7)", &[]), 1.0);
1344 assert!((eval_expr("smoothstep(0, 1, 0.5)", &[]) - 0.5).abs() < 0.1);
1345 assert_eq!(eval_expr("inverse_lerp(0, 10, 5)", &[]), 0.5);
1346 }
1347
1348 #[test]
1349 fn test_remap() {
1350 assert_eq!(eval_expr("remap(5, 0, 10, 0, 100)", &[]), 50.0);
1351 }
1352
1353 #[test]
1354 fn test_with_variables() {
1355 let v = eval_expr("sin(x * pi())", &[("x", 0.5)]);
1356 assert!((v - 1.0).abs() < 0.001);
1357 }
1358
1359 #[test]
1360 fn test_f64() {
1361 let registry: FunctionRegistry<f64> = scalar_registry();
1362 let expr = Expr::parse("sin(x) + 1").unwrap();
1363 let vars: HashMap<String, f64> = [("x".to_string(), 0.0)].into();
1364 let result = eval(expr.ast(), &vars, ®istry).unwrap();
1365 assert!((result - 1.0).abs() < 0.001);
1366 }
1367
1368 mod int_tests {
1370 use super::*;
1371
1372 fn eval_int_expr(expr_str: &str, vars: &[(&str, i32)]) -> i32 {
1373 let registry = scalar_registry_int();
1374 let expr = Expr::parse(expr_str).unwrap();
1375 let var_map: HashMap<String, i32> =
1376 vars.iter().map(|(k, v)| (k.to_string(), *v)).collect();
1377 eval_int(expr.ast(), &var_map, ®istry).unwrap()
1378 }
1379
1380 #[test]
1381 fn test_int_arithmetic() {
1382 assert_eq!(eval_int_expr("5 + 3", &[]), 8);
1383 assert_eq!(eval_int_expr("10 - 4", &[]), 6);
1384 assert_eq!(eval_int_expr("6 * 7", &[]), 42);
1385 assert_eq!(eval_int_expr("15 / 4", &[]), 3); }
1387
1388 #[test]
1389 fn test_int_modulo() {
1390 assert_eq!(eval_int_expr("8 % 3", &[]), 2);
1391 assert_eq!(eval_int_expr("10 % 5", &[]), 0);
1392 assert_eq!(eval_int_expr("17 % 7", &[]), 3);
1393 }
1394
1395 #[test]
1396 fn test_int_power() {
1397 assert_eq!(eval_int_expr("2 ^ 3", &[]), 8);
1398 assert_eq!(eval_int_expr("3 ^ 4", &[]), 81);
1399 assert_eq!(eval_int_expr("5 ^ 0", &[]), 1);
1400 }
1401
1402 #[test]
1403 fn test_int_bitwise() {
1404 assert_eq!(eval_int_expr("5 & 3", &[]), 1); assert_eq!(eval_int_expr("5 | 3", &[]), 7); assert_eq!(eval_int_expr("xor(5, 3)", &[]), 6); assert_eq!(eval_int_expr("1 << 4", &[]), 16);
1408 assert_eq!(eval_int_expr("16 >> 2", &[]), 4);
1409 }
1410
1411 #[test]
1412 fn test_int_bitnot() {
1413 assert_eq!(eval_int_expr("~0", &[]), -1);
1415 }
1416
1417 #[test]
1418 fn test_int_functions() {
1419 assert_eq!(eval_int_expr("abs(-5)", &[]), 5);
1420 assert_eq!(eval_int_expr("min(3, 7)", &[]), 3);
1421 assert_eq!(eval_int_expr("max(3, 7)", &[]), 7);
1422 assert_eq!(eval_int_expr("clamp(5, 0, 3)", &[]), 3);
1423 assert_eq!(eval_int_expr("sign(-10)", &[]), -1);
1424 assert_eq!(eval_int_expr("sign(10)", &[]), 1);
1425 assert_eq!(eval_int_expr("sign(0)", &[]), 0);
1426 }
1427
1428 #[test]
1429 fn test_int_with_variables() {
1430 assert_eq!(eval_int_expr("x + y", &[("x", 5), ("y", 3)]), 8);
1431 assert_eq!(
1432 eval_int_expr("steps % beats", &[("steps", 8), ("beats", 3)]),
1433 2
1434 );
1435 }
1436
1437 #[test]
1438 fn test_int_fractional_literal_error() {
1439 let registry: FunctionRegistry<i32> = scalar_registry_int();
1440 let expr = Expr::parse("3.14 + 1").unwrap();
1441 let result = eval_int(expr.ast(), &HashMap::new(), ®istry);
1442 assert!(matches!(result, Err(Error::InvalidLiteral(_))));
1443 }
1444
1445 #[test]
1446 fn test_int_negative_exponent_error() {
1447 let registry: FunctionRegistry<i32> = scalar_registry_int();
1448 let expr = Expr::parse("2 ^ -1").unwrap();
1449 let vars: HashMap<String, i32> = HashMap::new();
1450 let result = eval_int(expr.ast(), &vars, ®istry);
1451 assert!(matches!(result, Err(Error::NegativeExponent)));
1452 }
1453 }
1454}