cervo_core/
wrapper.rs

1/*!
2Inferer wrappers with state separated from the inferer.
3
4This allows separation of stateful logic from the inner inferer,
5allowing the inner inferer to be swapped out while maintaining
6state in the wrappers.
7
8This is an alternative to the old layered inferer setup, which
9tightly coupled the inner inferer with the wrapper state.
10
11```rust,ignore
12let inferer = ...;
13// the root needs [`BaseCase`] passed as a base case.
14let wrappers = RecurrentTrackerWrapper::new(BaseCase, inferer);
15let wrapped = StatefulInferer::new(wrappers, infere);
16// or
17let wrapped = inferer.into_stateful(wrappers);
18// or
19let wrapped = wrappers.wrap(inferer);
20```
21*/
22
23use crate::batcher::ScratchPadView;
24use crate::inferer::{
25    BasicInferer, DynamicInferer, FixedBatchInferer, Inferer, MemoizingDynamicInferer,
26};
27
28/// A trait for wrapping an inferer with additional functionality.
29///
30/// This works similar to the old layered inferer setup, but allows
31/// separation of wrapper state from the inner inferer. This allows
32/// swapping out the inner inferer while maintaining state in the
33/// wrappers.
34pub trait InfererWrapper {
35    /// Returns the input shapes after this wrapper has been applied.
36    fn input_shapes<'a>(&'a self, inferer: &'a dyn Inferer) -> &'a [(String, Vec<usize>)];
37
38    /// Returns the output shapes after this wrapper has been applied.
39    fn output_shapes<'a>(&'a self, inferer: &'a dyn Inferer) -> &'a [(String, Vec<usize>)];
40
41    /// Invokes the inner inferer, applying any additional logic before
42    /// and after the call.
43    fn invoke(&self, inferer: &dyn Inferer, batch: &mut ScratchPadView<'_>) -> anyhow::Result<()>;
44
45    /// Called when starting inference for a new agent.
46    fn begin_agent(&self, inferer: &dyn Inferer, id: u64);
47
48    /// Called when finishing inference for an agent.
49    fn end_agent(&self, inferer: &dyn Inferer, id: u64);
50}
51
52/// A no-op inferer wrapper that just calls the inner inferer directly. This is the base-case of wrapper stack.
53pub struct BaseWrapper;
54
55impl InfererWrapper for BaseWrapper {
56    /// Returns the input shapes after this wrapper has been applied.
57    fn input_shapes<'a>(&'a self, inferer: &'a dyn Inferer) -> &'a [(String, Vec<usize>)] {
58        inferer.input_shapes()
59    }
60
61    /// Returns the output shapes after this wrapper has been applied.
62    fn output_shapes<'a>(&'a self, inferer: &'a dyn Inferer) -> &'a [(String, Vec<usize>)] {
63        inferer.output_shapes()
64    }
65
66    /// Invokes the inner inferer.
67    fn invoke(&self, inferer: &dyn Inferer, batch: &mut ScratchPadView<'_>) -> anyhow::Result<()> {
68        inferer.infer_raw(batch)
69    }
70
71    fn begin_agent(&self, inferer: &dyn Inferer, id: u64) {
72        inferer.begin_agent(id);
73    }
74
75    fn end_agent(&self, inferer: &dyn Inferer, id: u64) {
76        inferer.end_agent(id);
77    }
78}
79
80impl InfererWrapper for Box<dyn InfererWrapper> {
81    fn input_shapes<'a>(&'a self, inferer: &'a dyn Inferer) -> &'a [(String, Vec<usize>)] {
82        self.as_ref().input_shapes(inferer)
83    }
84
85    fn output_shapes<'a>(&'a self, inferer: &'a dyn Inferer) -> &'a [(String, Vec<usize>)] {
86        self.as_ref().output_shapes(inferer)
87    }
88
89    fn invoke(&self, inferer: &dyn Inferer, batch: &mut ScratchPadView<'_>) -> anyhow::Result<()> {
90        self.as_ref().invoke(inferer, batch)
91    }
92
93    fn begin_agent(&self, inferer: &dyn Inferer, id: u64) {
94        self.as_ref().begin_agent(inferer, id);
95    }
96
97    fn end_agent(&self, inferer: &dyn Inferer, id: u64) {
98        self.as_ref().end_agent(inferer, id);
99    }
100}
101
102/// An inferer that maintains state in wrappers around an inferer.
103///
104/// This is an alternative to direct wrapping of an inferer, which
105/// allows the inner inferer to be swapped out while maintaining
106/// state in the wrappers.
107pub struct StatefulInferer<WrapStack: InfererWrapper, Inf: Inferer> {
108    wrapper_stack: WrapStack,
109    inferer: Inf,
110}
111
112impl<WrapStack: InfererWrapper, Inf: Inferer> StatefulInferer<WrapStack, Inf> {
113    /// Construct a new [`StatefulInferer`] by wrapping the given
114    /// inferer with the given wrapper stack.
115    pub fn new(wrapper_stack: WrapStack, inferer: Inf) -> Self {
116        Self {
117            wrapper_stack,
118            inferer,
119        }
120    }
121
122    /// Replace the inner inferer with a new inferer while maintaining
123    /// any state in wrappers.
124    ///
125    /// Requires that the shapes of the policies are compatible, but
126    /// they may be different inferer types. If this check fails, will
127    /// return self unchanged.
128    pub fn with_new_inferer<NewInf: Inferer>(
129        self,
130        new_inferer: NewInf,
131    ) -> Result<StatefulInferer<WrapStack, NewInf>, (Self, anyhow::Error)> {
132        if let Err(e) = Self::check_compatible_shapes(&self.inferer, &new_inferer) {
133            return Err((self, e));
134        }
135        Ok(StatefulInferer {
136            wrapper_stack: self.wrapper_stack,
137            inferer: new_inferer,
138        })
139    }
140
141    /// Replace the inner inferer with a new inferer while maintaining
142    /// any state in wrappers.
143    ///
144    /// Requires that the shapes of the policies are compatible If
145    /// this check fails, will not change self. Compared to
146    /// [`with_new_inferer`], also requires that the new inferer has
147    /// the same type as the old one.
148    pub fn replace_inferer(&mut self, new_inferer: Inf) -> anyhow::Result<()> {
149        if let Err(e) = Self::check_compatible_shapes(&self.inferer, &new_inferer) {
150            Err(e)
151        } else {
152            self.inferer = new_inferer;
153            Ok(())
154        }
155    }
156
157    /// Validate that [`Old`] and [`New`] are compatible with each
158    /// other.
159    pub fn check_compatible_shapes<Old: Inferer, New: Inferer>(
160        old: &Old,
161        new: &New,
162    ) -> Result<(), anyhow::Error> {
163        let old_in = old.raw_input_shapes();
164        let new_in = new.raw_input_shapes();
165
166        let old_out = old.raw_output_shapes();
167        let new_out = new.raw_output_shapes();
168
169        for (i, (o, n)) in old_in.iter().zip(new_in).enumerate() {
170            if o != n {
171                if o.0 != n.0 {
172                    return Err(anyhow::format_err!(
173                        "name mismatch for input {i}: '{}' != '{}'",
174                        o.0,
175                        n.0,
176                    ));
177                }
178
179                return Err(anyhow::format_err!(
180                    "shape mismatch for input '{}': {:?} != {:?}",
181                    o.0,
182                    o.1,
183                    n.1,
184                ));
185            }
186        }
187
188        for (i, (o, n)) in old_out.iter().zip(new_out).enumerate() {
189            if o != n {
190                if o.0 != n.0 {
191                    return Err(anyhow::format_err!(
192                        "name mismatch for output {i}: '{}' != '{}'",
193                        o.0,
194                        n.0,
195                    ));
196                }
197
198                return Err(anyhow::format_err!(
199                    "shape mismatch for output {}: {:?} != {:?}",
200                    o.0,
201                    o.1,
202                    n.1,
203                ));
204            }
205        }
206
207        Ok(())
208    }
209
210    /// Returns the input shapes after all wrappers have been applied.
211    pub fn input_shapes(&self) -> &[(String, Vec<usize>)] {
212        self.wrapper_stack.input_shapes(&self.inferer)
213    }
214
215    /// Returns the output shapes after all wrappers have been applied.
216    pub fn output_shapes(&self) -> &[(String, Vec<usize>)] {
217        self.wrapper_stack.output_shapes(&self.inferer)
218    }
219}
220
221/// See [`Inferer`] for documentation.
222impl<WrapStack: InfererWrapper, Inf: Inferer> Inferer for StatefulInferer<WrapStack, Inf> {
223    fn select_batch_size(&self, max_count: usize) -> usize {
224        self.inferer.select_batch_size(max_count)
225    }
226
227    fn infer_raw(&self, batch: &mut ScratchPadView<'_>) -> anyhow::Result<(), anyhow::Error> {
228        self.wrapper_stack.invoke(&self.inferer, batch)
229    }
230
231    fn raw_input_shapes(&self) -> &[(String, Vec<usize>)] {
232        self.inferer.raw_input_shapes()
233    }
234
235    fn raw_output_shapes(&self) -> &[(String, Vec<usize>)] {
236        self.inferer.raw_output_shapes()
237    }
238
239    fn begin_agent(&self, id: u64) {
240        self.wrapper_stack.begin_agent(&self.inferer, id);
241    }
242
243    fn end_agent(&self, id: u64) {
244        self.wrapper_stack.end_agent(&self.inferer, id);
245    }
246}
247
248/// Extension trait to allow easy wrapping of an inferer with a wrapper stack.
249pub trait IntoStateful: Inferer + Sized {
250    /// Construct a [`StatefulInferer`] by wrapping this concrete
251    /// inferer with the given wrapper stack.
252    fn into_stateful<WrapStack: InfererWrapper>(
253        self,
254        wrapper_stack: WrapStack,
255    ) -> StatefulInferer<WrapStack, Self> {
256        StatefulInferer::new(wrapper_stack, self)
257    }
258}
259
260impl IntoStateful for BasicInferer {}
261impl IntoStateful for DynamicInferer {}
262impl IntoStateful for MemoizingDynamicInferer {}
263impl IntoStateful for FixedBatchInferer {}
264
265/// Extension trait to allow easy wrapping of an inferer with a wrapper stack.
266pub trait InfererWrapperExt: InfererWrapper + Sized {
267    /// Construct a [`StatefulInferer`] by wrapping an inner inferer with this wrapper.
268    fn wrap<Inf: Inferer>(self, inferer: Inf) -> StatefulInferer<Self, Inf> {
269        StatefulInferer::new(self, inferer)
270    }
271}
272
273impl<T: InfererWrapper> InfererWrapperExt for T {}