Skip to main content

wick_quaternion/
lib.rs

1//! Quaternion support for dew expressions.
2//!
3//! Provides quaternion types and operations for 3D rotations. Uses [x, y, z, w]
4//! component order (scalar last, matching GLM/glTF convention).
5//!
6//! # Quick Start
7//!
8//! ```
9//! use wick_core::Expr;
10//! use wick_quaternion::{Value, eval, quaternion_registry};
11//! use std::collections::HashMap;
12//!
13//! // Create a rotation quaternion and normalize it
14//! let expr = Expr::parse("normalize(q)").unwrap();
15//!
16//! let vars: HashMap<String, Value<f32>> = [
17//!     ("q".into(), Value::Quaternion([0.0, 0.0, 0.0, 2.0])),
18//! ].into();
19//!
20//! let result = eval(expr.ast(), &vars, &quaternion_registry()).unwrap();
21//! assert_eq!(result, Value::Quaternion([0.0, 0.0, 0.0, 1.0]));
22//! ```
23//!
24//! # Features
25//!
26//! | Feature     | Description                    |
27//! |-------------|--------------------------------|
28//! | `wgsl`      | WGSL shader code generation    |
29//! | `lua`       | Lua code generation            |
30//! | `cranelift` | Cranelift JIT compilation      |
31//!
32//! # Types
33//!
34//! | Type        | Description                      |
35//! |-------------|----------------------------------|
36//! | `Scalar`    | Real number                      |
37//! | `Vec3`      | 3D vector [x, y, z]              |
38//! | `Quaternion`| Quaternion [x, y, z, w]          |
39//!
40//! # Functions
41//!
42//! | Function              | Description                              |
43//! |-----------------------|------------------------------------------|
44//! | `vec3(x, y, z)`       | Construct vector → vec3                  |
45//! | `quat(x, y, z, w)`    | Construct quaternion → quaternion        |
46//! | `conj(q)`             | Conjugate → quaternion                   |
47//! | `length(q)`           | Magnitude → scalar                       |
48//! | `normalize(q)`        | Unit quaternion → quaternion             |
49//! | `inverse(q)`          | Multiplicative inverse → quaternion      |
50//! | `dot(q1, q2)`         | Dot product → scalar                     |
51//! | `lerp(q1, q2, t)`     | Linear interpolation → quaternion        |
52//! | `slerp(q1, q2, t)`    | Spherical interpolation → quaternion     |
53//! | `axis_angle(axis, θ)` | From axis-angle → quaternion             |
54//! | `rotate(v, q)`        | Rotate vector by quaternion → vec3       |
55//!
56//! # Operators
57//!
58//! | Operation          | Result                          |
59//! |--------------------|---------------------------------|
60//! | `q1 * q2`          | Quaternion multiplication       |
61//! | `q * scalar`       | Scalar multiplication           |
62//! | `q1 + q2`          | Component-wise addition         |
63//! | `q1 - q2`          | Component-wise subtraction      |
64//! | `-q`               | Negation                        |
65//!
66//! # Component Order
67//!
68//! This crate uses [x, y, z, w] order (scalar last), matching:
69//! - GLM (OpenGL Mathematics)
70//! - glTF format
71//! - Unity (internal representation)
72//!
73//! Other conventions exist (w first), so be careful when interfacing
74//! with external libraries.
75
76use num_traits::Float;
77use std::collections::HashMap;
78use std::sync::Arc;
79use wick_core::{Ast, BinOp, CompareOp, UnaryOp};
80
81mod funcs;
82pub mod ops;
83#[cfg(test)]
84mod parity_tests;
85
86#[cfg(feature = "wgsl")]
87pub mod wgsl;
88
89#[cfg(feature = "glsl")]
90pub mod glsl;
91
92#[cfg(feature = "rust")]
93pub mod rust;
94
95#[cfg(feature = "c")]
96pub mod c;
97
98#[cfg(feature = "opencl")]
99pub mod opencl;
100
101#[cfg(feature = "cuda")]
102pub mod cuda;
103
104#[cfg(feature = "hip")]
105pub mod hip;
106
107#[cfg(feature = "tokenstream")]
108pub mod tokenstream;
109
110#[cfg(feature = "lua-codegen")]
111pub mod lua;
112
113#[cfg(feature = "cranelift")]
114pub mod cranelift;
115
116#[cfg(feature = "optimize")]
117pub mod optimize;
118
119pub use funcs::{
120    AxisAngle, Conj, Dot, Inverse, Length, Lerp, Normalize, QuatConstructor, Rotate, Slerp,
121    Vec3Constructor, quaternion_registry, register_quaternion,
122};
123
124// ============================================================================
125// Types
126// ============================================================================
127
128/// Type of a quaternion value.
129#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
130pub enum Type {
131    /// Real scalar.
132    Scalar,
133    /// 3D vector [x, y, z].
134    Vec3,
135    /// Quaternion [x, y, z, w].
136    Quaternion,
137}
138
139impl std::fmt::Display for Type {
140    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141        match self {
142            Type::Scalar => write!(f, "scalar"),
143            Type::Vec3 => write!(f, "vec3"),
144            Type::Quaternion => write!(f, "quaternion"),
145        }
146    }
147}
148
149// ============================================================================
150// QuaternionValue trait (for composability)
151// ============================================================================
152
153/// Trait for values that support quaternion operations.
154///
155/// Implement this for combined value types when composing multiple domain crates.
156pub trait QuaternionValue<T: Float>: Clone + PartialEq + Sized + std::fmt::Debug {
157    /// Returns the type of this value.
158    fn typ(&self) -> Type;
159
160    // Construction
161    fn from_scalar(v: T) -> Self;
162    fn from_vec3(v: [T; 3]) -> Self;
163    fn from_quaternion(q: [T; 4]) -> Self;
164
165    // Extraction
166    fn as_scalar(&self) -> Option<T>;
167    fn as_vec3(&self) -> Option<[T; 3]>;
168    fn as_quaternion(&self) -> Option<[T; 4]>;
169}
170
171// ============================================================================
172// Values
173// ============================================================================
174
175/// A quaternion value, generic over numeric type.
176///
177/// Quaternion uses [x, y, z, w] order (scalar last).
178#[derive(Debug, Clone, PartialEq)]
179pub enum Value<T> {
180    /// Real scalar.
181    Scalar(T),
182    /// 3D vector [x, y, z].
183    Vec3([T; 3]),
184    /// Quaternion [x, y, z, w] (scalar last).
185    Quaternion([T; 4]),
186}
187
188impl<T> Value<T> {
189    /// Returns the type of this value.
190    pub fn typ(&self) -> Type {
191        match self {
192            Value::Scalar(_) => Type::Scalar,
193            Value::Vec3(_) => Type::Vec3,
194            Value::Quaternion(_) => Type::Quaternion,
195        }
196    }
197}
198
199impl<T: Copy> Value<T> {
200    /// Try to get as scalar.
201    pub fn as_scalar(&self) -> Option<T> {
202        match self {
203            Value::Scalar(v) => Some(*v),
204            _ => None,
205        }
206    }
207
208    /// Try to get as vec3.
209    pub fn as_vec3(&self) -> Option<[T; 3]> {
210        match self {
211            Value::Vec3(v) => Some(*v),
212            _ => None,
213        }
214    }
215
216    /// Try to get as quaternion.
217    pub fn as_quaternion(&self) -> Option<[T; 4]> {
218        match self {
219            Value::Quaternion(q) => Some(*q),
220            _ => None,
221        }
222    }
223}
224
225impl<T: Float + std::fmt::Debug> QuaternionValue<T> for Value<T> {
226    fn typ(&self) -> Type {
227        Value::typ(self)
228    }
229
230    fn from_scalar(v: T) -> Self {
231        Value::Scalar(v)
232    }
233
234    fn from_vec3(v: [T; 3]) -> Self {
235        Value::Vec3(v)
236    }
237
238    fn from_quaternion(q: [T; 4]) -> Self {
239        Value::Quaternion(q)
240    }
241
242    fn as_scalar(&self) -> Option<T> {
243        Value::as_scalar(self)
244    }
245
246    fn as_vec3(&self) -> Option<[T; 3]> {
247        Value::as_vec3(self)
248    }
249
250    fn as_quaternion(&self) -> Option<[T; 4]> {
251        Value::as_quaternion(self)
252    }
253}
254
255// ============================================================================
256// Errors
257// ============================================================================
258
259/// Quaternion evaluation error.
260#[derive(Debug, Clone, PartialEq)]
261pub enum Error {
262    /// Unknown variable.
263    UnknownVariable(String),
264    /// Unknown function.
265    UnknownFunction(String),
266    /// Type mismatch for binary operation.
267    BinaryTypeMismatch { op: BinOp, left: Type, right: Type },
268    /// Type mismatch for unary operation.
269    UnaryTypeMismatch { op: UnaryOp, operand: Type },
270    /// Wrong number of arguments to function.
271    WrongArgCount {
272        func: String,
273        expected: usize,
274        got: usize,
275    },
276    /// Type mismatch in function arguments.
277    FunctionTypeMismatch {
278        func: String,
279        expected: Vec<Type>,
280        got: Vec<Type>,
281    },
282    /// Conditionals require scalar types.
283    UnsupportedTypeForConditional(Type),
284}
285
286impl std::fmt::Display for Error {
287    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288        match self {
289            Error::UnknownVariable(name) => write!(f, "unknown variable: '{name}'"),
290            Error::UnknownFunction(name) => write!(f, "unknown function: '{name}'"),
291            Error::BinaryTypeMismatch { op, left, right } => {
292                write!(f, "cannot apply {op:?} to {left} and {right}")
293            }
294            Error::UnaryTypeMismatch { op, operand } => {
295                write!(f, "cannot apply {op:?} to {operand}")
296            }
297            Error::WrongArgCount {
298                func,
299                expected,
300                got,
301            } => {
302                write!(f, "function '{func}' expects {expected} args, got {got}")
303            }
304            Error::FunctionTypeMismatch {
305                func,
306                expected,
307                got,
308            } => {
309                write!(
310                    f,
311                    "function '{func}' expects types {expected:?}, got {got:?}"
312                )
313            }
314            Error::UnsupportedTypeForConditional(t) => {
315                write!(f, "conditionals require scalar type, got {t}")
316            }
317        }
318    }
319}
320
321impl std::error::Error for Error {}
322
323// ============================================================================
324// Function Registry
325// ============================================================================
326
327/// A function signature.
328#[derive(Debug, Clone, PartialEq)]
329pub struct Signature {
330    pub args: Vec<Type>,
331    pub ret: Type,
332}
333
334/// A function that can be called from quaternion expressions.
335///
336/// Generic over both the numeric type `T` and the value type `V`.
337pub trait QuaternionFn<T, V>: Send + Sync
338where
339    T: Float,
340    V: QuaternionValue<T>,
341{
342    /// Function name.
343    fn name(&self) -> &str;
344
345    /// Available signatures for this function.
346    fn signatures(&self) -> Vec<Signature>;
347
348    /// Call the function with typed arguments.
349    fn call(&self, args: &[V]) -> V;
350}
351
352/// Registry of quaternion functions.
353#[derive(Clone)]
354pub struct FunctionRegistry<T, V>
355where
356    T: Float,
357    V: QuaternionValue<T>,
358{
359    funcs: HashMap<String, Arc<dyn QuaternionFn<T, V>>>,
360}
361
362impl<T, V> Default for FunctionRegistry<T, V>
363where
364    T: Float,
365    V: QuaternionValue<T>,
366{
367    fn default() -> Self {
368        Self {
369            funcs: HashMap::new(),
370        }
371    }
372}
373
374impl<T, V> FunctionRegistry<T, V>
375where
376    T: Float,
377    V: QuaternionValue<T>,
378{
379    pub fn new() -> Self {
380        Self::default()
381    }
382
383    pub fn register<F: QuaternionFn<T, V> + 'static>(&mut self, func: F) {
384        self.funcs.insert(func.name().to_string(), Arc::new(func));
385    }
386
387    pub fn get(&self, name: &str) -> Option<&Arc<dyn QuaternionFn<T, V>>> {
388        self.funcs.get(name)
389    }
390}
391
392// ============================================================================
393// Evaluation
394// ============================================================================
395
396/// Evaluate an AST with quaternion values.
397///
398/// Generic over both the numeric type `T` and the value type `V`.
399pub fn eval<T, V>(
400    ast: &Ast,
401    vars: &HashMap<String, V>,
402    funcs: &FunctionRegistry<T, V>,
403) -> Result<V, Error>
404where
405    T: Float,
406    V: QuaternionValue<T>,
407{
408    match ast {
409        Ast::Num(n) => Ok(V::from_scalar(T::from(*n).unwrap())),
410
411        Ast::Var(name) => vars
412            .get(name)
413            .cloned()
414            .ok_or_else(|| Error::UnknownVariable(name.clone())),
415
416        Ast::BinOp(op, left, right) => {
417            let left_val = eval(left, vars, funcs)?;
418            let right_val = eval(right, vars, funcs)?;
419            ops::apply_binop(*op, left_val, right_val)
420        }
421
422        Ast::UnaryOp(op, inner) => {
423            let val = eval(inner, vars, funcs)?;
424            ops::apply_unaryop(*op, val)
425        }
426
427        Ast::Call(name, args) => {
428            let func = funcs
429                .get(name)
430                .ok_or_else(|| Error::UnknownFunction(name.clone()))?;
431
432            let arg_vals: Vec<V> = args
433                .iter()
434                .map(|a| eval(a, vars, funcs))
435                .collect::<Result<_, _>>()?;
436
437            let arg_types: Vec<Type> = arg_vals.iter().map(|v| v.typ()).collect();
438
439            // Find matching signature
440            let matched = func.signatures().iter().any(|sig| sig.args == arg_types);
441            if !matched {
442                return Err(Error::FunctionTypeMismatch {
443                    func: name.clone(),
444                    expected: func
445                        .signatures()
446                        .first()
447                        .map(|s| s.args.clone())
448                        .unwrap_or_default(),
449                    got: arg_types,
450                });
451            }
452
453            Ok(func.call(&arg_vals))
454        }
455
456        Ast::Compare(op, left, right) => {
457            let left_val = eval(left, vars, funcs)?;
458            let right_val = eval(right, vars, funcs)?;
459            match (left_val.as_scalar(), right_val.as_scalar()) {
460                (Some(l), Some(r)) => {
461                    let result = match op {
462                        CompareOp::Lt => l < r,
463                        CompareOp::Le => l <= r,
464                        CompareOp::Gt => l > r,
465                        CompareOp::Ge => l >= r,
466                        CompareOp::Eq => l == r,
467                        CompareOp::Ne => l != r,
468                    };
469                    Ok(V::from_scalar(if result { T::one() } else { T::zero() }))
470                }
471                _ => Err(Error::UnsupportedTypeForConditional(left_val.typ())),
472            }
473        }
474
475        Ast::And(left, right) => {
476            let left_val = eval(left, vars, funcs)?;
477            let right_val = eval(right, vars, funcs)?;
478            match (left_val.as_scalar(), right_val.as_scalar()) {
479                (Some(l), Some(r)) => {
480                    let result = !l.is_zero() && !r.is_zero();
481                    Ok(V::from_scalar(if result { T::one() } else { T::zero() }))
482                }
483                _ => Err(Error::UnsupportedTypeForConditional(left_val.typ())),
484            }
485        }
486
487        Ast::Or(left, right) => {
488            let left_val = eval(left, vars, funcs)?;
489            let right_val = eval(right, vars, funcs)?;
490            match (left_val.as_scalar(), right_val.as_scalar()) {
491                (Some(l), Some(r)) => {
492                    let result = !l.is_zero() || !r.is_zero();
493                    Ok(V::from_scalar(if result { T::one() } else { T::zero() }))
494                }
495                _ => Err(Error::UnsupportedTypeForConditional(left_val.typ())),
496            }
497        }
498
499        Ast::If(cond, then_ast, else_ast) => {
500            let cond_val = eval(cond, vars, funcs)?;
501            if let Some(c) = cond_val.as_scalar() {
502                if !c.is_zero() {
503                    eval(then_ast, vars, funcs)
504                } else {
505                    eval(else_ast, vars, funcs)
506                }
507            } else {
508                Err(Error::UnsupportedTypeForConditional(cond_val.typ()))
509            }
510        }
511
512        Ast::Let { name, value, body } => {
513            let val = eval(value, vars, funcs)?;
514            let mut new_vars = vars.clone();
515            new_vars.insert(name.clone(), val);
516            eval(body, &new_vars, funcs)
517        }
518    }
519}
520
521// ============================================================================
522// Tests
523// ============================================================================
524
525#[cfg(test)]
526mod tests {
527    use super::*;
528    use wick_core::Expr;
529
530    fn eval_expr(expr: &str, vars: &[(&str, Value<f32>)]) -> Result<Value<f32>, Error> {
531        let expr = Expr::parse(expr).unwrap();
532        let var_map: HashMap<String, Value<f32>> = vars
533            .iter()
534            .map(|(k, v)| (k.to_string(), v.clone()))
535            .collect();
536        let registry = quaternion_registry();
537        eval(expr.ast(), &var_map, &registry)
538    }
539
540    #[test]
541    fn test_quaternion_add() {
542        let result = eval_expr(
543            "a + b",
544            &[
545                ("a", Value::Quaternion([1.0, 2.0, 3.0, 4.0])),
546                ("b", Value::Quaternion([5.0, 6.0, 7.0, 8.0])),
547            ],
548        );
549        assert_eq!(result.unwrap(), Value::Quaternion([6.0, 8.0, 10.0, 12.0]));
550    }
551
552    #[test]
553    fn test_quaternion_mul() {
554        // Identity quaternion: [0, 0, 0, 1]
555        // q * identity = q
556        let result = eval_expr(
557            "a * b",
558            &[
559                ("a", Value::Quaternion([1.0, 2.0, 3.0, 4.0])),
560                ("b", Value::Quaternion([0.0, 0.0, 0.0, 1.0])),
561            ],
562        );
563        assert_eq!(result.unwrap(), Value::Quaternion([1.0, 2.0, 3.0, 4.0]));
564    }
565
566    #[test]
567    fn test_quaternion_neg() {
568        let result = eval_expr("-q", &[("q", Value::Quaternion([1.0, 2.0, 3.0, 4.0]))]);
569        assert_eq!(result.unwrap(), Value::Quaternion([-1.0, -2.0, -3.0, -4.0]));
570    }
571
572    #[test]
573    fn test_quaternion_scalar_mul() {
574        let result = eval_expr(
575            "s * q",
576            &[
577                ("s", Value::Scalar(2.0)),
578                ("q", Value::Quaternion([1.0, 2.0, 3.0, 4.0])),
579            ],
580        );
581        assert_eq!(result.unwrap(), Value::Quaternion([2.0, 4.0, 6.0, 8.0]));
582    }
583}