use num_traits::Float;
use std::collections::HashMap;
use std::sync::Arc;
use wick_core::{Ast, BinOp, CompareOp, UnaryOp};
mod funcs;
pub mod ops;
#[cfg(test)]
mod parity_tests;
#[cfg(feature = "wgsl")]
pub mod wgsl;
#[cfg(feature = "glsl")]
pub mod glsl;
#[cfg(feature = "rust")]
pub mod rust;
#[cfg(feature = "c")]
pub mod c;
#[cfg(feature = "opencl")]
pub mod opencl;
#[cfg(feature = "cuda")]
pub mod cuda;
#[cfg(feature = "hip")]
pub mod hip;
#[cfg(feature = "tokenstream")]
pub mod tokenstream;
#[cfg(feature = "lua-codegen")]
pub mod lua;
#[cfg(feature = "cranelift")]
pub mod cranelift;
#[cfg(feature = "optimize")]
pub mod optimize;
pub use funcs::{
AxisAngle, Conj, Dot, Inverse, Length, Lerp, Normalize, QuatConstructor, Rotate, Slerp,
Vec3Constructor, quaternion_registry, register_quaternion,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Type {
Scalar,
Vec3,
Quaternion,
}
impl std::fmt::Display for Type {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Type::Scalar => write!(f, "scalar"),
Type::Vec3 => write!(f, "vec3"),
Type::Quaternion => write!(f, "quaternion"),
}
}
}
pub trait QuaternionValue<T: Float>: Clone + PartialEq + Sized + std::fmt::Debug {
fn typ(&self) -> Type;
fn from_scalar(v: T) -> Self;
fn from_vec3(v: [T; 3]) -> Self;
fn from_quaternion(q: [T; 4]) -> Self;
fn as_scalar(&self) -> Option<T>;
fn as_vec3(&self) -> Option<[T; 3]>;
fn as_quaternion(&self) -> Option<[T; 4]>;
}
#[derive(Debug, Clone, PartialEq)]
pub enum Value<T> {
Scalar(T),
Vec3([T; 3]),
Quaternion([T; 4]),
}
impl<T> Value<T> {
pub fn typ(&self) -> Type {
match self {
Value::Scalar(_) => Type::Scalar,
Value::Vec3(_) => Type::Vec3,
Value::Quaternion(_) => Type::Quaternion,
}
}
}
impl<T: Copy> Value<T> {
pub fn as_scalar(&self) -> Option<T> {
match self {
Value::Scalar(v) => Some(*v),
_ => None,
}
}
pub fn as_vec3(&self) -> Option<[T; 3]> {
match self {
Value::Vec3(v) => Some(*v),
_ => None,
}
}
pub fn as_quaternion(&self) -> Option<[T; 4]> {
match self {
Value::Quaternion(q) => Some(*q),
_ => None,
}
}
}
impl<T: Float + std::fmt::Debug> QuaternionValue<T> for Value<T> {
fn typ(&self) -> Type {
Value::typ(self)
}
fn from_scalar(v: T) -> Self {
Value::Scalar(v)
}
fn from_vec3(v: [T; 3]) -> Self {
Value::Vec3(v)
}
fn from_quaternion(q: [T; 4]) -> Self {
Value::Quaternion(q)
}
fn as_scalar(&self) -> Option<T> {
Value::as_scalar(self)
}
fn as_vec3(&self) -> Option<[T; 3]> {
Value::as_vec3(self)
}
fn as_quaternion(&self) -> Option<[T; 4]> {
Value::as_quaternion(self)
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum Error {
UnknownVariable(String),
UnknownFunction(String),
BinaryTypeMismatch { op: BinOp, left: Type, right: Type },
UnaryTypeMismatch { op: UnaryOp, operand: Type },
WrongArgCount {
func: String,
expected: usize,
got: usize,
},
FunctionTypeMismatch {
func: String,
expected: Vec<Type>,
got: Vec<Type>,
},
UnsupportedTypeForConditional(Type),
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Error::UnknownVariable(name) => write!(f, "unknown variable: '{name}'"),
Error::UnknownFunction(name) => write!(f, "unknown function: '{name}'"),
Error::BinaryTypeMismatch { op, left, right } => {
write!(f, "cannot apply {op:?} to {left} and {right}")
}
Error::UnaryTypeMismatch { op, operand } => {
write!(f, "cannot apply {op:?} to {operand}")
}
Error::WrongArgCount {
func,
expected,
got,
} => {
write!(f, "function '{func}' expects {expected} args, got {got}")
}
Error::FunctionTypeMismatch {
func,
expected,
got,
} => {
write!(
f,
"function '{func}' expects types {expected:?}, got {got:?}"
)
}
Error::UnsupportedTypeForConditional(t) => {
write!(f, "conditionals require scalar type, got {t}")
}
}
}
}
impl std::error::Error for Error {}
#[derive(Debug, Clone, PartialEq)]
pub struct Signature {
pub args: Vec<Type>,
pub ret: Type,
}
pub trait QuaternionFn<T, V>: Send + Sync
where
T: Float,
V: QuaternionValue<T>,
{
fn name(&self) -> &str;
fn signatures(&self) -> Vec<Signature>;
fn call(&self, args: &[V]) -> V;
}
#[derive(Clone)]
pub struct FunctionRegistry<T, V>
where
T: Float,
V: QuaternionValue<T>,
{
funcs: HashMap<String, Arc<dyn QuaternionFn<T, V>>>,
}
impl<T, V> Default for FunctionRegistry<T, V>
where
T: Float,
V: QuaternionValue<T>,
{
fn default() -> Self {
Self {
funcs: HashMap::new(),
}
}
}
impl<T, V> FunctionRegistry<T, V>
where
T: Float,
V: QuaternionValue<T>,
{
pub fn new() -> Self {
Self::default()
}
pub fn register<F: QuaternionFn<T, V> + 'static>(&mut self, func: F) {
self.funcs.insert(func.name().to_string(), Arc::new(func));
}
pub fn get(&self, name: &str) -> Option<&Arc<dyn QuaternionFn<T, V>>> {
self.funcs.get(name)
}
}
pub fn eval<T, V>(
ast: &Ast,
vars: &HashMap<String, V>,
funcs: &FunctionRegistry<T, V>,
) -> Result<V, Error>
where
T: Float,
V: QuaternionValue<T>,
{
match ast {
Ast::Num(n) => Ok(V::from_scalar(T::from(*n).unwrap())),
Ast::Var(name) => vars
.get(name)
.cloned()
.ok_or_else(|| Error::UnknownVariable(name.clone())),
Ast::BinOp(op, left, right) => {
let left_val = eval(left, vars, funcs)?;
let right_val = eval(right, vars, funcs)?;
ops::apply_binop(*op, left_val, right_val)
}
Ast::UnaryOp(op, inner) => {
let val = eval(inner, vars, funcs)?;
ops::apply_unaryop(*op, val)
}
Ast::Call(name, args) => {
let func = funcs
.get(name)
.ok_or_else(|| Error::UnknownFunction(name.clone()))?;
let arg_vals: Vec<V> = args
.iter()
.map(|a| eval(a, vars, funcs))
.collect::<Result<_, _>>()?;
let arg_types: Vec<Type> = arg_vals.iter().map(|v| v.typ()).collect();
let matched = func.signatures().iter().any(|sig| sig.args == arg_types);
if !matched {
return Err(Error::FunctionTypeMismatch {
func: name.clone(),
expected: func
.signatures()
.first()
.map(|s| s.args.clone())
.unwrap_or_default(),
got: arg_types,
});
}
Ok(func.call(&arg_vals))
}
Ast::Compare(op, left, right) => {
let left_val = eval(left, vars, funcs)?;
let right_val = eval(right, vars, funcs)?;
match (left_val.as_scalar(), right_val.as_scalar()) {
(Some(l), Some(r)) => {
let result = match op {
CompareOp::Lt => l < r,
CompareOp::Le => l <= r,
CompareOp::Gt => l > r,
CompareOp::Ge => l >= r,
CompareOp::Eq => l == r,
CompareOp::Ne => l != r,
};
Ok(V::from_scalar(if result { T::one() } else { T::zero() }))
}
_ => Err(Error::UnsupportedTypeForConditional(left_val.typ())),
}
}
Ast::And(left, right) => {
let left_val = eval(left, vars, funcs)?;
let right_val = eval(right, vars, funcs)?;
match (left_val.as_scalar(), right_val.as_scalar()) {
(Some(l), Some(r)) => {
let result = !l.is_zero() && !r.is_zero();
Ok(V::from_scalar(if result { T::one() } else { T::zero() }))
}
_ => Err(Error::UnsupportedTypeForConditional(left_val.typ())),
}
}
Ast::Or(left, right) => {
let left_val = eval(left, vars, funcs)?;
let right_val = eval(right, vars, funcs)?;
match (left_val.as_scalar(), right_val.as_scalar()) {
(Some(l), Some(r)) => {
let result = !l.is_zero() || !r.is_zero();
Ok(V::from_scalar(if result { T::one() } else { T::zero() }))
}
_ => Err(Error::UnsupportedTypeForConditional(left_val.typ())),
}
}
Ast::If(cond, then_ast, else_ast) => {
let cond_val = eval(cond, vars, funcs)?;
if let Some(c) = cond_val.as_scalar() {
if !c.is_zero() {
eval(then_ast, vars, funcs)
} else {
eval(else_ast, vars, funcs)
}
} else {
Err(Error::UnsupportedTypeForConditional(cond_val.typ()))
}
}
Ast::Let { name, value, body } => {
let val = eval(value, vars, funcs)?;
let mut new_vars = vars.clone();
new_vars.insert(name.clone(), val);
eval(body, &new_vars, funcs)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use wick_core::Expr;
fn eval_expr(expr: &str, vars: &[(&str, Value<f32>)]) -> Result<Value<f32>, Error> {
let expr = Expr::parse(expr).unwrap();
let var_map: HashMap<String, Value<f32>> = vars
.iter()
.map(|(k, v)| (k.to_string(), v.clone()))
.collect();
let registry = quaternion_registry();
eval(expr.ast(), &var_map, ®istry)
}
#[test]
fn test_quaternion_add() {
let result = eval_expr(
"a + b",
&[
("a", Value::Quaternion([1.0, 2.0, 3.0, 4.0])),
("b", Value::Quaternion([5.0, 6.0, 7.0, 8.0])),
],
);
assert_eq!(result.unwrap(), Value::Quaternion([6.0, 8.0, 10.0, 12.0]));
}
#[test]
fn test_quaternion_mul() {
let result = eval_expr(
"a * b",
&[
("a", Value::Quaternion([1.0, 2.0, 3.0, 4.0])),
("b", Value::Quaternion([0.0, 0.0, 0.0, 1.0])),
],
);
assert_eq!(result.unwrap(), Value::Quaternion([1.0, 2.0, 3.0, 4.0]));
}
#[test]
fn test_quaternion_neg() {
let result = eval_expr("-q", &[("q", Value::Quaternion([1.0, 2.0, 3.0, 4.0]))]);
assert_eq!(result.unwrap(), Value::Quaternion([-1.0, -2.0, -3.0, -4.0]));
}
#[test]
fn test_quaternion_scalar_mul() {
let result = eval_expr(
"s * q",
&[
("s", Value::Scalar(2.0)),
("q", Value::Quaternion([1.0, 2.0, 3.0, 4.0])),
],
);
assert_eq!(result.unwrap(), Value::Quaternion([2.0, 4.0, 6.0, 8.0]));
}
}