use crate::{
Error,
context::{Context, Node, Tree},
eval::{BulkEvaluator, Function, MathFunction, Tape, TracingEvaluator},
types::{Grad, Interval},
var::{Var, VarIndex, VarMap},
};
use nalgebra::{Matrix4, Point3};
use std::collections::HashMap;
pub struct Shape<F, T = ()> {
f: F,
axes: [Var; 3],
transform: Option<Matrix4<f32>>,
_marker: std::marker::PhantomData<T>,
}
impl<F: Clone, T> Clone for Shape<F, T> {
fn clone(&self) -> Self {
Self {
f: self.f.clone(),
axes: self.axes,
transform: self.transform,
_marker: std::marker::PhantomData,
}
}
}
impl<F: Function + Clone, T> Shape<F, T> {
pub fn new_point_eval() -> ShapeTracingEval<F::PointEval> {
ShapeTracingEval {
eval: F::PointEval::default(),
scratch: vec![],
}
}
pub fn new_interval_eval() -> ShapeTracingEval<F::IntervalEval> {
ShapeTracingEval {
eval: F::IntervalEval::default(),
scratch: vec![],
}
}
pub fn new_float_slice_eval() -> ShapeBulkEval<F::FloatSliceEval> {
ShapeBulkEval {
eval: F::FloatSliceEval::default(),
scratch: vec![],
}
}
pub fn new_grad_slice_eval() -> ShapeBulkEval<F::GradSliceEval> {
ShapeBulkEval {
eval: F::GradSliceEval::default(),
scratch: vec![],
}
}
#[inline]
pub fn point_tape(
&self,
storage: F::TapeStorage,
) -> ShapeTape<<F::PointEval as TracingEvaluator>::Tape> {
let tape = self.f.point_tape(storage);
let vars = tape.vars();
let axes = self.axes.map(|v| vars.get(&v));
ShapeTape {
tape,
axes,
transform: self.transform,
}
}
#[inline]
pub fn interval_tape(
&self,
storage: F::TapeStorage,
) -> ShapeTape<<F::IntervalEval as TracingEvaluator>::Tape> {
let tape = self.f.interval_tape(storage);
let vars = tape.vars();
let axes = self.axes.map(|v| vars.get(&v));
ShapeTape {
tape,
axes,
transform: self.transform,
}
}
#[inline]
pub fn float_slice_tape(
&self,
storage: F::TapeStorage,
) -> ShapeTape<<F::FloatSliceEval as BulkEvaluator>::Tape> {
let tape = self.f.float_slice_tape(storage);
let vars = tape.vars();
let axes = self.axes.map(|v| vars.get(&v));
ShapeTape {
tape,
axes,
transform: self.transform,
}
}
#[inline]
pub fn grad_slice_tape(
&self,
storage: F::TapeStorage,
) -> ShapeTape<<F::GradSliceEval as BulkEvaluator>::Tape> {
let tape = self.f.grad_slice_tape(storage);
let vars = tape.vars();
let axes = self.axes.map(|v| vars.get(&v));
ShapeTape {
tape,
axes,
transform: self.transform,
}
}
#[inline]
pub fn simplify(
&self,
trace: &F::Trace,
storage: F::Storage,
workspace: &mut F::Workspace,
) -> Result<Self, Error>
where
Self: Sized,
{
let f = self.f.simplify(trace, storage, workspace)?;
Ok(Self {
f,
axes: self.axes,
transform: self.transform,
_marker: std::marker::PhantomData,
})
}
#[inline]
pub fn recycle(self) -> Option<F::Storage> {
self.f.recycle()
}
#[inline]
pub fn size(&self) -> usize {
self.f.size()
}
}
impl<F, T> Shape<F, T> {
pub fn inner(&self) -> &F {
&self.f
}
pub fn axes(&self) -> &[Var; 3] {
&self.axes
}
pub fn new_raw(f: F, axes: [Var; 3]) -> Self {
Self {
f,
axes,
transform: None,
_marker: std::marker::PhantomData,
}
}
}
pub struct Transformed;
impl<F: Clone> Shape<F, ()> {
pub fn with_transform(&self, mat: Matrix4<f32>) -> Shape<F, Transformed> {
Shape {
f: self.f.clone(),
axes: self.axes,
transform: Some(mat),
_marker: std::marker::PhantomData,
}
}
}
impl<F: Clone> Shape<F, Transformed> {
pub fn transform(&self) -> Matrix4<f32> {
self.transform.unwrap()
}
}
pub struct ShapeVars<F>(HashMap<VarIndex, F>);
impl<F> Default for ShapeVars<F> {
fn default() -> Self {
Self(HashMap::default())
}
}
impl<F> ShapeVars<F> {
pub fn new() -> Self {
Self(HashMap::default())
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn insert(&mut self, v: VarIndex, f: F) -> Option<F> {
self.0.insert(v, f)
}
pub fn values(&self) -> impl Iterator<Item = &F> {
self.0.values()
}
}
impl<'a, F> IntoIterator for &'a ShapeVars<F> {
type Item = (&'a VarIndex, &'a F);
type IntoIter = std::collections::hash_map::Iter<'a, VarIndex, F>;
fn into_iter(self) -> Self::IntoIter {
self.0.iter()
}
}
pub trait EzShape<F: Function> {
fn ez_point_tape(
&self,
) -> ShapeTape<<F::PointEval as TracingEvaluator>::Tape>;
fn ez_interval_tape(
&self,
) -> ShapeTape<<F::IntervalEval as TracingEvaluator>::Tape>;
fn ez_float_slice_tape(
&self,
) -> ShapeTape<<F::FloatSliceEval as BulkEvaluator>::Tape>;
fn ez_grad_slice_tape(
&self,
) -> ShapeTape<<F::GradSliceEval as BulkEvaluator>::Tape>;
fn ez_simplify(&self, trace: &F::Trace) -> Result<Self, Error>
where
Self: Sized;
}
impl<F: Function, T> EzShape<F> for Shape<F, T> {
fn ez_point_tape(
&self,
) -> ShapeTape<<F::PointEval as TracingEvaluator>::Tape> {
self.point_tape(Default::default())
}
fn ez_interval_tape(
&self,
) -> ShapeTape<<F::IntervalEval as TracingEvaluator>::Tape> {
self.interval_tape(Default::default())
}
fn ez_float_slice_tape(
&self,
) -> ShapeTape<<F::FloatSliceEval as BulkEvaluator>::Tape> {
self.float_slice_tape(Default::default())
}
fn ez_grad_slice_tape(
&self,
) -> ShapeTape<<F::GradSliceEval as BulkEvaluator>::Tape> {
self.grad_slice_tape(Default::default())
}
fn ez_simplify(&self, trace: &F::Trace) -> Result<Self, Error> {
let mut workspace = Default::default();
self.simplify(trace, Default::default(), &mut workspace)
}
}
impl<F: MathFunction> Shape<F> {
pub fn new_with_axes(
ctx: &Context,
node: Node,
axes: [Var; 3],
) -> Result<Self, Error> {
let f = F::new(ctx, &[node])?;
Ok(Self {
f,
axes,
transform: None,
_marker: std::marker::PhantomData,
})
}
pub fn new(ctx: &Context, node: Node) -> Result<Self, Error>
where
Self: Sized,
{
Self::new_with_axes(ctx, node, [Var::X, Var::Y, Var::Z])
}
}
impl<F: MathFunction> From<Tree> for Shape<F> {
fn from(t: Tree) -> Self {
let mut ctx = Context::new();
let node = ctx.import(&t);
Self::new(&ctx, node).unwrap()
}
}
#[derive(Clone)]
pub struct ShapeTape<T> {
tape: T,
axes: [Option<usize>; 3],
transform: Option<Matrix4<f32>>,
}
impl<T: Tape> ShapeTape<T> {
pub fn recycle(self) -> Option<T::Storage> {
self.tape.recycle()
}
pub fn vars(&self) -> &VarMap {
self.tape.vars()
}
}
#[derive(Debug)]
pub struct ShapeTracingEval<E: TracingEvaluator> {
eval: E,
scratch: Vec<E::Data>,
}
impl<E: TracingEvaluator> Default for ShapeTracingEval<E> {
fn default() -> Self {
Self {
eval: E::default(),
scratch: vec![],
}
}
}
impl<E: TracingEvaluator> ShapeTracingEval<E>
where
<E as TracingEvaluator>::Data: Transformable,
{
#[inline]
pub fn eval<F: Into<E::Data> + Copy>(
&mut self,
tape: &ShapeTape<E::Tape>,
x: F,
y: F,
z: F,
) -> Result<(E::Data, Option<&E::Trace>), Error> {
let h = ShapeVars::<f32>::new();
self.eval_v(tape, x, y, z, &h)
}
#[inline]
pub fn eval_v<F: Into<E::Data> + Copy, V: Into<E::Data> + Copy>(
&mut self,
tape: &ShapeTape<E::Tape>,
x: F,
y: F,
z: F,
vars: &ShapeVars<V>,
) -> Result<(E::Data, Option<&E::Trace>), Error> {
assert_eq!(
tape.tape.output_count(),
1,
"ShapeTape has multiple outputs"
);
let x = x.into();
let y = y.into();
let z = z.into();
let (x, y, z) = if let Some(mat) = tape.transform {
Transformable::transform(x, y, z, mat)
} else {
(x, y, z)
};
let vs = tape.vars();
let expected_vars = vs.len()
- vs.get(&Var::X).is_some() as usize
- vs.get(&Var::Y).is_some() as usize
- vs.get(&Var::Z).is_some() as usize;
if expected_vars != vars.len() {
return Err(Error::BadVarSlice(vars.len(), expected_vars));
}
self.scratch.resize(tape.vars().len(), 0f32.into());
if let Some(a) = tape.axes[0] {
self.scratch[a] = x;
}
if let Some(b) = tape.axes[1] {
self.scratch[b] = y;
}
if let Some(c) = tape.axes[2] {
self.scratch[c] = z;
}
for (var, value) in vars {
if let Some(i) = vs.get(&Var::V(*var)) {
if i < self.scratch.len() {
self.scratch[i] = (*value).into();
} else {
return Err(Error::BadVarIndex(i, self.scratch.len()));
}
} else {
}
}
let (out, trace) = self.eval.eval(&tape.tape, &self.scratch)?;
Ok((out[0], trace))
}
}
#[derive(Debug, Default)]
pub struct ShapeBulkEval<E: BulkEvaluator> {
eval: E,
scratch: Vec<Vec<E::Data>>,
}
impl<E: BulkEvaluator> ShapeBulkEval<E>
where
E::Data: From<f32> + Transformable,
{
#[inline]
pub fn eval(
&mut self,
tape: &ShapeTape<E::Tape>,
x: &[E::Data],
y: &[E::Data],
z: &[E::Data],
) -> Result<&[E::Data], Error> {
let h: ShapeVars<&[E::Data]> = ShapeVars::new();
self.eval_vs(tape, x, y, z, &h)
}
#[inline]
fn setup<V>(
&mut self,
tape: &ShapeTape<E::Tape>,
x: &[E::Data],
y: &[E::Data],
z: &[E::Data],
vars: &ShapeVars<V>,
) -> Result<usize, Error> {
assert_eq!(
tape.tape.output_count(),
1,
"ShapeTape has multiple outputs"
);
if x.len() != y.len() || x.len() != z.len() {
return Err(Error::MismatchedSlices);
}
let n = x.len();
let vs = tape.vars();
let expected_vars = vs.len()
- vs.get(&Var::X).is_some() as usize
- vs.get(&Var::Y).is_some() as usize
- vs.get(&Var::Z).is_some() as usize;
if expected_vars != vars.len() {
return Err(Error::BadVarSlice(vars.len(), expected_vars));
}
self.scratch.resize_with(vs.len().max(1), Vec::new);
for s in &mut self.scratch {
s.resize(n, 0.0.into());
}
if let Some(mat) = tape.transform {
for i in 0..n {
let (x, y, z) = Transformable::transform(x[i], y[i], z[i], mat);
if let Some(a) = tape.axes[0] {
self.scratch[a][i] = x;
}
if let Some(b) = tape.axes[1] {
self.scratch[b][i] = y;
}
if let Some(c) = tape.axes[2] {
self.scratch[c][i] = z;
}
}
} else {
if let Some(a) = tape.axes[0] {
self.scratch[a].copy_from_slice(x);
}
if let Some(b) = tape.axes[1] {
self.scratch[b].copy_from_slice(y);
}
if let Some(c) = tape.axes[2] {
self.scratch[c].copy_from_slice(z);
}
};
Ok(n)
}
#[inline]
pub fn eval_vs<
V: std::ops::Deref<Target = [G]>,
G: Into<E::Data> + Copy,
>(
&mut self,
tape: &ShapeTape<E::Tape>,
x: &[E::Data],
y: &[E::Data],
z: &[E::Data],
vars: &ShapeVars<V>,
) -> Result<&[E::Data], Error> {
let n = self.setup(tape, x, y, z, vars)?;
if vars.values().any(|vs| vs.len() != n) {
return Err(Error::MismatchedSlices);
}
let vs = tape.vars();
for (var, value) in vars {
if let Some(i) = vs.get(&Var::V(*var)) {
if i < self.scratch.len() {
for (a, b) in
self.scratch[i].iter_mut().zip(value.deref().iter())
{
*a = (*b).into();
}
} else {
return Err(Error::BadVarIndex(i, self.scratch.len()));
}
} else {
}
}
let out = self.eval.eval(&tape.tape, &self.scratch)?;
Ok(out.borrow(0))
}
#[inline]
pub fn eval_v<G: Into<E::Data> + Copy>(
&mut self,
tape: &ShapeTape<E::Tape>,
x: &[E::Data],
y: &[E::Data],
z: &[E::Data],
vars: &ShapeVars<G>,
) -> Result<&[E::Data], Error> {
self.setup(tape, x, y, z, vars)?;
let vs = tape.vars();
for (var, value) in vars {
if let Some(i) = vs.get(&Var::V(*var)) {
if i < self.scratch.len() {
self.scratch[i].fill((*value).into());
} else {
return Err(Error::BadVarIndex(i, self.scratch.len()));
}
} else {
}
}
let out = self.eval.eval(&tape.tape, &self.scratch)?;
Ok(out.borrow(0))
}
}
pub trait Transformable {
fn transform(
x: Self,
y: Self,
z: Self,
mat: Matrix4<f32>,
) -> (Self, Self, Self)
where
Self: Sized;
}
impl Transformable for f32 {
fn transform(x: f32, y: f32, z: f32, mat: Matrix4<f32>) -> (f32, f32, f32) {
let out = mat.transform_point(&Point3::new(x, y, z));
(out.x, out.y, out.z)
}
}
impl Transformable for Interval {
fn transform(
x: Interval,
y: Interval,
z: Interval,
mat: Matrix4<f32>,
) -> (Interval, Interval, Interval) {
let out = [0, 1, 2, 3].map(|i| {
let row = mat.row(i);
x * row[0] + y * row[1] + z * row[2] + Interval::from(row[3])
});
(out[0] / out[3], out[1] / out[3], out[2] / out[3])
}
}
impl Transformable for Grad {
fn transform(
x: Grad,
y: Grad,
z: Grad,
mat: Matrix4<f32>,
) -> (Grad, Grad, Grad) {
let out = [0, 1, 2, 3].map(|i| {
let row = mat.row(i);
x * row[0] + y * row[1] + z * row[2] + Grad::from(row[3])
});
(out[0] / out[3], out[1] / out[3], out[2] / out[3])
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::vm::VmShape;
#[test]
fn shape_vars() {
let v = Var::new();
let s = Tree::x() + Tree::y() + v;
let mut ctx = Context::new();
let s = ctx.import(&s);
let s = VmShape::new(&ctx, s).unwrap();
let vs = s.inner().vars();
assert_eq!(vs.len(), 3);
assert!(vs.get(&Var::X).is_some());
assert!(vs.get(&Var::Y).is_some());
assert!(vs.get(&Var::Z).is_none());
assert!(vs.get(&v).is_some());
let mut seen = [false; 3];
for v in [Var::X, Var::Y, v] {
seen[vs[&v]] = true;
}
assert!(seen.iter().all(|i| *i));
}
#[test]
fn shape_eval_bulk_size() {
let s = Tree::constant(1.0);
let mut ctx = Context::new();
let s = ctx.import(&s);
let s = VmShape::new(&ctx, s).unwrap();
let tape = s.ez_float_slice_tape();
let mut eval = VmShape::new_float_slice_eval();
let out = eval
.eval_v::<f32>(
&tape,
&[1.0, 2.0, 3.0],
&[4.0, 5.0, 6.0],
&[7.0, 8.0, 9.0],
&ShapeVars::default(),
)
.unwrap();
assert_eq!(out, [1.0, 1.0, 1.0]);
}
}