use crate::{FunctionRegistry, QuaternionFn, QuaternionValue, Signature, Type, Value};
use num_traits::Float;
pub struct Conj;
impl<T, V> QuaternionFn<T, V> for Conj
where
T: Float,
V: QuaternionValue<T>,
{
fn name(&self) -> &str {
"conj"
}
fn signatures(&self) -> Vec<Signature> {
vec![Signature {
args: vec![Type::Quaternion],
ret: Type::Quaternion,
}]
}
fn call(&self, args: &[V]) -> V {
let q = args[0].as_quaternion().unwrap();
V::from_quaternion([-q[0], -q[1], -q[2], q[3]])
}
}
pub struct Length;
impl<T, V> QuaternionFn<T, V> for Length
where
T: Float,
V: QuaternionValue<T>,
{
fn name(&self) -> &str {
"length"
}
fn signatures(&self) -> Vec<Signature> {
vec![
Signature {
args: vec![Type::Vec3],
ret: Type::Scalar,
},
Signature {
args: vec![Type::Quaternion],
ret: Type::Scalar,
},
]
}
fn call(&self, args: &[V]) -> V {
match args[0].typ() {
Type::Vec3 => {
let v = args[0].as_vec3().unwrap();
V::from_scalar((v[0] * v[0] + v[1] * v[1] + v[2] * v[2]).sqrt())
}
Type::Quaternion => {
let q = args[0].as_quaternion().unwrap();
V::from_scalar((q[0] * q[0] + q[1] * q[1] + q[2] * q[2] + q[3] * q[3]).sqrt())
}
_ => unreachable!(),
}
}
}
pub struct Normalize;
impl<T, V> QuaternionFn<T, V> for Normalize
where
T: Float,
V: QuaternionValue<T>,
{
fn name(&self) -> &str {
"normalize"
}
fn signatures(&self) -> Vec<Signature> {
vec![
Signature {
args: vec![Type::Vec3],
ret: Type::Vec3,
},
Signature {
args: vec![Type::Quaternion],
ret: Type::Quaternion,
},
]
}
fn call(&self, args: &[V]) -> V {
match args[0].typ() {
Type::Vec3 => {
let v = args[0].as_vec3().unwrap();
let len = (v[0] * v[0] + v[1] * v[1] + v[2] * v[2]).sqrt();
V::from_vec3([v[0] / len, v[1] / len, v[2] / len])
}
Type::Quaternion => {
let q = args[0].as_quaternion().unwrap();
let len = (q[0] * q[0] + q[1] * q[1] + q[2] * q[2] + q[3] * q[3]).sqrt();
V::from_quaternion([q[0] / len, q[1] / len, q[2] / len, q[3] / len])
}
_ => unreachable!(),
}
}
}
pub struct Inverse;
impl<T, V> QuaternionFn<T, V> for Inverse
where
T: Float,
V: QuaternionValue<T>,
{
fn name(&self) -> &str {
"inverse"
}
fn signatures(&self) -> Vec<Signature> {
vec![Signature {
args: vec![Type::Quaternion],
ret: Type::Quaternion,
}]
}
fn call(&self, args: &[V]) -> V {
let q = args[0].as_quaternion().unwrap();
let norm_sq = q[0] * q[0] + q[1] * q[1] + q[2] * q[2] + q[3] * q[3];
V::from_quaternion([
-q[0] / norm_sq,
-q[1] / norm_sq,
-q[2] / norm_sq,
q[3] / norm_sq,
])
}
}
pub struct Dot;
impl<T, V> QuaternionFn<T, V> for Dot
where
T: Float,
V: QuaternionValue<T>,
{
fn name(&self) -> &str {
"dot"
}
fn signatures(&self) -> Vec<Signature> {
vec![
Signature {
args: vec![Type::Vec3, Type::Vec3],
ret: Type::Scalar,
},
Signature {
args: vec![Type::Quaternion, Type::Quaternion],
ret: Type::Scalar,
},
]
}
fn call(&self, args: &[V]) -> V {
match (args[0].typ(), args[1].typ()) {
(Type::Vec3, Type::Vec3) => {
let a = args[0].as_vec3().unwrap();
let b = args[1].as_vec3().unwrap();
V::from_scalar(a[0] * b[0] + a[1] * b[1] + a[2] * b[2])
}
(Type::Quaternion, Type::Quaternion) => {
let a = args[0].as_quaternion().unwrap();
let b = args[1].as_quaternion().unwrap();
V::from_scalar(a[0] * b[0] + a[1] * b[1] + a[2] * b[2] + a[3] * b[3])
}
_ => unreachable!(),
}
}
}
pub struct Lerp;
impl<T, V> QuaternionFn<T, V> for Lerp
where
T: Float,
V: QuaternionValue<T>,
{
fn name(&self) -> &str {
"lerp"
}
fn signatures(&self) -> Vec<Signature> {
vec![
Signature {
args: vec![Type::Vec3, Type::Vec3, Type::Scalar],
ret: Type::Vec3,
},
Signature {
args: vec![Type::Quaternion, Type::Quaternion, Type::Scalar],
ret: Type::Quaternion,
},
]
}
fn call(&self, args: &[V]) -> V {
let t = args[2].as_scalar().unwrap();
match (args[0].typ(), args[1].typ()) {
(Type::Vec3, Type::Vec3) => {
let a = args[0].as_vec3().unwrap();
let b = args[1].as_vec3().unwrap();
V::from_vec3([
a[0] + (b[0] - a[0]) * t,
a[1] + (b[1] - a[1]) * t,
a[2] + (b[2] - a[2]) * t,
])
}
(Type::Quaternion, Type::Quaternion) => {
let a = args[0].as_quaternion().unwrap();
let b = args[1].as_quaternion().unwrap();
V::from_quaternion([
a[0] + (b[0] - a[0]) * t,
a[1] + (b[1] - a[1]) * t,
a[2] + (b[2] - a[2]) * t,
a[3] + (b[3] - a[3]) * t,
])
}
_ => unreachable!(),
}
}
}
pub struct Slerp;
impl<T, V> QuaternionFn<T, V> for Slerp
where
T: Float,
V: QuaternionValue<T>,
{
fn name(&self) -> &str {
"slerp"
}
fn signatures(&self) -> Vec<Signature> {
vec![Signature {
args: vec![Type::Quaternion, Type::Quaternion, Type::Scalar],
ret: Type::Quaternion,
}]
}
fn call(&self, args: &[V]) -> V {
let a = args[0].as_quaternion().unwrap();
let b = args[1].as_quaternion().unwrap();
let t = args[2].as_scalar().unwrap();
V::from_quaternion(slerp_impl(&a, &b, t))
}
}
fn slerp_impl<T: Float>(a: &[T; 4], b: &[T; 4], t: T) -> [T; 4] {
let mut dot = a[0] * b[0] + a[1] * b[1] + a[2] * b[2] + a[3] * b[3];
let mut b = *b;
if dot < T::zero() {
b = [-b[0], -b[1], -b[2], -b[3]];
dot = -dot;
}
let one = T::one();
if dot > one {
dot = one;
}
let threshold = T::from(0.9995).unwrap();
if dot > threshold {
let result = [
a[0] + (b[0] - a[0]) * t,
a[1] + (b[1] - a[1]) * t,
a[2] + (b[2] - a[2]) * t,
a[3] + (b[3] - a[3]) * t,
];
let len = (result[0] * result[0]
+ result[1] * result[1]
+ result[2] * result[2]
+ result[3] * result[3])
.sqrt();
return [
result[0] / len,
result[1] / len,
result[2] / len,
result[3] / len,
];
}
let theta = dot.acos();
let sin_theta = theta.sin();
let s0 = ((one - t) * theta).sin() / sin_theta;
let s1 = (t * theta).sin() / sin_theta;
[
a[0] * s0 + b[0] * s1,
a[1] * s0 + b[1] * s1,
a[2] * s0 + b[2] * s1,
a[3] * s0 + b[3] * s1,
]
}
pub struct AxisAngle;
impl<T, V> QuaternionFn<T, V> for AxisAngle
where
T: Float,
V: QuaternionValue<T>,
{
fn name(&self) -> &str {
"axis_angle"
}
fn signatures(&self) -> Vec<Signature> {
vec![Signature {
args: vec![Type::Vec3, Type::Scalar],
ret: Type::Quaternion,
}]
}
fn call(&self, args: &[V]) -> V {
let axis = args[0].as_vec3().unwrap();
let angle = args[1].as_scalar().unwrap();
let half_angle = angle / T::from(2.0).unwrap();
let s = half_angle.sin();
let c = half_angle.cos();
let len = (axis[0] * axis[0] + axis[1] * axis[1] + axis[2] * axis[2]).sqrt();
V::from_quaternion([axis[0] / len * s, axis[1] / len * s, axis[2] / len * s, c])
}
}
pub struct Rotate;
impl<T, V> QuaternionFn<T, V> for Rotate
where
T: Float,
V: QuaternionValue<T>,
{
fn name(&self) -> &str {
"rotate"
}
fn signatures(&self) -> Vec<Signature> {
vec![Signature {
args: vec![Type::Vec3, Type::Quaternion],
ret: Type::Vec3,
}]
}
fn call(&self, args: &[V]) -> V {
let v = args[0].as_vec3().unwrap();
let q = args[1].as_quaternion().unwrap();
V::from_vec3(rotate_vec3_by_quat(&v, &q))
}
}
fn rotate_vec3_by_quat<T: Float>(v: &[T; 3], q: &[T; 4]) -> [T; 3] {
let (qx, qy, qz, qw) = (q[0], q[1], q[2], q[3]);
let two = T::from(2.0).unwrap();
let tx = two * (qy * v[2] - qz * v[1]);
let ty = two * (qz * v[0] - qx * v[2]);
let tz = two * (qx * v[1] - qy * v[0]);
[
v[0] + qw * tx + (qy * tz - qz * ty),
v[1] + qw * ty + (qz * tx - qx * tz),
v[2] + qw * tz + (qx * ty - qy * tx),
]
}
pub struct Vec3Constructor;
impl<T, V> QuaternionFn<T, V> for Vec3Constructor
where
T: Float,
V: QuaternionValue<T>,
{
fn name(&self) -> &str {
"vec3"
}
fn signatures(&self) -> Vec<Signature> {
vec![Signature {
args: vec![Type::Scalar, Type::Scalar, Type::Scalar],
ret: Type::Vec3,
}]
}
fn call(&self, args: &[V]) -> V {
let x = args[0].as_scalar().unwrap();
let y = args[1].as_scalar().unwrap();
let z = args[2].as_scalar().unwrap();
V::from_vec3([x, y, z])
}
}
pub struct QuatConstructor;
impl<T, V> QuaternionFn<T, V> for QuatConstructor
where
T: Float,
V: QuaternionValue<T>,
{
fn name(&self) -> &str {
"quat"
}
fn signatures(&self) -> Vec<Signature> {
vec![Signature {
args: vec![Type::Scalar, Type::Scalar, Type::Scalar, Type::Scalar],
ret: Type::Quaternion,
}]
}
fn call(&self, args: &[V]) -> V {
let x = args[0].as_scalar().unwrap();
let y = args[1].as_scalar().unwrap();
let z = args[2].as_scalar().unwrap();
let w = args[3].as_scalar().unwrap();
V::from_quaternion([x, y, z, w])
}
}
pub fn register_quaternion<T, V>(registry: &mut FunctionRegistry<T, V>)
where
T: Float + 'static,
V: QuaternionValue<T> + 'static,
{
registry.register(Conj);
registry.register(Length);
registry.register(Normalize);
registry.register(Inverse);
registry.register(Dot);
registry.register(Lerp);
registry.register(Slerp);
registry.register(AxisAngle);
registry.register(Rotate);
registry.register(Vec3Constructor);
registry.register(QuatConstructor);
}
pub fn quaternion_registry<T: Float + std::fmt::Debug + 'static>() -> FunctionRegistry<T, Value<T>>
{
let mut registry = FunctionRegistry::new();
register_quaternion(&mut registry);
registry
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use wick_core::Expr;
fn approx_eq(a: f32, b: f32) -> bool {
(a - b).abs() < 0.0001
}
fn eval_expr(expr: &str, vars: &[(&str, Value<f32>)]) -> Value<f32> {
let expr = Expr::parse(expr).unwrap();
let var_map: HashMap<String, Value<f32>> = vars
.iter()
.map(|(k, v)| (k.to_string(), v.clone()))
.collect();
let registry = quaternion_registry();
crate::eval(expr.ast(), &var_map, ®istry).unwrap()
}
#[test]
fn test_conj() {
let result = eval_expr("conj(q)", &[("q", Value::Quaternion([1.0, 2.0, 3.0, 4.0]))]);
assert_eq!(result, Value::Quaternion([-1.0, -2.0, -3.0, 4.0]));
}
#[test]
fn test_normalize() {
let result = eval_expr(
"normalize(q)",
&[("q", Value::Quaternion([0.0, 0.0, 0.0, 2.0]))],
);
assert_eq!(result, Value::Quaternion([0.0, 0.0, 0.0, 1.0]));
}
#[test]
fn test_length() {
let result = eval_expr(
"length(q)",
&[("q", Value::Quaternion([0.0, 0.0, 3.0, 4.0]))],
);
assert_eq!(result, Value::Scalar(5.0));
}
#[test]
fn test_dot() {
let result = eval_expr(
"dot(a, b)",
&[
("a", Value::Quaternion([1.0, 0.0, 0.0, 0.0])),
("b", Value::Quaternion([1.0, 0.0, 0.0, 0.0])),
],
);
assert_eq!(result, Value::Scalar(1.0));
}
#[test]
fn test_axis_angle() {
let result = eval_expr(
"axis_angle(axis, angle)",
&[
("axis", Value::Vec3([0.0, 0.0, 1.0])),
("angle", Value::Scalar(std::f32::consts::FRAC_PI_2)),
],
);
if let Value::Quaternion(q) = result {
assert!(approx_eq(q[0], 0.0));
assert!(approx_eq(q[1], 0.0));
assert!(approx_eq(q[2], std::f32::consts::FRAC_PI_4.sin()));
assert!(approx_eq(q[3], std::f32::consts::FRAC_PI_4.cos()));
} else {
panic!("expected quaternion");
}
}
#[test]
fn test_rotate() {
let half_angle = std::f32::consts::FRAC_PI_4;
let result = eval_expr(
"rotate(v, q)",
&[
("v", Value::Vec3([1.0, 0.0, 0.0])),
(
"q",
Value::Quaternion([0.0, 0.0, half_angle.sin(), half_angle.cos()]),
),
],
);
if let Value::Vec3(v) = result {
assert!(approx_eq(v[0], 0.0));
assert!(approx_eq(v[1], 1.0));
assert!(approx_eq(v[2], 0.0));
} else {
panic!("expected vec3");
}
}
#[test]
fn test_slerp() {
let identity = Value::Quaternion([0.0, 0.0, 0.0, 1.0]);
let half_turn = Value::Quaternion([0.0, 0.0, 1.0, 0.0]);
let result = eval_expr(
"slerp(a, b, t)",
&[("a", identity), ("b", half_turn), ("t", Value::Scalar(0.5))],
);
if let Value::Quaternion(q) = result {
assert!(approx_eq(q[0], 0.0));
assert!(approx_eq(q[1], 0.0));
assert!(approx_eq(q[2], std::f32::consts::FRAC_PI_4.sin()));
assert!(approx_eq(q[3], std::f32::consts::FRAC_PI_4.cos()));
} else {
panic!("expected quaternion");
}
}
}