fidget_core/eval/mod.rs
1//! Traits and data structures for function evaluation
2use crate::{
3    Error,
4    context::{Context, Node},
5    types::{Grad, Interval},
6    var::VarMap,
7};
8
9#[cfg(any(test, feature = "eval-tests"))]
10#[allow(missing_docs)]
11pub mod test;
12
13mod bulk;
14mod tracing;
15
16// Reexport a few types
17pub use bulk::{BulkEvaluator, BulkOutput};
18pub use tracing::TracingEvaluator;
19
20/// A tape represents something that can be evaluated by an evaluator
21///
22/// It includes some kind of storage (which could be empty) and the ability to
23/// look up variable mapping.
24///
25/// Tapes may be shared between threads, so they should be cheap to clone (i.e.
26/// a wrapper around an `Arc<..>`).
27pub trait Tape: Send + Sync + Clone {
28    /// Associated type for this tape's data storage
29    type Storage: Default;
30
31    /// Tries to retrieve the internal storage from this tape
32    ///
33    /// This matters most for JIT evaluators, whose tapes are regions of
34    /// executable memory-mapped RAM (which is expensive to map and unmap).
35    fn recycle(self) -> Option<Self::Storage>;
36
37    /// Returns a mapping from [`Var`](crate::var::Var) to evaluation index
38    ///
39    /// This must be identical to [`Function::vars`] on the `Function` which
40    /// produced this tape.
41    fn vars(&self) -> &VarMap;
42
43    /// Returns the number of outputs written by this tape
44    ///
45    /// The order of outputs is set by the caller at tape construction, so we
46    /// don't need a map to determine the index of a particular output (unlike
47    /// variables).
48    fn output_count(&self) -> usize;
49}
50
51/// Represents the trace captured by a tracing evaluation
52///
53/// The only property enforced on the trait is that we must have a way of
54/// reusing trace allocations.  Because [`Trace`] implies `Clone` where it's
55/// used in [`Function`], this is trivial, but we can't provide a default
56/// implementation because it would fall afoul of `impl` specialization.
57pub trait Trace {
58    /// Copies the contents of `other` into `self`
59    fn copy_from(&mut self, other: &Self);
60}
61
62impl<T: Copy + Clone + Default> Trace for Vec<T> {
63    fn copy_from(&mut self, other: &Self) {
64        self.resize(other.len(), T::default());
65        self.copy_from_slice(other);
66    }
67}
68
69/// A function represents something that can be evaluated
70///
71/// It is mostly agnostic to _how_ that something is represented; we simply
72/// require that it can generate evaluators of various kinds.
73///
74/// Inputs to the function should be represented as [`Var`](crate::var::Var)
75/// values; the [`vars()`](Function::vars) function returns the mapping from
76/// `Var` to position in the input slice.
77///
78/// Functions are shared between threads, so they should be cheap to clone.  In
79/// most cases, they're a thin wrapper around an `Arc<..>`.
80pub trait Function: Send + Sync + Clone {
81    /// Associated type traces collected during tracing evaluation
82    ///
83    /// This type must implement [`Eq`] so that traces can be compared; calling
84    /// [`Function::simplify`] with traces that compare equal should produce an
85    /// identical result and may be cached.
86    type Trace: Clone + Eq + Send + Sync + Trace;
87
88    /// Associated type for storage used by the function itself
89    type Storage: Default + Send;
90
91    /// Associated type for workspace used during function simplification
92    type Workspace: Default + Send;
93
94    /// Associated type for storage used by tapes
95    ///
96    /// For simplicity, we require that every tape use the same type for storage.
97    /// This could change in the future!
98    type TapeStorage: Default + Send;
99
100    /// Associated type for single-point tracing evaluation
101    type PointEval: TracingEvaluator<
102            Data = f32,
103            Trace = Self::Trace,
104            TapeStorage = Self::TapeStorage,
105        > + Send
106        + Sync;
107
108    /// Builds a new point evaluator
109    fn new_point_eval() -> Self::PointEval {
110        Self::PointEval::new()
111    }
112
113    /// Associated type for single interval tracing evaluation
114    type IntervalEval: TracingEvaluator<
115            Data = Interval,
116            Trace = Self::Trace,
117            TapeStorage = Self::TapeStorage,
118        > + Send
119        + Sync;
120
121    /// Builds a new interval evaluator
122    fn new_interval_eval() -> Self::IntervalEval {
123        Self::IntervalEval::new()
124    }
125
126    /// Associated type for evaluating many points in one call
127    type FloatSliceEval: BulkEvaluator<Data = f32, TapeStorage = Self::TapeStorage>
128        + Send
129        + Sync;
130
131    /// Builds a new float slice evaluator
132    fn new_float_slice_eval() -> Self::FloatSliceEval {
133        Self::FloatSliceEval::new()
134    }
135
136    /// Associated type for evaluating many gradients in one call
137    type GradSliceEval: BulkEvaluator<Data = Grad, TapeStorage = Self::TapeStorage>
138        + Send
139        + Sync;
140
141    /// Builds a new gradient slice evaluator
142    fn new_grad_slice_eval() -> Self::GradSliceEval {
143        Self::GradSliceEval::new()
144    }
145
146    /// Returns an evaluation tape for a point evaluator
147    fn point_tape(
148        &self,
149        storage: Self::TapeStorage,
150    ) -> <Self::PointEval as TracingEvaluator>::Tape;
151
152    /// Returns an evaluation tape for an interval evaluator
153    fn interval_tape(
154        &self,
155        storage: Self::TapeStorage,
156    ) -> <Self::IntervalEval as TracingEvaluator>::Tape;
157
158    /// Returns an evaluation tape for a float slice evaluator
159    fn float_slice_tape(
160        &self,
161        storage: Self::TapeStorage,
162    ) -> <Self::FloatSliceEval as BulkEvaluator>::Tape;
163
164    /// Returns an evaluation tape for a float slice evaluator
165    fn grad_slice_tape(
166        &self,
167        storage: Self::TapeStorage,
168    ) -> <Self::GradSliceEval as BulkEvaluator>::Tape;
169
170    /// Computes a simplified tape using the given trace, and reusing storage
171    fn simplify(
172        &self,
173        trace: &Self::Trace,
174        storage: Self::Storage,
175        workspace: &mut Self::Workspace,
176    ) -> Result<Self, Error>
177    where
178        Self: Sized;
179
180    /// Attempt to reclaim storage from this function
181    ///
182    /// This may fail, because functions are `Clone` and are often implemented
183    /// using an `Arc` around a heavier data structure.
184    fn recycle(self) -> Option<Self::Storage>;
185
186    /// Returns a size associated with this function
187    ///
188    /// This is underspecified and only used for unit testing; for tape-based
189    /// functions, it's typically the length of the tape,
190    fn size(&self) -> usize;
191
192    /// Returns the map from [`Var`](crate::var::Var) to input index
193    fn vars(&self) -> &VarMap;
194
195    /// Checks to see whether this function can ever be simplified
196    fn can_simplify(&self) -> bool;
197}
198
199/// A [`Function`] which can be built from a math expression
200pub trait MathFunction: Function {
201    /// Builds a new function from the given context and node
202    fn new(ctx: &Context, nodes: &[Node]) -> Result<Self, Error>
203    where
204        Self: Sized;
205}