use crate::{Error, QuaternionValue, Type};
use num_traits::Float;
use wick_core::{BinOp, UnaryOp};
pub fn apply_binop<T, V>(op: BinOp, left: V, right: V) -> Result<V, Error>
where
T: Float,
V: QuaternionValue<T>,
{
match op {
BinOp::Add => apply_add(left, right),
BinOp::Sub => apply_sub(left, right),
BinOp::Mul => apply_mul(left, right),
BinOp::Div => apply_div(left, right),
BinOp::Pow => apply_pow(left, right),
BinOp::Rem | BinOp::BitAnd | BinOp::BitOr | BinOp::Shl | BinOp::Shr => {
Err(Error::BinaryTypeMismatch {
op,
left: left.typ(),
right: right.typ(),
})
}
}
}
pub fn apply_unaryop<T, V>(op: UnaryOp, val: V) -> Result<V, Error>
where
T: Float,
V: QuaternionValue<T>,
{
match op {
UnaryOp::Neg => apply_neg(val),
UnaryOp::Not => apply_not(val),
UnaryOp::BitNot => Err(Error::UnaryTypeMismatch {
op,
operand: val.typ(),
}),
}
}
fn apply_not<T, V>(val: V) -> Result<V, Error>
where
T: Float,
V: QuaternionValue<T>,
{
match val.typ() {
Type::Scalar => {
let v = val.as_scalar().unwrap();
let result = if v.is_zero() { T::one() } else { T::zero() };
Ok(V::from_scalar(result))
}
_ => Err(Error::UnaryTypeMismatch {
op: UnaryOp::Not,
operand: val.typ(),
}),
}
}
fn apply_add<T, V>(left: V, right: V) -> Result<V, Error>
where
T: Float,
V: QuaternionValue<T>,
{
match (left.typ(), right.typ()) {
(Type::Scalar, Type::Scalar) => {
let a = left.as_scalar().unwrap();
let b = right.as_scalar().unwrap();
Ok(V::from_scalar(a + b))
}
(Type::Vec3, Type::Vec3) => {
let a = left.as_vec3().unwrap();
let b = right.as_vec3().unwrap();
Ok(V::from_vec3([a[0] + b[0], a[1] + b[1], a[2] + b[2]]))
}
(Type::Quaternion, Type::Quaternion) => {
let a = left.as_quaternion().unwrap();
let b = right.as_quaternion().unwrap();
Ok(V::from_quaternion([
a[0] + b[0],
a[1] + b[1],
a[2] + b[2],
a[3] + b[3],
]))
}
_ => Err(Error::BinaryTypeMismatch {
op: BinOp::Add,
left: left.typ(),
right: right.typ(),
}),
}
}
fn apply_sub<T, V>(left: V, right: V) -> Result<V, Error>
where
T: Float,
V: QuaternionValue<T>,
{
match (left.typ(), right.typ()) {
(Type::Scalar, Type::Scalar) => {
let a = left.as_scalar().unwrap();
let b = right.as_scalar().unwrap();
Ok(V::from_scalar(a - b))
}
(Type::Vec3, Type::Vec3) => {
let a = left.as_vec3().unwrap();
let b = right.as_vec3().unwrap();
Ok(V::from_vec3([a[0] - b[0], a[1] - b[1], a[2] - b[2]]))
}
(Type::Quaternion, Type::Quaternion) => {
let a = left.as_quaternion().unwrap();
let b = right.as_quaternion().unwrap();
Ok(V::from_quaternion([
a[0] - b[0],
a[1] - b[1],
a[2] - b[2],
a[3] - b[3],
]))
}
_ => Err(Error::BinaryTypeMismatch {
op: BinOp::Sub,
left: left.typ(),
right: right.typ(),
}),
}
}
fn apply_mul<T, V>(left: V, right: V) -> Result<V, Error>
where
T: Float,
V: QuaternionValue<T>,
{
match (left.typ(), right.typ()) {
(Type::Scalar, Type::Scalar) => {
let a = left.as_scalar().unwrap();
let b = right.as_scalar().unwrap();
Ok(V::from_scalar(a * b))
}
(Type::Scalar, Type::Vec3) => {
let s = left.as_scalar().unwrap();
let v = right.as_vec3().unwrap();
Ok(V::from_vec3([s * v[0], s * v[1], s * v[2]]))
}
(Type::Vec3, Type::Scalar) => {
let v = left.as_vec3().unwrap();
let s = right.as_scalar().unwrap();
Ok(V::from_vec3([v[0] * s, v[1] * s, v[2] * s]))
}
(Type::Scalar, Type::Quaternion) => {
let s = left.as_scalar().unwrap();
let q = right.as_quaternion().unwrap();
Ok(V::from_quaternion([s * q[0], s * q[1], s * q[2], s * q[3]]))
}
(Type::Quaternion, Type::Scalar) => {
let q = left.as_quaternion().unwrap();
let s = right.as_scalar().unwrap();
Ok(V::from_quaternion([q[0] * s, q[1] * s, q[2] * s, q[3] * s]))
}
(Type::Quaternion, Type::Quaternion) => {
let a = left.as_quaternion().unwrap();
let b = right.as_quaternion().unwrap();
let (x1, y1, z1, w1) = (a[0], a[1], a[2], a[3]);
let (x2, y2, z2, w2) = (b[0], b[1], b[2], b[3]);
Ok(V::from_quaternion([
w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2, w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2, w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2, w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2, ]))
}
(Type::Quaternion, Type::Vec3) => {
let q = left.as_quaternion().unwrap();
let v = right.as_vec3().unwrap();
Ok(V::from_vec3(rotate_vec3_by_quat(&v, &q)))
}
_ => Err(Error::BinaryTypeMismatch {
op: BinOp::Mul,
left: left.typ(),
right: right.typ(),
}),
}
}
fn rotate_vec3_by_quat<T: Float>(v: &[T; 3], q: &[T; 4]) -> [T; 3] {
let (qx, qy, qz, qw) = (q[0], q[1], q[2], q[3]);
let two = T::from(2.0).unwrap();
let tx = two * (qy * v[2] - qz * v[1]);
let ty = two * (qz * v[0] - qx * v[2]);
let tz = two * (qx * v[1] - qy * v[0]);
[
v[0] + qw * tx + (qy * tz - qz * ty),
v[1] + qw * ty + (qz * tx - qx * tz),
v[2] + qw * tz + (qx * ty - qy * tx),
]
}
fn apply_div<T, V>(left: V, right: V) -> Result<V, Error>
where
T: Float,
V: QuaternionValue<T>,
{
match (left.typ(), right.typ()) {
(Type::Scalar, Type::Scalar) => {
let a = left.as_scalar().unwrap();
let b = right.as_scalar().unwrap();
Ok(V::from_scalar(a / b))
}
(Type::Vec3, Type::Scalar) => {
let v = left.as_vec3().unwrap();
let s = right.as_scalar().unwrap();
Ok(V::from_vec3([v[0] / s, v[1] / s, v[2] / s]))
}
(Type::Quaternion, Type::Scalar) => {
let q = left.as_quaternion().unwrap();
let s = right.as_scalar().unwrap();
Ok(V::from_quaternion([q[0] / s, q[1] / s, q[2] / s, q[3] / s]))
}
_ => Err(Error::BinaryTypeMismatch {
op: BinOp::Div,
left: left.typ(),
right: right.typ(),
}),
}
}
fn apply_pow<T, V>(left: V, right: V) -> Result<V, Error>
where
T: Float,
V: QuaternionValue<T>,
{
match (left.typ(), right.typ()) {
(Type::Scalar, Type::Scalar) => {
let a = left.as_scalar().unwrap();
let b = right.as_scalar().unwrap();
Ok(V::from_scalar(a.powf(b)))
}
_ => Err(Error::BinaryTypeMismatch {
op: BinOp::Pow,
left: left.typ(),
right: right.typ(),
}),
}
}
fn apply_neg<T, V>(val: V) -> Result<V, Error>
where
T: Float,
V: QuaternionValue<T>,
{
match val.typ() {
Type::Scalar => {
let v = val.as_scalar().unwrap();
Ok(V::from_scalar(-v))
}
Type::Vec3 => {
let v = val.as_vec3().unwrap();
Ok(V::from_vec3([-v[0], -v[1], -v[2]]))
}
Type::Quaternion => {
let q = val.as_quaternion().unwrap();
Ok(V::from_quaternion([-q[0], -q[1], -q[2], -q[3]]))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Value;
fn approx_eq(a: f32, b: f32) -> bool {
(a - b).abs() < 0.0001
}
#[test]
fn test_quaternion_mul_identity() {
let q = Value::Quaternion([1.0_f32, 2.0, 3.0, 4.0]);
let identity = Value::Quaternion([0.0, 0.0, 0.0, 1.0]);
let result = apply_binop(BinOp::Mul, q, identity).unwrap();
assert_eq!(result, Value::Quaternion([1.0, 2.0, 3.0, 4.0]));
}
#[test]
fn test_quaternion_mul_inverse() {
let angle = std::f32::consts::FRAC_PI_4; let q = Value::Quaternion([0.0, 0.0, angle.sin(), angle.cos()]);
let q_conj = Value::Quaternion([0.0, 0.0, -angle.sin(), angle.cos()]);
let result = apply_binop(BinOp::Mul, q, q_conj).unwrap();
if let Value::Quaternion(r) = result {
assert!(approx_eq(r[0], 0.0));
assert!(approx_eq(r[1], 0.0));
assert!(approx_eq(r[2], 0.0));
assert!(approx_eq(r[3], 1.0));
} else {
panic!("expected quaternion");
}
}
#[test]
fn test_rotate_vec3() {
let angle = std::f32::consts::FRAC_PI_4; let q = Value::Quaternion([0.0, 0.0, angle.sin(), angle.cos()]);
let v = Value::Vec3([1.0, 0.0, 0.0]);
let result = apply_binop(BinOp::Mul, q, v).unwrap();
if let Value::Vec3(r) = result {
assert!(approx_eq(r[0], 0.0));
assert!(approx_eq(r[1], 1.0));
assert!(approx_eq(r[2], 0.0));
} else {
panic!("expected vec3");
}
}
}