1use std::collections::HashMap;
78use std::sync::Arc;
79use wick_core::{Ast, BinOp, CompareOp, Numeric, UnaryOp};
80
81mod funcs;
82pub mod ops;
83
84#[cfg(feature = "wgsl")]
85pub mod wgsl;
86
87#[cfg(feature = "glsl")]
88pub mod glsl;
89
90#[cfg(feature = "rust")]
91pub mod rust;
92
93#[cfg(feature = "c")]
94pub mod c;
95
96#[cfg(feature = "opencl")]
97pub mod opencl;
98
99#[cfg(feature = "cuda")]
100pub mod cuda;
101
102#[cfg(feature = "hip")]
103pub mod hip;
104
105#[cfg(feature = "tokenstream")]
106pub mod tokenstream;
107
108#[cfg(feature = "lua-codegen")]
109pub mod lua;
110
111#[cfg(feature = "cranelift")]
112pub mod cranelift;
113
114#[cfg(feature = "optimize")]
115pub mod optimize;
116
117#[cfg(feature = "3d")]
118pub use funcs::Cross;
119pub use funcs::{
120 Distance, Dot, Hadamard, Length, Lerp, Mix, Normalize, Reflect, linalg_registry,
121 linalg_registry_int, register_linalg, register_linalg_numeric,
122};
123
124#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
130pub enum Type {
131 Scalar,
132 Vec2,
133 #[cfg(feature = "3d")]
134 Vec3,
135 #[cfg(feature = "4d")]
136 Vec4,
137 Mat2,
138 #[cfg(feature = "3d")]
139 Mat3,
140 #[cfg(feature = "4d")]
141 Mat4,
142}
143
144impl std::fmt::Display for Type {
145 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146 match self {
147 Type::Scalar => write!(f, "scalar"),
148 Type::Vec2 => write!(f, "vec2"),
149 #[cfg(feature = "3d")]
150 Type::Vec3 => write!(f, "vec3"),
151 #[cfg(feature = "4d")]
152 Type::Vec4 => write!(f, "vec4"),
153 Type::Mat2 => write!(f, "mat2"),
154 #[cfg(feature = "3d")]
155 Type::Mat3 => write!(f, "mat3"),
156 #[cfg(feature = "4d")]
157 Type::Mat4 => write!(f, "mat4"),
158 }
159 }
160}
161
162pub trait LinalgValue<T: Numeric>: Clone + PartialEq + Sized + std::fmt::Debug {
186 fn typ(&self) -> Type;
188
189 fn from_scalar(v: T) -> Self;
191 fn from_vec2(v: [T; 2]) -> Self;
192 #[cfg(feature = "3d")]
193 fn from_vec3(v: [T; 3]) -> Self;
194 #[cfg(feature = "4d")]
195 fn from_vec4(v: [T; 4]) -> Self;
196 fn from_mat2(v: [T; 4]) -> Self;
197 #[cfg(feature = "3d")]
198 fn from_mat3(v: [T; 9]) -> Self;
199 #[cfg(feature = "4d")]
200 fn from_mat4(v: [T; 16]) -> Self;
201
202 fn as_scalar(&self) -> Option<T>;
204 fn as_vec2(&self) -> Option<[T; 2]>;
205 #[cfg(feature = "3d")]
206 fn as_vec3(&self) -> Option<[T; 3]>;
207 #[cfg(feature = "4d")]
208 fn as_vec4(&self) -> Option<[T; 4]>;
209 fn as_mat2(&self) -> Option<[T; 4]>;
210 #[cfg(feature = "3d")]
211 fn as_mat3(&self) -> Option<[T; 9]>;
212 #[cfg(feature = "4d")]
213 fn as_mat4(&self) -> Option<[T; 16]>;
214}
215
216#[derive(Debug, Clone, PartialEq)]
226pub enum Value<T> {
227 Scalar(T),
228 Vec2([T; 2]),
229 #[cfg(feature = "3d")]
230 Vec3([T; 3]),
231 #[cfg(feature = "4d")]
232 Vec4([T; 4]),
233 Mat2([T; 4]), #[cfg(feature = "3d")]
235 Mat3([T; 9]), #[cfg(feature = "4d")]
237 Mat4([T; 16]), }
239
240impl<T> Value<T> {
242 pub fn typ(&self) -> Type {
244 match self {
245 Value::Scalar(_) => Type::Scalar,
246 Value::Vec2(_) => Type::Vec2,
247 #[cfg(feature = "3d")]
248 Value::Vec3(_) => Type::Vec3,
249 #[cfg(feature = "4d")]
250 Value::Vec4(_) => Type::Vec4,
251 Value::Mat2(_) => Type::Mat2,
252 #[cfg(feature = "3d")]
253 Value::Mat3(_) => Type::Mat3,
254 #[cfg(feature = "4d")]
255 Value::Mat4(_) => Type::Mat4,
256 }
257 }
258}
259
260impl<T: Copy> Value<T> {
261 pub fn as_scalar(&self) -> Option<T> {
263 match self {
264 Value::Scalar(v) => Some(*v),
265 _ => None,
266 }
267 }
268}
269
270impl<T: Numeric> LinalgValue<T> for Value<T> {
271 fn typ(&self) -> Type {
272 Value::typ(self)
274 }
275
276 fn from_scalar(v: T) -> Self {
277 Value::Scalar(v)
278 }
279 fn from_vec2(v: [T; 2]) -> Self {
280 Value::Vec2(v)
281 }
282 #[cfg(feature = "3d")]
283 fn from_vec3(v: [T; 3]) -> Self {
284 Value::Vec3(v)
285 }
286 #[cfg(feature = "4d")]
287 fn from_vec4(v: [T; 4]) -> Self {
288 Value::Vec4(v)
289 }
290 fn from_mat2(v: [T; 4]) -> Self {
291 Value::Mat2(v)
292 }
293 #[cfg(feature = "3d")]
294 fn from_mat3(v: [T; 9]) -> Self {
295 Value::Mat3(v)
296 }
297 #[cfg(feature = "4d")]
298 fn from_mat4(v: [T; 16]) -> Self {
299 Value::Mat4(v)
300 }
301
302 fn as_scalar(&self) -> Option<T> {
303 match self {
304 Value::Scalar(v) => Some(*v),
305 _ => None,
306 }
307 }
308 fn as_vec2(&self) -> Option<[T; 2]> {
309 match self {
310 Value::Vec2(v) => Some(*v),
311 _ => None,
312 }
313 }
314 #[cfg(feature = "3d")]
315 fn as_vec3(&self) -> Option<[T; 3]> {
316 match self {
317 Value::Vec3(v) => Some(*v),
318 _ => None,
319 }
320 }
321 #[cfg(feature = "4d")]
322 fn as_vec4(&self) -> Option<[T; 4]> {
323 match self {
324 Value::Vec4(v) => Some(*v),
325 _ => None,
326 }
327 }
328 fn as_mat2(&self) -> Option<[T; 4]> {
329 match self {
330 Value::Mat2(v) => Some(*v),
331 _ => None,
332 }
333 }
334 #[cfg(feature = "3d")]
335 fn as_mat3(&self) -> Option<[T; 9]> {
336 match self {
337 Value::Mat3(v) => Some(*v),
338 _ => None,
339 }
340 }
341 #[cfg(feature = "4d")]
342 fn as_mat4(&self) -> Option<[T; 16]> {
343 match self {
344 Value::Mat4(v) => Some(*v),
345 _ => None,
346 }
347 }
348}
349
350#[derive(Debug, Clone, PartialEq)]
356pub enum Error {
357 UnknownVariable(String),
359 UnknownFunction(String),
361 BinaryTypeMismatch { op: BinOp, left: Type, right: Type },
363 UnaryTypeMismatch { op: UnaryOp, operand: Type },
365 WrongArgCount {
367 func: String,
368 expected: usize,
369 got: usize,
370 },
371 FunctionTypeMismatch {
373 func: String,
374 expected: Vec<Type>,
375 got: Vec<Type>,
376 },
377 UnsupportedTypeForConditional(Type),
379 NegativeExponent,
381}
382
383impl std::fmt::Display for Error {
384 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
385 match self {
386 Error::UnknownVariable(name) => write!(f, "unknown variable: '{name}'"),
387 Error::UnknownFunction(name) => write!(f, "unknown function: '{name}'"),
388 Error::BinaryTypeMismatch { op, left, right } => {
389 write!(f, "cannot apply {op:?} to {left} and {right}")
390 }
391 Error::UnaryTypeMismatch { op, operand } => {
392 write!(f, "cannot apply {op:?} to {operand}")
393 }
394 Error::WrongArgCount {
395 func,
396 expected,
397 got,
398 } => {
399 write!(f, "function '{func}' expects {expected} args, got {got}")
400 }
401 Error::FunctionTypeMismatch {
402 func,
403 expected,
404 got,
405 } => {
406 write!(
407 f,
408 "function '{func}' expects types {expected:?}, got {got:?}"
409 )
410 }
411 Error::UnsupportedTypeForConditional(t) => {
412 write!(f, "conditionals require scalar type, got {t}")
413 }
414 Error::NegativeExponent => {
415 write!(f, "negative exponent not supported for integer types")
416 }
417 }
418 }
419}
420
421impl std::error::Error for Error {}
422
423#[derive(Debug, Clone, PartialEq)]
429pub struct Signature {
430 pub args: Vec<Type>,
431 pub ret: Type,
432}
433
434pub trait LinalgFn<T, V>: Send + Sync
439where
440 T: Numeric,
441 V: LinalgValue<T>,
442{
443 fn name(&self) -> &str;
445
446 fn signatures(&self) -> Vec<Signature>;
448
449 fn call(&self, args: &[V]) -> V;
452}
453
454#[derive(Clone)]
456pub struct FunctionRegistry<T, V>
457where
458 T: Numeric,
459 V: LinalgValue<T>,
460{
461 funcs: HashMap<String, Arc<dyn LinalgFn<T, V>>>,
462}
463
464impl<T, V> Default for FunctionRegistry<T, V>
465where
466 T: Numeric,
467 V: LinalgValue<T>,
468{
469 fn default() -> Self {
470 Self {
471 funcs: HashMap::new(),
472 }
473 }
474}
475
476impl<T, V> FunctionRegistry<T, V>
477where
478 T: Numeric,
479 V: LinalgValue<T>,
480{
481 pub fn new() -> Self {
482 Self::default()
483 }
484
485 pub fn register<F: LinalgFn<T, V> + 'static>(&mut self, func: F) {
486 self.funcs.insert(func.name().to_string(), Arc::new(func));
487 }
488
489 pub fn get(&self, name: &str) -> Option<&Arc<dyn LinalgFn<T, V>>> {
490 self.funcs.get(name)
491 }
492}
493
494pub fn eval<T, V>(
505 ast: &Ast,
506 vars: &HashMap<String, V>,
507 funcs: &FunctionRegistry<T, V>,
508) -> Result<V, Error>
509where
510 T: Numeric,
511 V: LinalgValue<T>,
512{
513 match ast {
514 Ast::Num(n) => {
515 Ok(V::from_scalar(T::from(*n).unwrap()))
517 }
518
519 Ast::Var(name) => vars
520 .get(name)
521 .cloned()
522 .ok_or_else(|| Error::UnknownVariable(name.clone())),
523
524 Ast::BinOp(op, left, right) => {
525 let left_val = eval(left, vars, funcs)?;
526 let right_val = eval(right, vars, funcs)?;
527 ops::apply_binop(*op, left_val, right_val)
528 }
529
530 Ast::UnaryOp(op, inner) => {
531 let val = eval(inner, vars, funcs)?;
532 ops::apply_unaryop(*op, val)
533 }
534
535 Ast::Call(name, args) => {
536 let func = funcs
537 .get(name)
538 .ok_or_else(|| Error::UnknownFunction(name.clone()))?;
539
540 let arg_vals: Vec<V> = args
541 .iter()
542 .map(|a| eval(a, vars, funcs))
543 .collect::<Result<_, _>>()?;
544
545 let arg_types: Vec<Type> = arg_vals.iter().map(|v| v.typ()).collect();
546
547 let matched = func.signatures().iter().any(|sig| sig.args == arg_types);
549 if !matched {
550 return Err(Error::FunctionTypeMismatch {
551 func: name.clone(),
552 expected: func
553 .signatures()
554 .first()
555 .map(|s| s.args.clone())
556 .unwrap_or_default(),
557 got: arg_types,
558 });
559 }
560
561 Ok(func.call(&arg_vals))
562 }
563
564 Ast::Compare(op, left, right) => {
565 let left_val = eval(left, vars, funcs)?;
566 let right_val = eval(right, vars, funcs)?;
567 match (left_val.as_scalar(), right_val.as_scalar()) {
569 (Some(l), Some(r)) => {
570 let result = match op {
571 CompareOp::Lt => l < r,
572 CompareOp::Le => l <= r,
573 CompareOp::Gt => l > r,
574 CompareOp::Ge => l >= r,
575 CompareOp::Eq => l == r,
576 CompareOp::Ne => l != r,
577 };
578 Ok(V::from_scalar(if result { T::one() } else { T::zero() }))
579 }
580 _ => Err(Error::UnsupportedTypeForConditional(left_val.typ())),
581 }
582 }
583
584 Ast::And(left, right) => {
585 let left_val = eval(left, vars, funcs)?;
586 let right_val = eval(right, vars, funcs)?;
587 match (left_val.as_scalar(), right_val.as_scalar()) {
588 (Some(l), Some(r)) => {
589 let result = !l.is_zero() && !r.is_zero();
590 Ok(V::from_scalar(if result { T::one() } else { T::zero() }))
591 }
592 _ => Err(Error::UnsupportedTypeForConditional(left_val.typ())),
593 }
594 }
595
596 Ast::Or(left, right) => {
597 let left_val = eval(left, vars, funcs)?;
598 let right_val = eval(right, vars, funcs)?;
599 match (left_val.as_scalar(), right_val.as_scalar()) {
600 (Some(l), Some(r)) => {
601 let result = !l.is_zero() || !r.is_zero();
602 Ok(V::from_scalar(if result { T::one() } else { T::zero() }))
603 }
604 _ => Err(Error::UnsupportedTypeForConditional(left_val.typ())),
605 }
606 }
607
608 Ast::If(cond, then_ast, else_ast) => {
609 let cond_val = eval(cond, vars, funcs)?;
610 match cond_val.as_scalar() {
611 Some(c) => {
612 if !c.is_zero() {
613 eval(then_ast, vars, funcs)
614 } else {
615 eval(else_ast, vars, funcs)
616 }
617 }
618 None => Err(Error::UnsupportedTypeForConditional(cond_val.typ())),
619 }
620 }
621
622 Ast::Let { name, value, body } => {
623 let val = eval(value, vars, funcs)?;
624 let mut new_vars = vars.clone();
625 new_vars.insert(name.clone(), val);
626 eval(body, &new_vars, funcs)
627 }
628 }
629}
630
631#[cfg(test)]
636mod exhaustive_tests;
637
638#[cfg(test)]
639mod parity_tests;
640
641#[cfg(test)]
642mod tests {
643 use super::*;
644 use wick_core::Expr;
645
646 fn eval_expr(expr: &str, vars: &[(&str, Value<f32>)]) -> Result<Value<f32>, Error> {
647 let expr = Expr::parse(expr).unwrap();
648 let var_map: HashMap<String, Value<f32>> = vars
649 .iter()
650 .map(|(k, v)| (k.to_string(), v.clone()))
651 .collect();
652 let registry = FunctionRegistry::new();
653 eval(expr.ast(), &var_map, ®istry)
654 }
655
656 #[test]
657 fn test_scalar_add() {
658 let result = eval_expr(
659 "a + b",
660 &[("a", Value::Scalar(1.0)), ("b", Value::Scalar(2.0))],
661 );
662 assert_eq!(result.unwrap(), Value::Scalar(3.0));
663 }
664
665 #[test]
666 fn test_vec2_add() {
667 let result = eval_expr(
668 "a + b",
669 &[
670 ("a", Value::Vec2([1.0, 2.0])),
671 ("b", Value::Vec2([3.0, 4.0])),
672 ],
673 );
674 assert_eq!(result.unwrap(), Value::Vec2([4.0, 6.0]));
675 }
676
677 #[test]
678 fn test_vec2_scalar_mul() {
679 let result = eval_expr(
680 "v * s",
681 &[("v", Value::Vec2([2.0, 3.0])), ("s", Value::Scalar(2.0))],
682 );
683 assert_eq!(result.unwrap(), Value::Vec2([4.0, 6.0]));
684 }
685
686 #[test]
687 fn test_scalar_vec2_mul() {
688 let result = eval_expr(
689 "s * v",
690 &[("s", Value::Scalar(2.0)), ("v", Value::Vec2([2.0, 3.0]))],
691 );
692 assert_eq!(result.unwrap(), Value::Vec2([4.0, 6.0]));
693 }
694
695 #[test]
696 fn test_vec2_neg() {
697 let result = eval_expr("-v", &[("v", Value::Vec2([1.0, -2.0]))]);
698 assert_eq!(result.unwrap(), Value::Vec2([-1.0, 2.0]));
699 }
700
701 #[cfg(feature = "3d")]
702 #[test]
703 fn test_vec3_add() {
704 let result = eval_expr(
705 "a + b",
706 &[
707 ("a", Value::Vec3([1.0, 2.0, 3.0])),
708 ("b", Value::Vec3([4.0, 5.0, 6.0])),
709 ],
710 );
711 assert_eq!(result.unwrap(), Value::Vec3([5.0, 7.0, 9.0]));
712 }
713
714 #[test]
715 fn test_type_mismatch() {
716 let result = eval_expr(
717 "a + b",
718 &[("a", Value::Scalar(1.0)), ("b", Value::Vec2([1.0, 2.0]))],
719 );
720 assert!(matches!(result, Err(Error::BinaryTypeMismatch { .. })));
721 }
722
723 #[test]
724 fn test_literal_conversion() {
725 let expr = Expr::parse("a + 1.5").unwrap();
727 let mut vars: HashMap<String, Value<f64>> = HashMap::new();
728 vars.insert("a".to_string(), Value::Scalar(2.5));
729 let registry = FunctionRegistry::new();
730 let result = eval(expr.ast(), &vars, ®istry).unwrap();
731 assert_eq!(result, Value::Scalar(4.0));
732 }
733}