1use crate::batcher::ScratchPadView;
24use crate::inferer::{
25 BasicInferer, DynamicInferer, FixedBatchInferer, Inferer, MemoizingDynamicInferer,
26};
27
28pub trait InfererWrapper {
35 fn input_shapes<'a>(&'a self, inferer: &'a dyn Inferer) -> &'a [(String, Vec<usize>)];
37
38 fn output_shapes<'a>(&'a self, inferer: &'a dyn Inferer) -> &'a [(String, Vec<usize>)];
40
41 fn invoke(&self, inferer: &dyn Inferer, batch: &mut ScratchPadView<'_>) -> anyhow::Result<()>;
44
45 fn begin_agent(&self, inferer: &dyn Inferer, id: u64);
47
48 fn end_agent(&self, inferer: &dyn Inferer, id: u64);
50}
51
52pub struct BaseWrapper;
54
55impl InfererWrapper for BaseWrapper {
56 fn input_shapes<'a>(&'a self, inferer: &'a dyn Inferer) -> &'a [(String, Vec<usize>)] {
58 inferer.input_shapes()
59 }
60
61 fn output_shapes<'a>(&'a self, inferer: &'a dyn Inferer) -> &'a [(String, Vec<usize>)] {
63 inferer.output_shapes()
64 }
65
66 fn invoke(&self, inferer: &dyn Inferer, batch: &mut ScratchPadView<'_>) -> anyhow::Result<()> {
68 inferer.infer_raw(batch)
69 }
70
71 fn begin_agent(&self, inferer: &dyn Inferer, id: u64) {
72 inferer.begin_agent(id);
73 }
74
75 fn end_agent(&self, inferer: &dyn Inferer, id: u64) {
76 inferer.end_agent(id);
77 }
78}
79
80impl InfererWrapper for Box<dyn InfererWrapper> {
81 fn input_shapes<'a>(&'a self, inferer: &'a dyn Inferer) -> &'a [(String, Vec<usize>)] {
82 self.as_ref().input_shapes(inferer)
83 }
84
85 fn output_shapes<'a>(&'a self, inferer: &'a dyn Inferer) -> &'a [(String, Vec<usize>)] {
86 self.as_ref().output_shapes(inferer)
87 }
88
89 fn invoke(&self, inferer: &dyn Inferer, batch: &mut ScratchPadView<'_>) -> anyhow::Result<()> {
90 self.as_ref().invoke(inferer, batch)
91 }
92
93 fn begin_agent(&self, inferer: &dyn Inferer, id: u64) {
94 self.as_ref().begin_agent(inferer, id);
95 }
96
97 fn end_agent(&self, inferer: &dyn Inferer, id: u64) {
98 self.as_ref().end_agent(inferer, id);
99 }
100}
101
102pub struct StatefulInferer<WrapStack: InfererWrapper, Inf: Inferer> {
108 wrapper_stack: WrapStack,
109 inferer: Inf,
110}
111
112impl<WrapStack: InfererWrapper, Inf: Inferer> StatefulInferer<WrapStack, Inf> {
113 pub fn new(wrapper_stack: WrapStack, inferer: Inf) -> Self {
116 Self {
117 wrapper_stack,
118 inferer,
119 }
120 }
121
122 pub fn with_new_inferer<NewInf: Inferer>(
129 self,
130 new_inferer: NewInf,
131 ) -> Result<StatefulInferer<WrapStack, NewInf>, (Self, anyhow::Error)> {
132 if let Err(e) = Self::check_compatible_shapes(&self.inferer, &new_inferer) {
133 return Err((self, e));
134 }
135 Ok(StatefulInferer {
136 wrapper_stack: self.wrapper_stack,
137 inferer: new_inferer,
138 })
139 }
140
141 pub fn replace_inferer(&mut self, new_inferer: Inf) -> anyhow::Result<()> {
149 if let Err(e) = Self::check_compatible_shapes(&self.inferer, &new_inferer) {
150 Err(e)
151 } else {
152 self.inferer = new_inferer;
153 Ok(())
154 }
155 }
156
157 pub fn check_compatible_shapes<Old: Inferer, New: Inferer>(
160 old: &Old,
161 new: &New,
162 ) -> Result<(), anyhow::Error> {
163 let old_in = old.raw_input_shapes();
164 let new_in = new.raw_input_shapes();
165
166 let old_out = old.raw_output_shapes();
167 let new_out = new.raw_output_shapes();
168
169 for (i, (o, n)) in old_in.iter().zip(new_in).enumerate() {
170 if o != n {
171 if o.0 != n.0 {
172 return Err(anyhow::format_err!(
173 "name mismatch for input {i}: '{}' != '{}'",
174 o.0,
175 n.0,
176 ));
177 }
178
179 return Err(anyhow::format_err!(
180 "shape mismatch for input '{}': {:?} != {:?}",
181 o.0,
182 o.1,
183 n.1,
184 ));
185 }
186 }
187
188 for (i, (o, n)) in old_out.iter().zip(new_out).enumerate() {
189 if o != n {
190 if o.0 != n.0 {
191 return Err(anyhow::format_err!(
192 "name mismatch for output {i}: '{}' != '{}'",
193 o.0,
194 n.0,
195 ));
196 }
197
198 return Err(anyhow::format_err!(
199 "shape mismatch for output {}: {:?} != {:?}",
200 o.0,
201 o.1,
202 n.1,
203 ));
204 }
205 }
206
207 Ok(())
208 }
209
210 pub fn input_shapes(&self) -> &[(String, Vec<usize>)] {
212 self.wrapper_stack.input_shapes(&self.inferer)
213 }
214
215 pub fn output_shapes(&self) -> &[(String, Vec<usize>)] {
217 self.wrapper_stack.output_shapes(&self.inferer)
218 }
219}
220
221impl<WrapStack: InfererWrapper, Inf: Inferer> Inferer for StatefulInferer<WrapStack, Inf> {
223 fn select_batch_size(&self, max_count: usize) -> usize {
224 self.inferer.select_batch_size(max_count)
225 }
226
227 fn infer_raw(&self, batch: &mut ScratchPadView<'_>) -> anyhow::Result<(), anyhow::Error> {
228 self.wrapper_stack.invoke(&self.inferer, batch)
229 }
230
231 fn raw_input_shapes(&self) -> &[(String, Vec<usize>)] {
232 self.inferer.raw_input_shapes()
233 }
234
235 fn raw_output_shapes(&self) -> &[(String, Vec<usize>)] {
236 self.inferer.raw_output_shapes()
237 }
238
239 fn begin_agent(&self, id: u64) {
240 self.wrapper_stack.begin_agent(&self.inferer, id);
241 }
242
243 fn end_agent(&self, id: u64) {
244 self.wrapper_stack.end_agent(&self.inferer, id);
245 }
246}
247
248pub trait IntoStateful: Inferer + Sized {
250 fn into_stateful<WrapStack: InfererWrapper>(
253 self,
254 wrapper_stack: WrapStack,
255 ) -> StatefulInferer<WrapStack, Self> {
256 StatefulInferer::new(wrapper_stack, self)
257 }
258}
259
260impl IntoStateful for BasicInferer {}
261impl IntoStateful for DynamicInferer {}
262impl IntoStateful for MemoizingDynamicInferer {}
263impl IntoStateful for FixedBatchInferer {}
264
265pub trait InfererWrapperExt: InfererWrapper + Sized {
267 fn wrap<Inf: Inferer>(self, inferer: Inf) -> StatefulInferer<Self, Inf> {
269 StatefulInferer::new(self, inferer)
270 }
271}
272
273impl<T: InfererWrapper> InfererWrapperExt for T {}