1use crate::{
30 Error,
31 context::{Context, Node, Tree},
32 eval::{BulkEvaluator, Function, MathFunction, Tape, TracingEvaluator},
33 types::{Grad, Interval},
34 var::{Var, VarIndex, VarMap},
35};
36use nalgebra::{Matrix4, Point3};
37use std::collections::HashMap;
38
39pub struct Shape<F, T = ()> {
58 f: F,
60
61 axes: [Var; 3],
63
64 transform: Option<Matrix4<f32>>,
69
70 _marker: std::marker::PhantomData<T>,
71}
72
73impl<F: Clone, T> Clone for Shape<F, T> {
74 fn clone(&self) -> Self {
75 Self {
76 f: self.f.clone(),
77 axes: self.axes,
78 transform: self.transform,
79 _marker: std::marker::PhantomData,
80 }
81 }
82}
83
84impl<F: Function + Clone, T> Shape<F, T> {
85 pub fn new_point_eval() -> ShapeTracingEval<F::PointEval> {
87 ShapeTracingEval {
88 eval: F::PointEval::default(),
89 scratch: vec![],
90 }
91 }
92
93 pub fn new_interval_eval() -> ShapeTracingEval<F::IntervalEval> {
95 ShapeTracingEval {
96 eval: F::IntervalEval::default(),
97 scratch: vec![],
98 }
99 }
100
101 pub fn new_float_slice_eval() -> ShapeBulkEval<F::FloatSliceEval> {
103 ShapeBulkEval {
104 eval: F::FloatSliceEval::default(),
105 scratch: vec![],
106 }
107 }
108
109 pub fn new_grad_slice_eval() -> ShapeBulkEval<F::GradSliceEval> {
111 ShapeBulkEval {
112 eval: F::GradSliceEval::default(),
113 scratch: vec![],
114 }
115 }
116
117 #[inline]
119 pub fn point_tape(
120 &self,
121 storage: F::TapeStorage,
122 ) -> ShapeTape<<F::PointEval as TracingEvaluator>::Tape> {
123 let tape = self.f.point_tape(storage);
124 let vars = tape.vars();
125 let axes = self.axes.map(|v| vars.get(&v));
126 ShapeTape {
127 tape,
128 axes,
129 transform: self.transform,
130 }
131 }
132
133 #[inline]
135 pub fn interval_tape(
136 &self,
137 storage: F::TapeStorage,
138 ) -> ShapeTape<<F::IntervalEval as TracingEvaluator>::Tape> {
139 let tape = self.f.interval_tape(storage);
140 let vars = tape.vars();
141 let axes = self.axes.map(|v| vars.get(&v));
142 ShapeTape {
143 tape,
144 axes,
145 transform: self.transform,
146 }
147 }
148
149 #[inline]
151 pub fn float_slice_tape(
152 &self,
153 storage: F::TapeStorage,
154 ) -> ShapeTape<<F::FloatSliceEval as BulkEvaluator>::Tape> {
155 let tape = self.f.float_slice_tape(storage);
156 let vars = tape.vars();
157 let axes = self.axes.map(|v| vars.get(&v));
158 ShapeTape {
159 tape,
160 axes,
161 transform: self.transform,
162 }
163 }
164
165 #[inline]
167 pub fn grad_slice_tape(
168 &self,
169 storage: F::TapeStorage,
170 ) -> ShapeTape<<F::GradSliceEval as BulkEvaluator>::Tape> {
171 let tape = self.f.grad_slice_tape(storage);
172 let vars = tape.vars();
173 let axes = self.axes.map(|v| vars.get(&v));
174 ShapeTape {
175 tape,
176 axes,
177 transform: self.transform,
178 }
179 }
180
181 #[inline]
183 pub fn simplify(
184 &self,
185 trace: &F::Trace,
186 storage: F::Storage,
187 workspace: &mut F::Workspace,
188 ) -> Result<Self, Error>
189 where
190 Self: Sized,
191 {
192 let f = self.f.simplify(trace, storage, workspace)?;
193 Ok(Self {
194 f,
195 axes: self.axes,
196 transform: self.transform,
197 _marker: std::marker::PhantomData,
198 })
199 }
200
201 #[inline]
206 pub fn recycle(self) -> Option<F::Storage> {
207 self.f.recycle()
208 }
209
210 #[inline]
215 pub fn size(&self) -> usize {
216 self.f.size()
217 }
218}
219
220impl<F, T> Shape<F, T> {
221 pub fn inner(&self) -> &F {
223 &self.f
224 }
225
226 pub fn axes(&self) -> &[Var; 3] {
228 &self.axes
229 }
230
231 pub fn new_raw(f: F, axes: [Var; 3]) -> Self {
233 Self {
234 f,
235 axes,
236 transform: None,
237 _marker: std::marker::PhantomData,
238 }
239 }
240}
241
242pub struct Transformed;
244
245impl<F: Clone> Shape<F, ()> {
246 pub fn with_transform(&self, mat: Matrix4<f32>) -> Shape<F, Transformed> {
248 Shape {
249 f: self.f.clone(),
250 axes: self.axes,
251 transform: Some(mat),
252 _marker: std::marker::PhantomData,
253 }
254 }
255}
256
257impl<F: Clone> Shape<F, Transformed> {
258 pub fn transform(&self) -> Matrix4<f32> {
260 self.transform.unwrap()
261 }
262}
263
264pub struct ShapeVars<F>(HashMap<VarIndex, F>);
270
271impl<F> Default for ShapeVars<F> {
272 fn default() -> Self {
273 Self(HashMap::default())
274 }
275}
276
277impl<F> ShapeVars<F> {
278 pub fn new() -> Self {
280 Self(HashMap::default())
281 }
282 pub fn len(&self) -> usize {
284 self.0.len()
285 }
286 pub fn is_empty(&self) -> bool {
288 self.0.is_empty()
289 }
290 pub fn insert(&mut self, v: VarIndex, f: F) -> Option<F> {
294 self.0.insert(v, f)
295 }
296
297 pub fn values(&self) -> impl Iterator<Item = &F> {
299 self.0.values()
300 }
301}
302
303impl<'a, F> IntoIterator for &'a ShapeVars<F> {
304 type Item = (&'a VarIndex, &'a F);
305 type IntoIter = std::collections::hash_map::Iter<'a, VarIndex, F>;
306 fn into_iter(self) -> Self::IntoIter {
307 self.0.iter()
308 }
309}
310
311pub trait EzShape<F: Function> {
321 fn ez_point_tape(
323 &self,
324 ) -> ShapeTape<<F::PointEval as TracingEvaluator>::Tape>;
325
326 fn ez_interval_tape(
328 &self,
329 ) -> ShapeTape<<F::IntervalEval as TracingEvaluator>::Tape>;
330
331 fn ez_float_slice_tape(
333 &self,
334 ) -> ShapeTape<<F::FloatSliceEval as BulkEvaluator>::Tape>;
335
336 fn ez_grad_slice_tape(
338 &self,
339 ) -> ShapeTape<<F::GradSliceEval as BulkEvaluator>::Tape>;
340
341 fn ez_simplify(&self, trace: &F::Trace) -> Result<Self, Error>
343 where
344 Self: Sized;
345}
346
347impl<F: Function, T> EzShape<F> for Shape<F, T> {
348 fn ez_point_tape(
349 &self,
350 ) -> ShapeTape<<F::PointEval as TracingEvaluator>::Tape> {
351 self.point_tape(Default::default())
352 }
353
354 fn ez_interval_tape(
355 &self,
356 ) -> ShapeTape<<F::IntervalEval as TracingEvaluator>::Tape> {
357 self.interval_tape(Default::default())
358 }
359
360 fn ez_float_slice_tape(
361 &self,
362 ) -> ShapeTape<<F::FloatSliceEval as BulkEvaluator>::Tape> {
363 self.float_slice_tape(Default::default())
364 }
365
366 fn ez_grad_slice_tape(
367 &self,
368 ) -> ShapeTape<<F::GradSliceEval as BulkEvaluator>::Tape> {
369 self.grad_slice_tape(Default::default())
370 }
371
372 fn ez_simplify(&self, trace: &F::Trace) -> Result<Self, Error> {
373 let mut workspace = Default::default();
374 self.simplify(trace, Default::default(), &mut workspace)
375 }
376}
377
378impl<F: MathFunction> Shape<F> {
379 pub fn new_with_axes(
381 ctx: &Context,
382 node: Node,
383 axes: [Var; 3],
384 ) -> Result<Self, Error> {
385 let f = F::new(ctx, &[node])?;
386 Ok(Self {
387 f,
388 axes,
389 transform: None,
390 _marker: std::marker::PhantomData,
391 })
392 }
393
394 pub fn new(ctx: &Context, node: Node) -> Result<Self, Error>
396 where
397 Self: Sized,
398 {
399 Self::new_with_axes(ctx, node, [Var::X, Var::Y, Var::Z])
400 }
401}
402
403impl<F: MathFunction> From<Tree> for Shape<F> {
405 fn from(t: Tree) -> Self {
406 let mut ctx = Context::new();
407 let node = ctx.import(&t);
408 Self::new(&ctx, node).unwrap()
409 }
410}
411
412#[derive(Clone)]
414pub struct ShapeTape<T> {
415 tape: T,
416
417 axes: [Option<usize>; 3],
419
420 transform: Option<Matrix4<f32>>,
422}
423
424impl<T: Tape> ShapeTape<T> {
425 pub fn recycle(self) -> Option<T::Storage> {
427 self.tape.recycle()
428 }
429
430 pub fn vars(&self) -> &VarMap {
432 self.tape.vars()
433 }
434}
435
436#[derive(Debug)]
441pub struct ShapeTracingEval<E: TracingEvaluator> {
442 eval: E,
443 scratch: Vec<E::Data>,
444}
445
446impl<E: TracingEvaluator> Default for ShapeTracingEval<E> {
447 fn default() -> Self {
448 Self {
449 eval: E::default(),
450 scratch: vec![],
451 }
452 }
453}
454
455impl<E: TracingEvaluator> ShapeTracingEval<E>
456where
457 <E as TracingEvaluator>::Data: Transformable,
458{
459 #[inline]
466 pub fn eval<F: Into<E::Data> + Copy>(
467 &mut self,
468 tape: &ShapeTape<E::Tape>,
469 x: F,
470 y: F,
471 z: F,
472 ) -> Result<(E::Data, Option<&E::Trace>), Error> {
473 let h = ShapeVars::<f32>::new();
474 self.eval_v(tape, x, y, z, &h)
475 }
476
477 #[inline]
481 pub fn eval_v<F: Into<E::Data> + Copy, V: Into<E::Data> + Copy>(
482 &mut self,
483 tape: &ShapeTape<E::Tape>,
484 x: F,
485 y: F,
486 z: F,
487 vars: &ShapeVars<V>,
488 ) -> Result<(E::Data, Option<&E::Trace>), Error> {
489 assert_eq!(
490 tape.tape.output_count(),
491 1,
492 "ShapeTape has multiple outputs"
493 );
494
495 let x = x.into();
496 let y = y.into();
497 let z = z.into();
498 let (x, y, z) = if let Some(mat) = tape.transform {
499 Transformable::transform(x, y, z, mat)
500 } else {
501 (x, y, z)
502 };
503
504 let vs = tape.vars();
505 let expected_vars = vs.len()
506 - vs.get(&Var::X).is_some() as usize
507 - vs.get(&Var::Y).is_some() as usize
508 - vs.get(&Var::Z).is_some() as usize;
509 if expected_vars != vars.len() {
510 return Err(Error::BadVarSlice(vars.len(), expected_vars));
511 }
512
513 self.scratch.resize(tape.vars().len(), 0f32.into());
514 if let Some(a) = tape.axes[0] {
515 self.scratch[a] = x;
516 }
517 if let Some(b) = tape.axes[1] {
518 self.scratch[b] = y;
519 }
520 if let Some(c) = tape.axes[2] {
521 self.scratch[c] = z;
522 }
523 for (var, value) in vars {
524 if let Some(i) = vs.get(&Var::V(*var)) {
525 if i < self.scratch.len() {
526 self.scratch[i] = (*value).into();
527 } else {
528 return Err(Error::BadVarIndex(i, self.scratch.len()));
529 }
530 } else {
531 }
533 }
534
535 let (out, trace) = self.eval.eval(&tape.tape, &self.scratch)?;
536 Ok((out[0], trace))
537 }
538}
539
540#[derive(Debug, Default)]
545pub struct ShapeBulkEval<E: BulkEvaluator> {
546 eval: E,
547 scratch: Vec<Vec<E::Data>>,
548}
549
550impl<E: BulkEvaluator> ShapeBulkEval<E>
551where
552 E::Data: From<f32> + Transformable,
553{
554 #[inline]
562 pub fn eval(
563 &mut self,
564 tape: &ShapeTape<E::Tape>,
565 x: &[E::Data],
566 y: &[E::Data],
567 z: &[E::Data],
568 ) -> Result<&[E::Data], Error> {
569 let h: ShapeVars<&[E::Data]> = ShapeVars::new();
570 self.eval_vs(tape, x, y, z, &h)
571 }
572
573 #[inline]
575 fn setup<V>(
576 &mut self,
577 tape: &ShapeTape<E::Tape>,
578 x: &[E::Data],
579 y: &[E::Data],
580 z: &[E::Data],
581 vars: &ShapeVars<V>,
582 ) -> Result<usize, Error> {
583 assert_eq!(
584 tape.tape.output_count(),
585 1,
586 "ShapeTape has multiple outputs"
587 );
588
589 if x.len() != y.len() || x.len() != z.len() {
591 return Err(Error::MismatchedSlices);
592 }
593 let n = x.len();
594
595 let vs = tape.vars();
596 let expected_vars = vs.len()
597 - vs.get(&Var::X).is_some() as usize
598 - vs.get(&Var::Y).is_some() as usize
599 - vs.get(&Var::Z).is_some() as usize;
600 if expected_vars != vars.len() {
601 return Err(Error::BadVarSlice(vars.len(), expected_vars));
602 }
603
604 self.scratch.resize_with(vs.len().max(1), Vec::new);
607 for s in &mut self.scratch {
608 s.resize(n, 0.0.into());
609 }
610
611 if let Some(mat) = tape.transform {
612 for i in 0..n {
613 let (x, y, z) = Transformable::transform(x[i], y[i], z[i], mat);
614 if let Some(a) = tape.axes[0] {
615 self.scratch[a][i] = x;
616 }
617 if let Some(b) = tape.axes[1] {
618 self.scratch[b][i] = y;
619 }
620 if let Some(c) = tape.axes[2] {
621 self.scratch[c][i] = z;
622 }
623 }
624 } else {
625 if let Some(a) = tape.axes[0] {
626 self.scratch[a].copy_from_slice(x);
627 }
628 if let Some(b) = tape.axes[1] {
629 self.scratch[b].copy_from_slice(y);
630 }
631 if let Some(c) = tape.axes[2] {
632 self.scratch[c].copy_from_slice(z);
633 }
634 };
636
637 Ok(n)
638 }
639 #[inline]
649 pub fn eval_vs<
650 V: std::ops::Deref<Target = [G]>,
651 G: Into<E::Data> + Copy,
652 >(
653 &mut self,
654 tape: &ShapeTape<E::Tape>,
655 x: &[E::Data],
656 y: &[E::Data],
657 z: &[E::Data],
658 vars: &ShapeVars<V>,
659 ) -> Result<&[E::Data], Error> {
660 let n = self.setup(tape, x, y, z, vars)?;
661
662 if vars.values().any(|vs| vs.len() != n) {
663 return Err(Error::MismatchedSlices);
664 }
665
666 let vs = tape.vars();
667 for (var, value) in vars {
668 if let Some(i) = vs.get(&Var::V(*var)) {
669 if i < self.scratch.len() {
670 for (a, b) in
671 self.scratch[i].iter_mut().zip(value.deref().iter())
672 {
673 *a = (*b).into();
674 }
675 } else {
677 return Err(Error::BadVarIndex(i, self.scratch.len()));
678 }
679 } else {
680 }
682 }
683
684 let out = self.eval.eval(&tape.tape, &self.scratch)?;
685 Ok(out.borrow(0))
686 }
687
688 #[inline]
697 pub fn eval_v<G: Into<E::Data> + Copy>(
698 &mut self,
699 tape: &ShapeTape<E::Tape>,
700 x: &[E::Data],
701 y: &[E::Data],
702 z: &[E::Data],
703 vars: &ShapeVars<G>,
704 ) -> Result<&[E::Data], Error> {
705 self.setup(tape, x, y, z, vars)?;
706 let vs = tape.vars();
707 for (var, value) in vars {
708 if let Some(i) = vs.get(&Var::V(*var)) {
709 if i < self.scratch.len() {
710 self.scratch[i].fill((*value).into());
711 } else {
712 return Err(Error::BadVarIndex(i, self.scratch.len()));
713 }
714 } else {
715 }
717 }
718
719 let out = self.eval.eval(&tape.tape, &self.scratch)?;
720 Ok(out.borrow(0))
721 }
722}
723
724pub trait Transformable {
726 fn transform(
728 x: Self,
729 y: Self,
730 z: Self,
731 mat: Matrix4<f32>,
732 ) -> (Self, Self, Self)
733 where
734 Self: Sized;
735}
736
737impl Transformable for f32 {
738 fn transform(x: f32, y: f32, z: f32, mat: Matrix4<f32>) -> (f32, f32, f32) {
739 let out = mat.transform_point(&Point3::new(x, y, z));
740 (out.x, out.y, out.z)
741 }
742}
743
744impl Transformable for Interval {
745 fn transform(
746 x: Interval,
747 y: Interval,
748 z: Interval,
749 mat: Matrix4<f32>,
750 ) -> (Interval, Interval, Interval) {
751 let out = [0, 1, 2, 3].map(|i| {
752 let row = mat.row(i);
753 x * row[0] + y * row[1] + z * row[2] + Interval::from(row[3])
754 });
755
756 (out[0] / out[3], out[1] / out[3], out[2] / out[3])
757 }
758}
759
760impl Transformable for Grad {
761 fn transform(
762 x: Grad,
763 y: Grad,
764 z: Grad,
765 mat: Matrix4<f32>,
766 ) -> (Grad, Grad, Grad) {
767 let out = [0, 1, 2, 3].map(|i| {
768 let row = mat.row(i);
769 x * row[0] + y * row[1] + z * row[2] + Grad::from(row[3])
770 });
771
772 (out[0] / out[3], out[1] / out[3], out[2] / out[3])
773 }
774}
775
776#[cfg(test)]
777mod test {
778 use super::*;
779 use crate::vm::VmShape;
780
781 #[test]
782 fn shape_vars() {
783 let v = Var::new();
784 let s = Tree::x() + Tree::y() + v;
785
786 let mut ctx = Context::new();
787 let s = ctx.import(&s);
788
789 let s = VmShape::new(&ctx, s).unwrap();
790 let vs = s.inner().vars();
791 assert_eq!(vs.len(), 3);
792
793 assert!(vs.get(&Var::X).is_some());
794 assert!(vs.get(&Var::Y).is_some());
795 assert!(vs.get(&Var::Z).is_none());
796 assert!(vs.get(&v).is_some());
797
798 let mut seen = [false; 3];
799 for v in [Var::X, Var::Y, v] {
800 seen[vs[&v]] = true;
801 }
802 assert!(seen.iter().all(|i| *i));
803 }
804
805 #[test]
806 fn shape_eval_bulk_size() {
807 let s = Tree::constant(1.0);
808 let mut ctx = Context::new();
809 let s = ctx.import(&s);
810
811 let s = VmShape::new(&ctx, s).unwrap();
812 let tape = s.ez_float_slice_tape();
813 let mut eval = VmShape::new_float_slice_eval();
814 let out = eval
815 .eval_v::<f32>(
816 &tape,
817 &[1.0, 2.0, 3.0],
818 &[4.0, 5.0, 6.0],
819 &[7.0, 8.0, 9.0],
820 &ShapeVars::default(),
821 )
822 .unwrap();
823 assert_eq!(out, [1.0, 1.0, 1.0]);
824 }
825}