1mod scratch;
15mod wrapper;
16
17use self::scratch::ScratchPad;
18use crate::inferer::{Inferer, Response, State};
19pub use scratch::ScratchPadView;
20use std::collections::HashMap;
21pub use wrapper::Batched;
22
23pub struct Batcher {
34 scratch: ScratchPad,
35}
36
37impl Batcher {
38 pub fn new(inferer: &dyn Inferer) -> Self {
40 Self {
41 scratch: ScratchPad::new_for_shapes(
42 inferer.raw_input_shapes(),
43 inferer.raw_output_shapes(),
44 ),
45 }
46 }
47
48 pub fn new_sized(inferer: &dyn Inferer, size: usize) -> Self {
50 Self {
51 scratch: ScratchPad::new_with_size(
52 inferer.raw_input_shapes(),
53 inferer.raw_output_shapes(),
54 size,
55 ),
56 }
57 }
58
59 #[inline]
60 fn input_slot(&self, name: &str) -> Option<usize> {
61 self.scratch
62 .inputs
63 .iter()
64 .position(|slot| slot.name == name)
65 }
66
67 pub fn push(&mut self, id: u64, state: State<'_>) -> anyhow::Result<()> {
69 self.scratch.next(id);
70 for (k, v) in state.data {
71 let slot = self
72 .input_slot(k)
73 .ok_or_else(|| anyhow::anyhow!("key doesn't match an input: {:?}", k))?;
74
75 self.scratch.push(slot, v);
76 }
77
78 Ok(())
79 }
80
81 pub fn extend<'a, Iter: IntoIterator<Item = (u64, State<'a>)>>(
83 &mut self,
84 states: Iter,
85 ) -> anyhow::Result<()> {
86 for (id, state) in states {
87 self.push(id, state)?;
88 }
89
90 Ok(())
91 }
92
93 pub fn execute<'b>(
95 &mut self,
96 inferer: &'b dyn Inferer,
97 ) -> anyhow::Result<HashMap<u64, Response<'b>>> {
98 let mut total_offset = 0;
101 while self.scratch.batch_size > 0 {
102 let preferred_batch_size = inferer.select_batch_size(self.scratch.batch_size);
103
104 let mut view = self.scratch.chunk(total_offset, preferred_batch_size);
105
106 inferer.infer_raw(&mut view)?;
107 total_offset += preferred_batch_size;
108 }
109
110 let mut outputs = vec![Response::empty(); self.scratch.ids.len()];
111
112 for slot in 0..inferer.output_shapes().len() {
113 let slot_name = &inferer.output_shapes()[slot].0;
114 let scratch_slot = self
115 .scratch
116 .lookup_output_slot(slot_name)
117 .expect("invalid inferer passed to `Batcher::execute`");
118
119 for (idx, o) in outputs.iter_mut().enumerate() {
120 let slot_response = self.scratch.output_slot(scratch_slot, idx..idx + 1);
121 o.data.insert(slot_name, slot_response.to_owned());
122 }
123 }
124
125 Ok(self.scratch.ids.drain(..).zip(outputs).collect::<_>())
126 }
127
128 pub fn is_empty(&self) -> bool {
130 self.scratch.batch_size == 0
131 }
132
133 pub fn len(&self) -> usize {
135 self.scratch.batch_size
136 }
137}