1#![warn(clippy::all)]
6
7use anyhow::{Error, Result};
39use std::collections::HashMap;
40
41mod basic;
42mod dynamic;
43mod fixed;
44mod helpers;
45mod memoizing;
46
47pub use basic::BasicInferer;
48pub use dynamic::DynamicInferer;
49pub use fixed::FixedBatchInferer;
50pub use memoizing::MemoizingDynamicInferer;
51
52use crate::{
53 batcher::{Batched, Batcher, ScratchPadView},
54 epsilon::{EpsilonInjector, NoiseGenerator},
55};
56
57#[derive(Clone, Debug)]
59pub struct State<'a> {
60 pub data: HashMap<&'a str, Vec<f32>>,
61}
62
63impl<'a> State<'a> {
64 pub fn empty() -> Self {
66 Self {
67 data: Default::default(),
68 }
69 }
70}
71
72#[derive(Clone, Debug, Default)]
74pub struct Response<'a> {
75 pub data: HashMap<&'a str, Vec<f32>>,
76}
77
78impl<'a> Response<'a> {
79 pub fn empty() -> Self {
81 Self {
82 data: Default::default(),
83 }
84 }
85}
86
87pub trait Inferer {
89 fn select_batch_size(&self, max_count: usize) -> usize;
91
92 fn infer_raw(&self, batch: &mut ScratchPadView<'_>) -> Result<(), anyhow::Error>;
94
95 fn input_shapes(&self) -> &[(String, Vec<usize>)] {
99 self.raw_input_shapes()
100 }
101
102 fn output_shapes(&self) -> &[(String, Vec<usize>)] {
106 self.raw_output_shapes()
107 }
108
109 fn raw_input_shapes(&self) -> &[(String, Vec<usize>)];
111
112 fn raw_output_shapes(&self) -> &[(String, Vec<usize>)];
114
115 fn begin_agent(&self, id: u64);
116 fn end_agent(&self, id: u64);
117}
118
119pub trait InfererProvider {
121 fn build_basic(self) -> Result<BasicInferer>;
123
124 fn build_fixed(self, sizes: &[usize]) -> Result<FixedBatchInferer>;
126
127 fn build_memoizing(self, preload_sizes: &[usize]) -> Result<MemoizingDynamicInferer>;
129
130 fn build_dynamic(self) -> Result<DynamicInferer>;
132}
133
134pub struct InfererBuilder<P: InfererProvider> {
136 provider: P,
137}
138
139impl<P> InfererBuilder<P>
140where
141 P: InfererProvider,
142{
143 pub fn new(provider: P) -> Self {
145 Self { provider }
146 }
147
148 pub fn build_basic(self) -> Result<BasicInferer> {
150 self.provider.build_basic()
151 }
152
153 pub fn build_fixed(self, sizes: &[usize]) -> Result<FixedBatchInferer> {
155 self.provider.build_fixed(sizes)
156 }
157
158 pub fn build_dynamic(self) -> Result<DynamicInferer> {
160 self.provider.build_dynamic()
161 }
162
163 pub fn build_memoizing(self, preload_sizes: &[usize]) -> Result<MemoizingDynamicInferer> {
165 self.provider.build_memoizing(preload_sizes)
166 }
167}
168
169pub trait InfererExt: Inferer + Sized {
172 fn with_default_epsilon(self, key: &str) -> Result<EpsilonInjector<Self>> {
174 EpsilonInjector::wrap(self, key)
175 }
176
177 fn with_epsilon<G: NoiseGenerator>(
179 self,
180 generator: G,
181 key: &str,
182 ) -> Result<EpsilonInjector<Self, G>> {
183 EpsilonInjector::with_generator(self, generator, key)
184 }
185
186 fn into_batched(self) -> Batched<Self> {
188 Batched::wrap(self)
189 }
190
191 #[deprecated(
193 note = "Please use the more explicit 'infer_batch' instead.",
194 since = "0.3.0"
195 )]
196 fn infer(
197 &mut self,
198 observations: HashMap<u64, State<'_>>,
199 ) -> Result<HashMap<u64, Response<'_>>, Error> {
200 self.infer_batch(observations)
201 }
202
203 fn infer_batch<'this>(
205 &'this self,
206 batch: HashMap<u64, State<'_>>,
207 ) -> Result<HashMap<u64, Response<'this>>, anyhow::Error> {
208 let mut batcher = Batcher::new_sized(self, batch.len());
209 batcher.extend(batch)?;
210
211 batcher.execute(self)
212 }
213
214 fn infer_single<'this>(
216 &'this self,
217 input: State<'_>,
218 ) -> Result<Response<'this>, anyhow::Error> {
219 let mut batcher = Batcher::new_sized(self, 1);
220 batcher.push(0, input)?;
221
222 Ok(batcher.execute(self)?.remove(&0).unwrap())
223 }
224}
225
226impl<T> InfererExt for T where T: Inferer + Sized {}
227
228impl Inferer for Box<dyn Inferer + Send> {
229 fn select_batch_size(&self, max_count: usize) -> usize {
230 self.as_ref().select_batch_size(max_count)
231 }
232
233 fn infer_raw(&self, batch: &mut ScratchPadView<'_>) -> Result<(), anyhow::Error> {
234 self.as_ref().infer_raw(batch)
235 }
236
237 fn raw_input_shapes(&self) -> &[(String, Vec<usize>)] {
238 self.as_ref().raw_input_shapes()
239 }
240
241 fn raw_output_shapes(&self) -> &[(String, Vec<usize>)] {
242 self.as_ref().raw_output_shapes()
243 }
244
245 fn begin_agent(&self, id: u64) {
246 self.as_ref().begin_agent(id);
247 }
248
249 fn end_agent(&self, id: u64) {
250 self.as_ref().end_agent(id);
251 }
252}
253
254impl Inferer for Box<dyn Inferer> {
255 fn select_batch_size(&self, max_count: usize) -> usize {
256 self.as_ref().select_batch_size(max_count)
257 }
258
259 fn infer_raw(&self, batch: &mut ScratchPadView<'_>) -> Result<(), anyhow::Error> {
260 self.as_ref().infer_raw(batch)
261 }
262
263 fn raw_input_shapes(&self) -> &[(String, Vec<usize>)] {
264 self.as_ref().raw_input_shapes()
265 }
266
267 fn raw_output_shapes(&self) -> &[(String, Vec<usize>)] {
268 self.as_ref().raw_output_shapes()
269 }
270
271 fn begin_agent(&self, id: u64) {
272 self.as_ref().begin_agent(id);
273 }
274
275 fn end_agent(&self, id: u64) {
276 self.as_ref().end_agent(id);
277 }
278}