Skip to main content

wick_linalg/
lib.rs

1//! Linear algebra types and operations for dew expressions.
2//!
3//! This crate provides vector and matrix types (Vec2, Vec3, Mat2, Mat3, etc.)
4//! that work with dew-core's AST. Types propagate during evaluation/emission.
5//!
6//! # Quick Start
7//!
8//! ```
9//! use wick_core::Expr;
10//! use wick_linalg::{Value, eval, linalg_registry};
11//! use std::collections::HashMap;
12//!
13//! let expr = Expr::parse("dot(a, b)").unwrap();
14//!
15//! let vars: HashMap<String, Value<f32>> = [
16//!     ("a".into(), Value::Vec2([1.0, 0.0])),
17//!     ("b".into(), Value::Vec2([0.0, 1.0])),
18//! ].into();
19//!
20//! let result = eval(expr.ast(), &vars, &linalg_registry()).unwrap();
21//! assert_eq!(result, Value::Scalar(0.0)); // perpendicular vectors
22//! ```
23//!
24//! # Features
25//!
26//! | Feature | Description                    |
27//! |---------|--------------------------------|
28//! | `3d`    | Vec3, Mat3 (default)           |
29//! | `4d`    | Vec4, Mat4 (implies 3d)        |
30//! | `wgsl`  | WGSL shader code generation    |
31//! | `lua`   | Lua code generation            |
32//! | `cranelift` | Cranelift JIT compilation  |
33//!
34//! # Types
35//!
36//! | Type     | Description                    |
37//! |----------|--------------------------------|
38//! | `Scalar` | Single f32/f64 value           |
39//! | `Vec2`   | 2D vector [x, y]               |
40//! | `Vec3`   | 3D vector [x, y, z] (3d)       |
41//! | `Vec4`   | 4D vector [x, y, z, w] (4d)    |
42//! | `Mat2`   | 2x2 matrix, column-major       |
43//! | `Mat3`   | 3x3 matrix, column-major (3d)  |
44//! | `Mat4`   | 4x4 matrix, column-major (4d)  |
45//!
46//! # Functions
47//!
48//! | Function           | Description                              |
49//! |--------------------|------------------------------------------|
50//! | `dot(a, b)`        | Dot product → scalar                     |
51//! | `cross(a, b)`      | Cross product → vec3 (3d only)           |
52//! | `length(v)`        | Vector magnitude → scalar                |
53//! | `normalize(v)`     | Unit vector → same type                  |
54//! | `distance(a, b)`   | Distance between points → scalar         |
55//! | `reflect(v, n)`    | Reflect v around normal n → same type    |
56//! | `hadamard(a, b)`   | Element-wise multiply → same type        |
57//! | `lerp(a, b, t)`    | Linear interpolation                     |
58//! | `mix(a, b, t)`     | Alias for lerp                           |
59//!
60//! # Operators
61//!
62//! | Operation          | Types                           |
63//! |--------------------|---------------------------------|
64//! | `vec + vec`        | Component-wise addition         |
65//! | `vec - vec`        | Component-wise subtraction      |
66//! | `vec * scalar`     | Scalar multiplication           |
67//! | `scalar * vec`     | Scalar multiplication           |
68//! | `mat * vec`        | Matrix-vector multiplication    |
69//! | `mat * mat`        | Matrix multiplication           |
70//! | `-vec`             | Negation                        |
71//!
72//! # Composability
73//!
74//! For composing multiple domain crates (e.g., linalg + rotors), the [`LinalgValue`]
75//! trait allows defining a combined value type that works with both crates.
76
77use std::collections::HashMap;
78use std::sync::Arc;
79use wick_core::{Ast, BinOp, CompareOp, Numeric, UnaryOp};
80
81mod funcs;
82pub mod ops;
83
84#[cfg(feature = "wgsl")]
85pub mod wgsl;
86
87#[cfg(feature = "glsl")]
88pub mod glsl;
89
90#[cfg(feature = "rust")]
91pub mod rust;
92
93#[cfg(feature = "c")]
94pub mod c;
95
96#[cfg(feature = "opencl")]
97pub mod opencl;
98
99#[cfg(feature = "cuda")]
100pub mod cuda;
101
102#[cfg(feature = "hip")]
103pub mod hip;
104
105#[cfg(feature = "tokenstream")]
106pub mod tokenstream;
107
108#[cfg(feature = "lua-codegen")]
109pub mod lua;
110
111#[cfg(feature = "cranelift")]
112pub mod cranelift;
113
114#[cfg(feature = "optimize")]
115pub mod optimize;
116
117#[cfg(feature = "3d")]
118pub use funcs::Cross;
119pub use funcs::{
120    Distance, Dot, Hadamard, Length, Lerp, Mix, Normalize, Reflect, linalg_registry,
121    linalg_registry_int, register_linalg, register_linalg_numeric,
122};
123
124// ============================================================================
125// Types
126// ============================================================================
127
128/// Type of a linalg value (shape only, not numeric type).
129#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
130pub enum Type {
131    Scalar,
132    Vec2,
133    #[cfg(feature = "3d")]
134    Vec3,
135    #[cfg(feature = "4d")]
136    Vec4,
137    Mat2,
138    #[cfg(feature = "3d")]
139    Mat3,
140    #[cfg(feature = "4d")]
141    Mat4,
142}
143
144impl std::fmt::Display for Type {
145    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146        match self {
147            Type::Scalar => write!(f, "scalar"),
148            Type::Vec2 => write!(f, "vec2"),
149            #[cfg(feature = "3d")]
150            Type::Vec3 => write!(f, "vec3"),
151            #[cfg(feature = "4d")]
152            Type::Vec4 => write!(f, "vec4"),
153            Type::Mat2 => write!(f, "mat2"),
154            #[cfg(feature = "3d")]
155            Type::Mat3 => write!(f, "mat3"),
156            #[cfg(feature = "4d")]
157            Type::Mat4 => write!(f, "mat4"),
158        }
159    }
160}
161
162// ============================================================================
163// LinalgValue trait (Option 3: generic over value type)
164// ============================================================================
165
166/// Trait for values that support linalg operations.
167///
168/// This enables composing multiple domain crates. Users can define their own
169/// combined enum implementing traits from each domain, then use both crates
170/// with zero conversion.
171///
172/// # Example (composing linalg + rotors)
173///
174/// ```ignore
175/// enum CombinedValue<T> {
176///     Scalar(T),
177///     Vec2([T; 2]),
178///     Vec3([T; 3]),
179///     Rotor2(Rotor2<T>),
180/// }
181///
182/// impl<T: Numeric> LinalgValue<T> for CombinedValue<T> { ... }
183/// impl<T: Numeric> RotorValue<T> for CombinedValue<T> { ... }
184/// ```
185pub trait LinalgValue<T: Numeric>: Clone + PartialEq + Sized + std::fmt::Debug {
186    /// Returns the type of this value.
187    fn typ(&self) -> Type;
188
189    // Construction
190    fn from_scalar(v: T) -> Self;
191    fn from_vec2(v: [T; 2]) -> Self;
192    #[cfg(feature = "3d")]
193    fn from_vec3(v: [T; 3]) -> Self;
194    #[cfg(feature = "4d")]
195    fn from_vec4(v: [T; 4]) -> Self;
196    fn from_mat2(v: [T; 4]) -> Self;
197    #[cfg(feature = "3d")]
198    fn from_mat3(v: [T; 9]) -> Self;
199    #[cfg(feature = "4d")]
200    fn from_mat4(v: [T; 16]) -> Self;
201
202    // Extraction (returns None if wrong type)
203    fn as_scalar(&self) -> Option<T>;
204    fn as_vec2(&self) -> Option<[T; 2]>;
205    #[cfg(feature = "3d")]
206    fn as_vec3(&self) -> Option<[T; 3]>;
207    #[cfg(feature = "4d")]
208    fn as_vec4(&self) -> Option<[T; 4]>;
209    fn as_mat2(&self) -> Option<[T; 4]>;
210    #[cfg(feature = "3d")]
211    fn as_mat3(&self) -> Option<[T; 9]>;
212    #[cfg(feature = "4d")]
213    fn as_mat4(&self) -> Option<[T; 16]>;
214}
215
216// ============================================================================
217// Values
218// ============================================================================
219
220/// A linalg value, generic over numeric type.
221///
222/// This is the default concrete type for standalone use of dew-linalg.
223/// For composing with other domain crates, implement `LinalgValue<T>` for
224/// your own combined enum.
225#[derive(Debug, Clone, PartialEq)]
226pub enum Value<T> {
227    Scalar(T),
228    Vec2([T; 2]),
229    #[cfg(feature = "3d")]
230    Vec3([T; 3]),
231    #[cfg(feature = "4d")]
232    Vec4([T; 4]),
233    Mat2([T; 4]), // column-major: [c0r0, c0r1, c1r0, c1r1]
234    #[cfg(feature = "3d")]
235    Mat3([T; 9]), // column-major
236    #[cfg(feature = "4d")]
237    Mat4([T; 16]), // column-major
238}
239
240// Inherent methods for backwards compatibility (don't require Debug bound)
241impl<T> Value<T> {
242    /// Returns the type of this value.
243    pub fn typ(&self) -> Type {
244        match self {
245            Value::Scalar(_) => Type::Scalar,
246            Value::Vec2(_) => Type::Vec2,
247            #[cfg(feature = "3d")]
248            Value::Vec3(_) => Type::Vec3,
249            #[cfg(feature = "4d")]
250            Value::Vec4(_) => Type::Vec4,
251            Value::Mat2(_) => Type::Mat2,
252            #[cfg(feature = "3d")]
253            Value::Mat3(_) => Type::Mat3,
254            #[cfg(feature = "4d")]
255            Value::Mat4(_) => Type::Mat4,
256        }
257    }
258}
259
260impl<T: Copy> Value<T> {
261    /// Try to get as scalar.
262    pub fn as_scalar(&self) -> Option<T> {
263        match self {
264            Value::Scalar(v) => Some(*v),
265            _ => None,
266        }
267    }
268}
269
270impl<T: Numeric> LinalgValue<T> for Value<T> {
271    fn typ(&self) -> Type {
272        // Delegate to inherent method
273        Value::typ(self)
274    }
275
276    fn from_scalar(v: T) -> Self {
277        Value::Scalar(v)
278    }
279    fn from_vec2(v: [T; 2]) -> Self {
280        Value::Vec2(v)
281    }
282    #[cfg(feature = "3d")]
283    fn from_vec3(v: [T; 3]) -> Self {
284        Value::Vec3(v)
285    }
286    #[cfg(feature = "4d")]
287    fn from_vec4(v: [T; 4]) -> Self {
288        Value::Vec4(v)
289    }
290    fn from_mat2(v: [T; 4]) -> Self {
291        Value::Mat2(v)
292    }
293    #[cfg(feature = "3d")]
294    fn from_mat3(v: [T; 9]) -> Self {
295        Value::Mat3(v)
296    }
297    #[cfg(feature = "4d")]
298    fn from_mat4(v: [T; 16]) -> Self {
299        Value::Mat4(v)
300    }
301
302    fn as_scalar(&self) -> Option<T> {
303        match self {
304            Value::Scalar(v) => Some(*v),
305            _ => None,
306        }
307    }
308    fn as_vec2(&self) -> Option<[T; 2]> {
309        match self {
310            Value::Vec2(v) => Some(*v),
311            _ => None,
312        }
313    }
314    #[cfg(feature = "3d")]
315    fn as_vec3(&self) -> Option<[T; 3]> {
316        match self {
317            Value::Vec3(v) => Some(*v),
318            _ => None,
319        }
320    }
321    #[cfg(feature = "4d")]
322    fn as_vec4(&self) -> Option<[T; 4]> {
323        match self {
324            Value::Vec4(v) => Some(*v),
325            _ => None,
326        }
327    }
328    fn as_mat2(&self) -> Option<[T; 4]> {
329        match self {
330            Value::Mat2(v) => Some(*v),
331            _ => None,
332        }
333    }
334    #[cfg(feature = "3d")]
335    fn as_mat3(&self) -> Option<[T; 9]> {
336        match self {
337            Value::Mat3(v) => Some(*v),
338            _ => None,
339        }
340    }
341    #[cfg(feature = "4d")]
342    fn as_mat4(&self) -> Option<[T; 16]> {
343        match self {
344            Value::Mat4(v) => Some(*v),
345            _ => None,
346        }
347    }
348}
349
350// ============================================================================
351// Errors
352// ============================================================================
353
354/// Linalg evaluation error.
355#[derive(Debug, Clone, PartialEq)]
356pub enum Error {
357    /// Unknown variable.
358    UnknownVariable(String),
359    /// Unknown function.
360    UnknownFunction(String),
361    /// Type mismatch for binary operation.
362    BinaryTypeMismatch { op: BinOp, left: Type, right: Type },
363    /// Type mismatch for unary operation.
364    UnaryTypeMismatch { op: UnaryOp, operand: Type },
365    /// Wrong number of arguments to function.
366    WrongArgCount {
367        func: String,
368        expected: usize,
369        got: usize,
370    },
371    /// Type mismatch in function arguments.
372    FunctionTypeMismatch {
373        func: String,
374        expected: Vec<Type>,
375        got: Vec<Type>,
376    },
377    /// Conditionals require scalar types.
378    UnsupportedTypeForConditional(Type),
379    /// Negative exponent for integer power.
380    NegativeExponent,
381}
382
383impl std::fmt::Display for Error {
384    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
385        match self {
386            Error::UnknownVariable(name) => write!(f, "unknown variable: '{name}'"),
387            Error::UnknownFunction(name) => write!(f, "unknown function: '{name}'"),
388            Error::BinaryTypeMismatch { op, left, right } => {
389                write!(f, "cannot apply {op:?} to {left} and {right}")
390            }
391            Error::UnaryTypeMismatch { op, operand } => {
392                write!(f, "cannot apply {op:?} to {operand}")
393            }
394            Error::WrongArgCount {
395                func,
396                expected,
397                got,
398            } => {
399                write!(f, "function '{func}' expects {expected} args, got {got}")
400            }
401            Error::FunctionTypeMismatch {
402                func,
403                expected,
404                got,
405            } => {
406                write!(
407                    f,
408                    "function '{func}' expects types {expected:?}, got {got:?}"
409                )
410            }
411            Error::UnsupportedTypeForConditional(t) => {
412                write!(f, "conditionals require scalar type, got {t}")
413            }
414            Error::NegativeExponent => {
415                write!(f, "negative exponent not supported for integer types")
416            }
417        }
418    }
419}
420
421impl std::error::Error for Error {}
422
423// ============================================================================
424// Function Registry
425// ============================================================================
426
427/// A function signature: argument types and return type.
428#[derive(Debug, Clone, PartialEq)]
429pub struct Signature {
430    pub args: Vec<Type>,
431    pub ret: Type,
432}
433
434/// A function that can be called from linalg expressions.
435///
436/// Generic over both the numeric type `T` and the value type `V`.
437/// This allows using custom combined value types when composing multiple domains.
438pub trait LinalgFn<T, V>: Send + Sync
439where
440    T: Numeric,
441    V: LinalgValue<T>,
442{
443    /// Function name.
444    fn name(&self) -> &str;
445
446    /// Available signatures for this function.
447    fn signatures(&self) -> Vec<Signature>;
448
449    /// Call the function with typed arguments.
450    /// Caller guarantees args match one of the signatures.
451    fn call(&self, args: &[V]) -> V;
452}
453
454/// Registry of linalg functions.
455#[derive(Clone)]
456pub struct FunctionRegistry<T, V>
457where
458    T: Numeric,
459    V: LinalgValue<T>,
460{
461    funcs: HashMap<String, Arc<dyn LinalgFn<T, V>>>,
462}
463
464impl<T, V> Default for FunctionRegistry<T, V>
465where
466    T: Numeric,
467    V: LinalgValue<T>,
468{
469    fn default() -> Self {
470        Self {
471            funcs: HashMap::new(),
472        }
473    }
474}
475
476impl<T, V> FunctionRegistry<T, V>
477where
478    T: Numeric,
479    V: LinalgValue<T>,
480{
481    pub fn new() -> Self {
482        Self::default()
483    }
484
485    pub fn register<F: LinalgFn<T, V> + 'static>(&mut self, func: F) {
486        self.funcs.insert(func.name().to_string(), Arc::new(func));
487    }
488
489    pub fn get(&self, name: &str) -> Option<&Arc<dyn LinalgFn<T, V>>> {
490        self.funcs.get(name)
491    }
492}
493
494// ============================================================================
495// Evaluation
496// ============================================================================
497
498/// Evaluate an AST with linalg values.
499///
500/// Generic over both numeric type `T` and value type `V`, allowing use of
501/// custom combined value types when composing multiple domains.
502///
503/// Literals from the AST (f64) are converted to T via `T::from(f64)`.
504pub fn eval<T, V>(
505    ast: &Ast,
506    vars: &HashMap<String, V>,
507    funcs: &FunctionRegistry<T, V>,
508) -> Result<V, Error>
509where
510    T: Numeric,
511    V: LinalgValue<T>,
512{
513    match ast {
514        Ast::Num(n) => {
515            // Convert f32 literal to T
516            Ok(V::from_scalar(T::from(*n).unwrap()))
517        }
518
519        Ast::Var(name) => vars
520            .get(name)
521            .cloned()
522            .ok_or_else(|| Error::UnknownVariable(name.clone())),
523
524        Ast::BinOp(op, left, right) => {
525            let left_val = eval(left, vars, funcs)?;
526            let right_val = eval(right, vars, funcs)?;
527            ops::apply_binop(*op, left_val, right_val)
528        }
529
530        Ast::UnaryOp(op, inner) => {
531            let val = eval(inner, vars, funcs)?;
532            ops::apply_unaryop(*op, val)
533        }
534
535        Ast::Call(name, args) => {
536            let func = funcs
537                .get(name)
538                .ok_or_else(|| Error::UnknownFunction(name.clone()))?;
539
540            let arg_vals: Vec<V> = args
541                .iter()
542                .map(|a| eval(a, vars, funcs))
543                .collect::<Result<_, _>>()?;
544
545            let arg_types: Vec<Type> = arg_vals.iter().map(|v| v.typ()).collect();
546
547            // Find matching signature
548            let matched = func.signatures().iter().any(|sig| sig.args == arg_types);
549            if !matched {
550                return Err(Error::FunctionTypeMismatch {
551                    func: name.clone(),
552                    expected: func
553                        .signatures()
554                        .first()
555                        .map(|s| s.args.clone())
556                        .unwrap_or_default(),
557                    got: arg_types,
558                });
559            }
560
561            Ok(func.call(&arg_vals))
562        }
563
564        Ast::Compare(op, left, right) => {
565            let left_val = eval(left, vars, funcs)?;
566            let right_val = eval(right, vars, funcs)?;
567            // Comparisons only supported for scalars
568            match (left_val.as_scalar(), right_val.as_scalar()) {
569                (Some(l), Some(r)) => {
570                    let result = match op {
571                        CompareOp::Lt => l < r,
572                        CompareOp::Le => l <= r,
573                        CompareOp::Gt => l > r,
574                        CompareOp::Ge => l >= r,
575                        CompareOp::Eq => l == r,
576                        CompareOp::Ne => l != r,
577                    };
578                    Ok(V::from_scalar(if result { T::one() } else { T::zero() }))
579                }
580                _ => Err(Error::UnsupportedTypeForConditional(left_val.typ())),
581            }
582        }
583
584        Ast::And(left, right) => {
585            let left_val = eval(left, vars, funcs)?;
586            let right_val = eval(right, vars, funcs)?;
587            match (left_val.as_scalar(), right_val.as_scalar()) {
588                (Some(l), Some(r)) => {
589                    let result = !l.is_zero() && !r.is_zero();
590                    Ok(V::from_scalar(if result { T::one() } else { T::zero() }))
591                }
592                _ => Err(Error::UnsupportedTypeForConditional(left_val.typ())),
593            }
594        }
595
596        Ast::Or(left, right) => {
597            let left_val = eval(left, vars, funcs)?;
598            let right_val = eval(right, vars, funcs)?;
599            match (left_val.as_scalar(), right_val.as_scalar()) {
600                (Some(l), Some(r)) => {
601                    let result = !l.is_zero() || !r.is_zero();
602                    Ok(V::from_scalar(if result { T::one() } else { T::zero() }))
603                }
604                _ => Err(Error::UnsupportedTypeForConditional(left_val.typ())),
605            }
606        }
607
608        Ast::If(cond, then_ast, else_ast) => {
609            let cond_val = eval(cond, vars, funcs)?;
610            match cond_val.as_scalar() {
611                Some(c) => {
612                    if !c.is_zero() {
613                        eval(then_ast, vars, funcs)
614                    } else {
615                        eval(else_ast, vars, funcs)
616                    }
617                }
618                None => Err(Error::UnsupportedTypeForConditional(cond_val.typ())),
619            }
620        }
621
622        Ast::Let { name, value, body } => {
623            let val = eval(value, vars, funcs)?;
624            let mut new_vars = vars.clone();
625            new_vars.insert(name.clone(), val);
626            eval(body, &new_vars, funcs)
627        }
628    }
629}
630
631// ============================================================================
632// Tests
633// ============================================================================
634
635#[cfg(test)]
636mod exhaustive_tests;
637
638#[cfg(test)]
639mod parity_tests;
640
641#[cfg(test)]
642mod tests {
643    use super::*;
644    use wick_core::Expr;
645
646    fn eval_expr(expr: &str, vars: &[(&str, Value<f32>)]) -> Result<Value<f32>, Error> {
647        let expr = Expr::parse(expr).unwrap();
648        let var_map: HashMap<String, Value<f32>> = vars
649            .iter()
650            .map(|(k, v)| (k.to_string(), v.clone()))
651            .collect();
652        let registry = FunctionRegistry::new();
653        eval(expr.ast(), &var_map, &registry)
654    }
655
656    #[test]
657    fn test_scalar_add() {
658        let result = eval_expr(
659            "a + b",
660            &[("a", Value::Scalar(1.0)), ("b", Value::Scalar(2.0))],
661        );
662        assert_eq!(result.unwrap(), Value::Scalar(3.0));
663    }
664
665    #[test]
666    fn test_vec2_add() {
667        let result = eval_expr(
668            "a + b",
669            &[
670                ("a", Value::Vec2([1.0, 2.0])),
671                ("b", Value::Vec2([3.0, 4.0])),
672            ],
673        );
674        assert_eq!(result.unwrap(), Value::Vec2([4.0, 6.0]));
675    }
676
677    #[test]
678    fn test_vec2_scalar_mul() {
679        let result = eval_expr(
680            "v * s",
681            &[("v", Value::Vec2([2.0, 3.0])), ("s", Value::Scalar(2.0))],
682        );
683        assert_eq!(result.unwrap(), Value::Vec2([4.0, 6.0]));
684    }
685
686    #[test]
687    fn test_scalar_vec2_mul() {
688        let result = eval_expr(
689            "s * v",
690            &[("s", Value::Scalar(2.0)), ("v", Value::Vec2([2.0, 3.0]))],
691        );
692        assert_eq!(result.unwrap(), Value::Vec2([4.0, 6.0]));
693    }
694
695    #[test]
696    fn test_vec2_neg() {
697        let result = eval_expr("-v", &[("v", Value::Vec2([1.0, -2.0]))]);
698        assert_eq!(result.unwrap(), Value::Vec2([-1.0, 2.0]));
699    }
700
701    #[cfg(feature = "3d")]
702    #[test]
703    fn test_vec3_add() {
704        let result = eval_expr(
705            "a + b",
706            &[
707                ("a", Value::Vec3([1.0, 2.0, 3.0])),
708                ("b", Value::Vec3([4.0, 5.0, 6.0])),
709            ],
710        );
711        assert_eq!(result.unwrap(), Value::Vec3([5.0, 7.0, 9.0]));
712    }
713
714    #[test]
715    fn test_type_mismatch() {
716        let result = eval_expr(
717            "a + b",
718            &[("a", Value::Scalar(1.0)), ("b", Value::Vec2([1.0, 2.0]))],
719        );
720        assert!(matches!(result, Err(Error::BinaryTypeMismatch { .. })));
721    }
722
723    #[test]
724    fn test_literal_conversion() {
725        // Test that f32 literals work with f64 values
726        let expr = Expr::parse("a + 1.5").unwrap();
727        let mut vars: HashMap<String, Value<f64>> = HashMap::new();
728        vars.insert("a".to_string(), Value::Scalar(2.5));
729        let registry = FunctionRegistry::new();
730        let result = eval(expr.ast(), &vars, &registry).unwrap();
731        assert_eq!(result, Value::Scalar(4.0));
732    }
733}