cervo_core/batcher/
scratch.rs

1// Author: Tom Solberg <tom.solberg@embark-studios.com>
2// Copyright © 2022, Embark Studios, all rights reserved.
3// Created: 27 July 2022
4
5use std::ops::Range;
6use tract_core::tract_data::TVec;
7
8/// Data container for a single slot in the scratchpad.
9pub(super) struct ScratchPadData {
10    /// The slot name in the model input
11    pub(super) name: String,
12
13    /// The data store
14    pub(super) data: Vec<f32>,
15
16    /// Number of data elements per batch-element.
17    pub(super) count: usize,
18}
19
20impl ScratchPadData {
21    /// Construct a new slot data with the specified capacity and element count.
22    fn new(name: String, count: usize, capacity: usize) -> Self {
23        let mut this = Self {
24            name,
25            data: vec![],
26            count,
27        };
28
29        this.reserve(capacity);
30        this
31    }
32
33    /// Reserve space for this many batch elemeents.
34    fn reserve(&mut self, batch_size: usize) {
35        self.data.resize(batch_size * self.count, 0.0);
36    }
37
38    /// A view over the specified range of batch elements.
39    #[inline]
40    fn view(&self, range: Range<usize>) -> &[f32] {
41        &self.data[range.start * self.count..range.end * self.count]
42    }
43
44    /// A mutable view over the specified range of batch elements.
45    #[inline]
46    fn view_mut(&mut self, range: Range<usize>) -> &mut [f32] {
47        let start = range.start * self.count;
48        let end = range.end * self.count;
49
50        &mut self.data[start..end]
51    }
52}
53
54const DEFAULT_CAPACITY: usize = 6;
55/// A scratch pad used during each inference call to avoid fragmented
56/// allocations and copying.
57pub struct ScratchPad {
58    pub(super) inputs: TVec<ScratchPadData>,
59    pub(super) outputs: TVec<ScratchPadData>,
60    pub(super) ids: Vec<u64>,
61    pub(super) batch_size: usize,
62    capacity: usize,
63}
64
65impl ScratchPad {
66    // TODO[TSolberg]: When switching to raw ModelAPI, fix this.
67    /// Construct a new scratchpad for the provided API.
68    pub fn new_for_shapes(
69        inputs: &[(String, Vec<usize>)],
70        outputs: &[(String, Vec<usize>)],
71    ) -> Self {
72        Self::new_with_size(inputs, outputs, DEFAULT_CAPACITY)
73    }
74
75    // TODO[TSolberg]: When switching to raw ModelAPI, fix this.
76    /// Construct a new scratchpad for the provided API with a specified default capacity.
77    pub fn new_with_size(
78        inputs: &[(String, Vec<usize>)],
79        outputs: &[(String, Vec<usize>)],
80        capacity: usize,
81    ) -> Self {
82        let inputs = inputs
83            .iter()
84            .map(|(name, shape)| {
85                let count = shape.iter().product();
86                ScratchPadData::new(name.clone(), count, capacity)
87            })
88            .collect();
89
90        let outputs = outputs
91            .iter()
92            .map(|(name, shape)| {
93                let count = shape.iter().product();
94                ScratchPadData::new(name.clone(), count, capacity)
95            })
96            .collect();
97
98        Self {
99            inputs,
100            outputs,
101            ids: vec![],
102            batch_size: 0,
103            capacity,
104        }
105    }
106
107    /// Prepare the next slot to store data for the provided id.
108    pub fn next(&mut self, id: u64) {
109        self.batch_size += 1;
110        self.ids.push(id);
111
112        if self.batch_size > self.capacity {
113            self.capacity *= 2;
114
115            for slot in &mut self.inputs {
116                slot.reserve(self.capacity);
117            }
118
119            for slot in &mut self.outputs {
120                slot.reserve(self.capacity);
121            }
122        }
123    }
124
125    /// Push data for the specific slot.
126    pub fn push(&mut self, slot: usize, data: Vec<f32>) {
127        self.inputs[slot]
128            .view_mut(self.batch_size - 1..self.batch_size)
129            .copy_from_slice(&data);
130    }
131
132    /// View the chunk starting at batch-element `offset` containing `size` elements.x
133    pub fn chunk(&mut self, offset: usize, size: usize) -> ScratchPadView<'_> {
134        let size = size.min(self.batch_size);
135        self.batch_size -= size;
136
137        ScratchPadView {
138            pad: self,
139            batch_range: offset..offset + size,
140        }
141    }
142
143    /// View of the specified `range` of input at location `slot`.
144    #[inline]
145    pub(crate) fn input_slot(&self, slot: usize, range: Range<usize>) -> &[f32] {
146        self.inputs[slot].view(range)
147    }
148
149    /// A mutable view of the specified `range` of input at location `slot`.
150    #[inline]
151    pub(crate) fn input_slot_mut(&mut self, slot: usize, range: Range<usize>) -> &mut [f32] {
152        self.inputs[slot].view_mut(range)
153    }
154
155    /// Retrieve the input name for `slot`.
156    #[inline]
157    pub(crate) fn input_name(&self, slot: usize) -> &str {
158        &self.inputs[slot].name
159    }
160
161    /// View of the specified `range` of output at location `slot`.
162    #[inline]
163    pub(crate) fn output_slot(&self, slot: usize, range: Range<usize>) -> &[f32] {
164        self.outputs[slot].view(range)
165    }
166
167    /// A mutable view of the specified `range` of output at location `slot`.
168    #[inline]
169    pub(crate) fn output_slot_mut(&mut self, slot: usize, range: Range<usize>) -> &mut [f32] {
170        self.outputs[slot].view_mut(range)
171    }
172
173    /// Retrieve the output name for `slot`.
174    #[inline]
175    pub(crate) fn output_name(&self, slot: usize) -> &str {
176        &self.outputs[slot].name
177    }
178
179    pub(crate) fn lookup_output_slot(&self, name: &str) -> Option<usize> {
180        self.outputs.iter().position(|slot| slot.name == name)
181    }
182}
183
184/// A view over a set of batch elements in a scratch pad.
185pub struct ScratchPadView<'a> {
186    pad: &'a mut ScratchPad,
187    batch_range: Range<usize>,
188}
189
190impl<'a> ScratchPadView<'a> {
191    pub fn inner(&self) -> &ScratchPad {
192        self.pad
193    }
194
195    /// View of the input at location `slot`.
196    pub fn input_slot_with_id(&self, slot: usize) -> (&[u64], &[f32]) {
197        (
198            &self.pad.ids[self.batch_range.clone()],
199            self.pad.input_slot(slot, self.batch_range.clone()),
200        )
201    }
202
203    /// Mutable view of the input at location `slot`.
204    pub fn input_slot_mut_with_id(&mut self, slot: usize) -> (&[u64], &mut [f32]) {
205        (
206            &self.pad.ids[self.batch_range.clone()],
207            self.pad.inputs[slot].view_mut(self.batch_range.clone()),
208        )
209    }
210
211    /// Mutable view of the input at location `slot`.
212    pub fn output_slot_mut_with_id(&mut self, slot: usize) -> (&[u64], &mut [f32]) {
213        (
214            &self.pad.ids[self.batch_range.clone()],
215            self.pad.outputs[slot].view_mut(self.batch_range.clone()),
216        )
217    }
218
219    /// View of the input at location `slot`.
220    pub fn input_slot(&self, slot: usize) -> &[f32] {
221        self.pad.input_slot(slot, self.batch_range.clone())
222    }
223
224    /// A mutable view of the input at location `slot`.
225    pub fn input_slot_mut(&mut self, slot: usize) -> &mut [f32] {
226        self.pad.input_slot_mut(slot, self.batch_range.clone())
227    }
228
229    /// Retrieve the input name for `slot`.
230    pub fn input_name(&self, slot: usize) -> &str {
231        self.pad.input_name(slot)
232    }
233
234    /// A mutable view of the data at input `slot`.
235    pub fn output_slot(&self, slot: usize) -> &[f32] {
236        self.pad.output_slot(slot, self.batch_range.clone())
237    }
238
239    /// A mutable view of the data at location `slot`.
240    pub fn output_slot_mut(&mut self, slot: usize) -> &mut [f32] {
241        self.pad.output_slot_mut(slot, self.batch_range.clone())
242    }
243
244    /// Retrieve the output name for `slot`.
245    pub fn output_name(&self, slot: usize) -> &str {
246        self.pad.output_name(slot)
247    }
248
249    /// The batch size of this view.
250    #[allow(clippy::len_without_is_empty)]
251    pub fn len(&self) -> usize {
252        self.batch_range.len()
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    mod scratchppaddata {
259        use super::super::ScratchPadData;
260        #[test]
261        fn has_right_initial_space() {
262            let spd = ScratchPadData::new("epsilon".to_owned(), 24, 2);
263
264            assert_eq!(spd.count, 24);
265            assert_eq!(spd.data.len(), 48);
266            assert_eq!(spd.name, "epsilon");
267        }
268
269        #[test]
270        fn reserves_correct_size() {
271            let mut spd = ScratchPadData::new("epsilon".to_owned(), 24, 2);
272
273            spd.reserve(4);
274            assert_eq!(spd.count, 24);
275            assert_eq!(spd.data.len(), 24 * 4);
276        }
277
278        #[test]
279        fn views_correct_range() {
280            let mut spd = ScratchPadData::new("epsilon".to_owned(), 6, 4);
281
282            spd.reserve(4);
283            for idx in 0..24 {
284                spd.data[idx] = idx as f32;
285            }
286
287            assert_eq!(spd.view(0..1), [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);
288            assert_eq!(spd.view_mut(0..1), [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);
289            assert_eq!(spd.view(1..2), [6.0, 7.0, 8.0, 9.0, 10.0, 11.0]);
290            assert_eq!(spd.view_mut(1..2), [6.0, 7.0, 8.0, 9.0, 10.0, 11.0]);
291
292            assert_eq!(spd.view(3..4), [18.0, 19.0, 20.0, 21.0, 22.0, 23.0]);
293            assert_eq!(spd.view_mut(3..4), [18.0, 19.0, 20.0, 21.0, 22.0, 23.0]);
294        }
295    }
296}