1use std::ops::Range;
6use tract_core::tract_data::TVec;
7
8pub(super) struct ScratchPadData {
10 pub(super) name: String,
12
13 pub(super) data: Vec<f32>,
15
16 pub(super) count: usize,
18}
19
20impl ScratchPadData {
21 fn new(name: String, count: usize, capacity: usize) -> Self {
23 let mut this = Self {
24 name,
25 data: vec![],
26 count,
27 };
28
29 this.reserve(capacity);
30 this
31 }
32
33 fn reserve(&mut self, batch_size: usize) {
35 self.data.resize(batch_size * self.count, 0.0);
36 }
37
38 #[inline]
40 fn view(&self, range: Range<usize>) -> &[f32] {
41 &self.data[range.start * self.count..range.end * self.count]
42 }
43
44 #[inline]
46 fn view_mut(&mut self, range: Range<usize>) -> &mut [f32] {
47 let start = range.start * self.count;
48 let end = range.end * self.count;
49
50 &mut self.data[start..end]
51 }
52}
53
54const DEFAULT_CAPACITY: usize = 6;
55pub struct ScratchPad {
58 pub(super) inputs: TVec<ScratchPadData>,
59 pub(super) outputs: TVec<ScratchPadData>,
60 pub(super) ids: Vec<u64>,
61 pub(super) batch_size: usize,
62 capacity: usize,
63}
64
65impl ScratchPad {
66 pub fn new_for_shapes(
69 inputs: &[(String, Vec<usize>)],
70 outputs: &[(String, Vec<usize>)],
71 ) -> Self {
72 Self::new_with_size(inputs, outputs, DEFAULT_CAPACITY)
73 }
74
75 pub fn new_with_size(
78 inputs: &[(String, Vec<usize>)],
79 outputs: &[(String, Vec<usize>)],
80 capacity: usize,
81 ) -> Self {
82 let inputs = inputs
83 .iter()
84 .map(|(name, shape)| {
85 let count = shape.iter().product();
86 ScratchPadData::new(name.clone(), count, capacity)
87 })
88 .collect();
89
90 let outputs = outputs
91 .iter()
92 .map(|(name, shape)| {
93 let count = shape.iter().product();
94 ScratchPadData::new(name.clone(), count, capacity)
95 })
96 .collect();
97
98 Self {
99 inputs,
100 outputs,
101 ids: vec![],
102 batch_size: 0,
103 capacity,
104 }
105 }
106
107 pub fn next(&mut self, id: u64) {
109 self.batch_size += 1;
110 self.ids.push(id);
111
112 if self.batch_size > self.capacity {
113 self.capacity *= 2;
114
115 for slot in &mut self.inputs {
116 slot.reserve(self.capacity);
117 }
118
119 for slot in &mut self.outputs {
120 slot.reserve(self.capacity);
121 }
122 }
123 }
124
125 pub fn push(&mut self, slot: usize, data: Vec<f32>) {
127 self.inputs[slot]
128 .view_mut(self.batch_size - 1..self.batch_size)
129 .copy_from_slice(&data);
130 }
131
132 pub fn chunk(&mut self, offset: usize, size: usize) -> ScratchPadView<'_> {
134 let size = size.min(self.batch_size);
135 self.batch_size -= size;
136
137 ScratchPadView {
138 pad: self,
139 batch_range: offset..offset + size,
140 }
141 }
142
143 #[inline]
145 pub(crate) fn input_slot(&self, slot: usize, range: Range<usize>) -> &[f32] {
146 self.inputs[slot].view(range)
147 }
148
149 #[inline]
151 pub(crate) fn input_slot_mut(&mut self, slot: usize, range: Range<usize>) -> &mut [f32] {
152 self.inputs[slot].view_mut(range)
153 }
154
155 #[inline]
157 pub(crate) fn input_name(&self, slot: usize) -> &str {
158 &self.inputs[slot].name
159 }
160
161 #[inline]
163 pub(crate) fn output_slot(&self, slot: usize, range: Range<usize>) -> &[f32] {
164 self.outputs[slot].view(range)
165 }
166
167 #[inline]
169 pub(crate) fn output_slot_mut(&mut self, slot: usize, range: Range<usize>) -> &mut [f32] {
170 self.outputs[slot].view_mut(range)
171 }
172
173 #[inline]
175 pub(crate) fn output_name(&self, slot: usize) -> &str {
176 &self.outputs[slot].name
177 }
178
179 pub(crate) fn lookup_output_slot(&self, name: &str) -> Option<usize> {
180 self.outputs.iter().position(|slot| slot.name == name)
181 }
182}
183
184pub struct ScratchPadView<'a> {
186 pad: &'a mut ScratchPad,
187 batch_range: Range<usize>,
188}
189
190impl<'a> ScratchPadView<'a> {
191 pub fn inner(&self) -> &ScratchPad {
192 self.pad
193 }
194
195 pub fn input_slot_with_id(&self, slot: usize) -> (&[u64], &[f32]) {
197 (
198 &self.pad.ids[self.batch_range.clone()],
199 self.pad.input_slot(slot, self.batch_range.clone()),
200 )
201 }
202
203 pub fn input_slot_mut_with_id(&mut self, slot: usize) -> (&[u64], &mut [f32]) {
205 (
206 &self.pad.ids[self.batch_range.clone()],
207 self.pad.inputs[slot].view_mut(self.batch_range.clone()),
208 )
209 }
210
211 pub fn output_slot_mut_with_id(&mut self, slot: usize) -> (&[u64], &mut [f32]) {
213 (
214 &self.pad.ids[self.batch_range.clone()],
215 self.pad.outputs[slot].view_mut(self.batch_range.clone()),
216 )
217 }
218
219 pub fn input_slot(&self, slot: usize) -> &[f32] {
221 self.pad.input_slot(slot, self.batch_range.clone())
222 }
223
224 pub fn input_slot_mut(&mut self, slot: usize) -> &mut [f32] {
226 self.pad.input_slot_mut(slot, self.batch_range.clone())
227 }
228
229 pub fn input_name(&self, slot: usize) -> &str {
231 self.pad.input_name(slot)
232 }
233
234 pub fn output_slot(&self, slot: usize) -> &[f32] {
236 self.pad.output_slot(slot, self.batch_range.clone())
237 }
238
239 pub fn output_slot_mut(&mut self, slot: usize) -> &mut [f32] {
241 self.pad.output_slot_mut(slot, self.batch_range.clone())
242 }
243
244 pub fn output_name(&self, slot: usize) -> &str {
246 self.pad.output_name(slot)
247 }
248
249 #[allow(clippy::len_without_is_empty)]
251 pub fn len(&self) -> usize {
252 self.batch_range.len()
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 mod scratchppaddata {
259 use super::super::ScratchPadData;
260 #[test]
261 fn has_right_initial_space() {
262 let spd = ScratchPadData::new("epsilon".to_owned(), 24, 2);
263
264 assert_eq!(spd.count, 24);
265 assert_eq!(spd.data.len(), 48);
266 assert_eq!(spd.name, "epsilon");
267 }
268
269 #[test]
270 fn reserves_correct_size() {
271 let mut spd = ScratchPadData::new("epsilon".to_owned(), 24, 2);
272
273 spd.reserve(4);
274 assert_eq!(spd.count, 24);
275 assert_eq!(spd.data.len(), 24 * 4);
276 }
277
278 #[test]
279 fn views_correct_range() {
280 let mut spd = ScratchPadData::new("epsilon".to_owned(), 6, 4);
281
282 spd.reserve(4);
283 for idx in 0..24 {
284 spd.data[idx] = idx as f32;
285 }
286
287 assert_eq!(spd.view(0..1), [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);
288 assert_eq!(spd.view_mut(0..1), [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);
289 assert_eq!(spd.view(1..2), [6.0, 7.0, 8.0, 9.0, 10.0, 11.0]);
290 assert_eq!(spd.view_mut(1..2), [6.0, 7.0, 8.0, 9.0, 10.0, 11.0]);
291
292 assert_eq!(spd.view(3..4), [18.0, 19.0, 20.0, 21.0, 22.0, 23.0]);
293 assert_eq!(spd.view_mut(3..4), [18.0, 19.0, 20.0, 21.0, 22.0, 23.0]);
294 }
295 }
296}