1use crate::{
30 Error,
31 context::{Context, Node, Tree},
32 eval::{BulkEvaluator, Function, MathFunction, Tape, TracingEvaluator},
33 types::{Grad, Interval},
34 var::{Var, VarIndex, VarMap},
35};
36use nalgebra::{Matrix4, Point3};
37use std::collections::HashMap;
38
39pub struct Shape<F, T = ()> {
58 f: F,
60
61 axes: [Var; 3],
63
64 transform: Option<Matrix4<f32>>,
69
70 _marker: std::marker::PhantomData<T>,
71}
72
73impl<F: Clone, T> Clone for Shape<F, T> {
74 fn clone(&self) -> Self {
75 Self {
76 f: self.f.clone(),
77 axes: self.axes,
78 transform: self.transform,
79 _marker: std::marker::PhantomData,
80 }
81 }
82}
83
84impl<F: Function + Clone, T> Shape<F, T> {
85 pub fn new_point_eval() -> ShapeTracingEval<F::PointEval> {
87 ShapeTracingEval {
88 eval: F::PointEval::default(),
89 scratch: vec![],
90 }
91 }
92
93 pub fn new_interval_eval() -> ShapeTracingEval<F::IntervalEval> {
95 ShapeTracingEval {
96 eval: F::IntervalEval::default(),
97 scratch: vec![],
98 }
99 }
100
101 pub fn new_float_slice_eval() -> ShapeBulkEval<F::FloatSliceEval> {
103 ShapeBulkEval {
104 eval: F::FloatSliceEval::default(),
105 scratch: vec![],
106 }
107 }
108
109 pub fn new_grad_slice_eval() -> ShapeBulkEval<F::GradSliceEval> {
111 ShapeBulkEval {
112 eval: F::GradSliceEval::default(),
113 scratch: vec![],
114 }
115 }
116
117 #[inline]
119 pub fn point_tape(
120 &self,
121 storage: F::TapeStorage,
122 ) -> ShapeTape<<F::PointEval as TracingEvaluator>::Tape> {
123 let tape = self.f.point_tape(storage);
124 let vars = tape.vars();
125 let axes = self.axes.map(|v| vars.get(&v));
126 ShapeTape {
127 tape,
128 axes,
129 transform: self.transform,
130 }
131 }
132
133 #[inline]
135 pub fn interval_tape(
136 &self,
137 storage: F::TapeStorage,
138 ) -> ShapeTape<<F::IntervalEval as TracingEvaluator>::Tape> {
139 let tape = self.f.interval_tape(storage);
140 let vars = tape.vars();
141 let axes = self.axes.map(|v| vars.get(&v));
142 ShapeTape {
143 tape,
144 axes,
145 transform: self.transform,
146 }
147 }
148
149 #[inline]
151 pub fn float_slice_tape(
152 &self,
153 storage: F::TapeStorage,
154 ) -> ShapeTape<<F::FloatSliceEval as BulkEvaluator>::Tape> {
155 let tape = self.f.float_slice_tape(storage);
156 let vars = tape.vars();
157 let axes = self.axes.map(|v| vars.get(&v));
158 ShapeTape {
159 tape,
160 axes,
161 transform: self.transform,
162 }
163 }
164
165 #[inline]
167 pub fn grad_slice_tape(
168 &self,
169 storage: F::TapeStorage,
170 ) -> ShapeTape<<F::GradSliceEval as BulkEvaluator>::Tape> {
171 let tape = self.f.grad_slice_tape(storage);
172 let vars = tape.vars();
173 let axes = self.axes.map(|v| vars.get(&v));
174 ShapeTape {
175 tape,
176 axes,
177 transform: self.transform,
178 }
179 }
180
181 #[inline]
183 pub fn simplify(
184 &self,
185 trace: &F::Trace,
186 storage: F::Storage,
187 workspace: &mut F::Workspace,
188 ) -> Result<Self, Error>
189 where
190 Self: Sized,
191 {
192 let f = self.f.simplify(trace, storage, workspace)?;
193 Ok(Self {
194 f,
195 axes: self.axes,
196 transform: self.transform,
197 _marker: std::marker::PhantomData,
198 })
199 }
200
201 #[inline]
206 pub fn recycle(self) -> Option<F::Storage> {
207 self.f.recycle()
208 }
209
210 #[inline]
215 pub fn size(&self) -> usize {
216 self.f.size()
217 }
218}
219
220impl<F, T> Shape<F, T> {
221 pub fn inner(&self) -> &F {
223 &self.f
224 }
225
226 pub fn axes(&self) -> &[Var; 3] {
228 &self.axes
229 }
230
231 pub fn new_raw(f: F, axes: [Var; 3]) -> Self {
233 Self {
234 f,
235 axes,
236 transform: None,
237 _marker: std::marker::PhantomData,
238 }
239 }
240}
241
242pub struct Transformed;
244
245impl<F: Clone> Shape<F, ()> {
246 pub fn with_transform(&self, mat: Matrix4<f32>) -> Shape<F, Transformed> {
248 Shape {
249 f: self.f.clone(),
250 axes: self.axes,
251 transform: Some(mat),
252 _marker: std::marker::PhantomData,
253 }
254 }
255}
256
257pub struct ShapeVars<F>(HashMap<VarIndex, F>);
263
264impl<F> Default for ShapeVars<F> {
265 fn default() -> Self {
266 Self(HashMap::default())
267 }
268}
269
270impl<F> ShapeVars<F> {
271 pub fn new() -> Self {
273 Self(HashMap::default())
274 }
275 pub fn len(&self) -> usize {
277 self.0.len()
278 }
279 pub fn is_empty(&self) -> bool {
281 self.0.is_empty()
282 }
283 pub fn insert(&mut self, v: VarIndex, f: F) -> Option<F> {
287 self.0.insert(v, f)
288 }
289
290 pub fn values(&self) -> impl Iterator<Item = &F> {
292 self.0.values()
293 }
294}
295
296impl<'a, F> IntoIterator for &'a ShapeVars<F> {
297 type Item = (&'a VarIndex, &'a F);
298 type IntoIter = std::collections::hash_map::Iter<'a, VarIndex, F>;
299 fn into_iter(self) -> Self::IntoIter {
300 self.0.iter()
301 }
302}
303
304pub trait EzShape<F: Function> {
314 fn ez_point_tape(
316 &self,
317 ) -> ShapeTape<<F::PointEval as TracingEvaluator>::Tape>;
318
319 fn ez_interval_tape(
321 &self,
322 ) -> ShapeTape<<F::IntervalEval as TracingEvaluator>::Tape>;
323
324 fn ez_float_slice_tape(
326 &self,
327 ) -> ShapeTape<<F::FloatSliceEval as BulkEvaluator>::Tape>;
328
329 fn ez_grad_slice_tape(
331 &self,
332 ) -> ShapeTape<<F::GradSliceEval as BulkEvaluator>::Tape>;
333
334 fn ez_simplify(&self, trace: &F::Trace) -> Result<Self, Error>
336 where
337 Self: Sized;
338}
339
340impl<F: Function, T> EzShape<F> for Shape<F, T> {
341 fn ez_point_tape(
342 &self,
343 ) -> ShapeTape<<F::PointEval as TracingEvaluator>::Tape> {
344 self.point_tape(Default::default())
345 }
346
347 fn ez_interval_tape(
348 &self,
349 ) -> ShapeTape<<F::IntervalEval as TracingEvaluator>::Tape> {
350 self.interval_tape(Default::default())
351 }
352
353 fn ez_float_slice_tape(
354 &self,
355 ) -> ShapeTape<<F::FloatSliceEval as BulkEvaluator>::Tape> {
356 self.float_slice_tape(Default::default())
357 }
358
359 fn ez_grad_slice_tape(
360 &self,
361 ) -> ShapeTape<<F::GradSliceEval as BulkEvaluator>::Tape> {
362 self.grad_slice_tape(Default::default())
363 }
364
365 fn ez_simplify(&self, trace: &F::Trace) -> Result<Self, Error> {
366 let mut workspace = Default::default();
367 self.simplify(trace, Default::default(), &mut workspace)
368 }
369}
370
371impl<F: MathFunction> Shape<F> {
372 pub fn new_with_axes(
374 ctx: &Context,
375 node: Node,
376 axes: [Var; 3],
377 ) -> Result<Self, Error> {
378 let f = F::new(ctx, &[node])?;
379 Ok(Self {
380 f,
381 axes,
382 transform: None,
383 _marker: std::marker::PhantomData,
384 })
385 }
386
387 pub fn new(ctx: &Context, node: Node) -> Result<Self, Error>
389 where
390 Self: Sized,
391 {
392 Self::new_with_axes(ctx, node, [Var::X, Var::Y, Var::Z])
393 }
394}
395
396impl<F: MathFunction> From<Tree> for Shape<F> {
398 fn from(t: Tree) -> Self {
399 let mut ctx = Context::new();
400 let node = ctx.import(&t);
401 Self::new(&ctx, node).unwrap()
402 }
403}
404
405#[derive(Clone)]
407pub struct ShapeTape<T> {
408 tape: T,
409
410 axes: [Option<usize>; 3],
412
413 transform: Option<Matrix4<f32>>,
415}
416
417impl<T: Tape> ShapeTape<T> {
418 pub fn recycle(self) -> Option<T::Storage> {
420 self.tape.recycle()
421 }
422
423 pub fn vars(&self) -> &VarMap {
425 self.tape.vars()
426 }
427}
428
429#[derive(Debug)]
434pub struct ShapeTracingEval<E: TracingEvaluator> {
435 eval: E,
436 scratch: Vec<E::Data>,
437}
438
439impl<E: TracingEvaluator> Default for ShapeTracingEval<E> {
440 fn default() -> Self {
441 Self {
442 eval: E::default(),
443 scratch: vec![],
444 }
445 }
446}
447
448impl<E: TracingEvaluator> ShapeTracingEval<E>
449where
450 <E as TracingEvaluator>::Data: Transformable,
451{
452 #[inline]
459 pub fn eval<F: Into<E::Data> + Copy>(
460 &mut self,
461 tape: &ShapeTape<E::Tape>,
462 x: F,
463 y: F,
464 z: F,
465 ) -> Result<(E::Data, Option<&E::Trace>), Error> {
466 let h = ShapeVars::<f32>::new();
467 self.eval_v(tape, x, y, z, &h)
468 }
469
470 #[inline]
474 pub fn eval_v<F: Into<E::Data> + Copy, V: Into<E::Data> + Copy>(
475 &mut self,
476 tape: &ShapeTape<E::Tape>,
477 x: F,
478 y: F,
479 z: F,
480 vars: &ShapeVars<V>,
481 ) -> Result<(E::Data, Option<&E::Trace>), Error> {
482 assert_eq!(
483 tape.tape.output_count(),
484 1,
485 "ShapeTape has multiple outputs"
486 );
487
488 let x = x.into();
489 let y = y.into();
490 let z = z.into();
491 let (x, y, z) = if let Some(mat) = tape.transform {
492 Transformable::transform(x, y, z, mat)
493 } else {
494 (x, y, z)
495 };
496
497 let vs = tape.vars();
498 let expected_vars = vs.len()
499 - vs.get(&Var::X).is_some() as usize
500 - vs.get(&Var::Y).is_some() as usize
501 - vs.get(&Var::Z).is_some() as usize;
502 if expected_vars != vars.len() {
503 return Err(Error::BadVarSlice(vars.len(), expected_vars));
504 }
505
506 self.scratch.resize(tape.vars().len(), 0f32.into());
507 if let Some(a) = tape.axes[0] {
508 self.scratch[a] = x;
509 }
510 if let Some(b) = tape.axes[1] {
511 self.scratch[b] = y;
512 }
513 if let Some(c) = tape.axes[2] {
514 self.scratch[c] = z;
515 }
516 for (var, value) in vars {
517 if let Some(i) = vs.get(&Var::V(*var)) {
518 if i < self.scratch.len() {
519 self.scratch[i] = (*value).into();
520 } else {
521 return Err(Error::BadVarIndex(i, self.scratch.len()));
522 }
523 } else {
524 }
526 }
527
528 let (out, trace) = self.eval.eval(&tape.tape, &self.scratch)?;
529 Ok((out[0], trace))
530 }
531}
532
533#[derive(Debug, Default)]
538pub struct ShapeBulkEval<E: BulkEvaluator> {
539 eval: E,
540 scratch: Vec<Vec<E::Data>>,
541}
542
543impl<E: BulkEvaluator> ShapeBulkEval<E>
544where
545 E::Data: From<f32> + Transformable,
546{
547 #[inline]
555 pub fn eval(
556 &mut self,
557 tape: &ShapeTape<E::Tape>,
558 x: &[E::Data],
559 y: &[E::Data],
560 z: &[E::Data],
561 ) -> Result<&[E::Data], Error> {
562 let h: ShapeVars<&[E::Data]> = ShapeVars::new();
563 self.eval_vs(tape, x, y, z, &h)
564 }
565
566 #[inline]
568 fn setup<V>(
569 &mut self,
570 tape: &ShapeTape<E::Tape>,
571 x: &[E::Data],
572 y: &[E::Data],
573 z: &[E::Data],
574 vars: &ShapeVars<V>,
575 ) -> Result<usize, Error> {
576 assert_eq!(
577 tape.tape.output_count(),
578 1,
579 "ShapeTape has multiple outputs"
580 );
581
582 if x.len() != y.len() || x.len() != z.len() {
584 return Err(Error::MismatchedSlices);
585 }
586 let n = x.len();
587
588 let vs = tape.vars();
589 let expected_vars = vs.len()
590 - vs.get(&Var::X).is_some() as usize
591 - vs.get(&Var::Y).is_some() as usize
592 - vs.get(&Var::Z).is_some() as usize;
593 if expected_vars != vars.len() {
594 return Err(Error::BadVarSlice(vars.len(), expected_vars));
595 }
596
597 self.scratch.resize_with(vs.len().max(1), Vec::new);
600 for s in &mut self.scratch {
601 s.resize(n, 0.0.into());
602 }
603
604 if let Some(mat) = tape.transform {
605 for i in 0..n {
606 let (x, y, z) = Transformable::transform(x[i], y[i], z[i], mat);
607 if let Some(a) = tape.axes[0] {
608 self.scratch[a][i] = x;
609 }
610 if let Some(b) = tape.axes[1] {
611 self.scratch[b][i] = y;
612 }
613 if let Some(c) = tape.axes[2] {
614 self.scratch[c][i] = z;
615 }
616 }
617 } else {
618 if let Some(a) = tape.axes[0] {
619 self.scratch[a].copy_from_slice(x);
620 }
621 if let Some(b) = tape.axes[1] {
622 self.scratch[b].copy_from_slice(y);
623 }
624 if let Some(c) = tape.axes[2] {
625 self.scratch[c].copy_from_slice(z);
626 }
627 };
629
630 Ok(n)
631 }
632 #[inline]
642 pub fn eval_vs<
643 V: std::ops::Deref<Target = [G]>,
644 G: Into<E::Data> + Copy,
645 >(
646 &mut self,
647 tape: &ShapeTape<E::Tape>,
648 x: &[E::Data],
649 y: &[E::Data],
650 z: &[E::Data],
651 vars: &ShapeVars<V>,
652 ) -> Result<&[E::Data], Error> {
653 let n = self.setup(tape, x, y, z, vars)?;
654
655 if vars.values().any(|vs| vs.len() != n) {
656 return Err(Error::MismatchedSlices);
657 }
658
659 let vs = tape.vars();
660 for (var, value) in vars {
661 if let Some(i) = vs.get(&Var::V(*var)) {
662 if i < self.scratch.len() {
663 for (a, b) in
664 self.scratch[i].iter_mut().zip(value.deref().iter())
665 {
666 *a = (*b).into();
667 }
668 } else {
670 return Err(Error::BadVarIndex(i, self.scratch.len()));
671 }
672 } else {
673 }
675 }
676
677 let out = self.eval.eval(&tape.tape, &self.scratch)?;
678 Ok(out.borrow(0))
679 }
680
681 #[inline]
690 pub fn eval_v<G: Into<E::Data> + Copy>(
691 &mut self,
692 tape: &ShapeTape<E::Tape>,
693 x: &[E::Data],
694 y: &[E::Data],
695 z: &[E::Data],
696 vars: &ShapeVars<G>,
697 ) -> Result<&[E::Data], Error> {
698 self.setup(tape, x, y, z, vars)?;
699 let vs = tape.vars();
700 for (var, value) in vars {
701 if let Some(i) = vs.get(&Var::V(*var)) {
702 if i < self.scratch.len() {
703 self.scratch[i].fill((*value).into());
704 } else {
705 return Err(Error::BadVarIndex(i, self.scratch.len()));
706 }
707 } else {
708 }
710 }
711
712 let out = self.eval.eval(&tape.tape, &self.scratch)?;
713 Ok(out.borrow(0))
714 }
715}
716
717pub trait Transformable {
719 fn transform(
721 x: Self,
722 y: Self,
723 z: Self,
724 mat: Matrix4<f32>,
725 ) -> (Self, Self, Self)
726 where
727 Self: Sized;
728}
729
730impl Transformable for f32 {
731 fn transform(x: f32, y: f32, z: f32, mat: Matrix4<f32>) -> (f32, f32, f32) {
732 let out = mat.transform_point(&Point3::new(x, y, z));
733 (out.x, out.y, out.z)
734 }
735}
736
737impl Transformable for Interval {
738 fn transform(
739 x: Interval,
740 y: Interval,
741 z: Interval,
742 mat: Matrix4<f32>,
743 ) -> (Interval, Interval, Interval) {
744 let out = [0, 1, 2, 3].map(|i| {
745 let row = mat.row(i);
746 x * row[0] + y * row[1] + z * row[2] + Interval::from(row[3])
747 });
748
749 (out[0] / out[3], out[1] / out[3], out[2] / out[3])
750 }
751}
752
753impl Transformable for Grad {
754 fn transform(
755 x: Grad,
756 y: Grad,
757 z: Grad,
758 mat: Matrix4<f32>,
759 ) -> (Grad, Grad, Grad) {
760 let out = [0, 1, 2, 3].map(|i| {
761 let row = mat.row(i);
762 x * row[0] + y * row[1] + z * row[2] + Grad::from(row[3])
763 });
764
765 (out[0] / out[3], out[1] / out[3], out[2] / out[3])
766 }
767}
768
769#[cfg(test)]
770mod test {
771 use super::*;
772 use crate::vm::VmShape;
773
774 #[test]
775 fn shape_vars() {
776 let v = Var::new();
777 let s = Tree::x() + Tree::y() + v;
778
779 let mut ctx = Context::new();
780 let s = ctx.import(&s);
781
782 let s = VmShape::new(&ctx, s).unwrap();
783 let vs = s.inner().vars();
784 assert_eq!(vs.len(), 3);
785
786 assert!(vs.get(&Var::X).is_some());
787 assert!(vs.get(&Var::Y).is_some());
788 assert!(vs.get(&Var::Z).is_none());
789 assert!(vs.get(&v).is_some());
790
791 let mut seen = [false; 3];
792 for v in [Var::X, Var::Y, v] {
793 seen[vs[&v]] = true;
794 }
795 assert!(seen.iter().all(|i| *i));
796 }
797
798 #[test]
799 fn shape_eval_bulk_size() {
800 let s = Tree::constant(1.0);
801 let mut ctx = Context::new();
802 let s = ctx.import(&s);
803
804 let s = VmShape::new(&ctx, s).unwrap();
805 let tape = s.ez_float_slice_tape();
806 let mut eval = VmShape::new_float_slice_eval();
807 let out = eval
808 .eval_v::<f32>(
809 &tape,
810 &[1.0, 2.0, 3.0],
811 &[4.0, 5.0, 6.0],
812 &[7.0, 8.0, 9.0],
813 &ShapeVars::default(),
814 )
815 .unwrap();
816 assert_eq!(out, [1.0, 1.0, 1.0]);
817 }
818}