1use std::collections::HashMap;
2
3use crate::{batcher::ScratchPadView, inferer::Inferer, prelude::InfererWrapper};
4use anyhow::{Context, Result};
5use itertools::Itertools;
6use parking_lot::RwLock;
7use tract_core::tract_data::TVec;
8
9pub struct RecurrentInfo {
10 pub inkey: String,
11 pub outkey: String,
12}
13
14struct RecurrentPair {
15 inslot: usize,
16 outslot: usize,
17 numels: usize,
18 offset: usize,
19}
20
21struct RecurrentState {
22 keys: TVec<RecurrentPair>,
23 per_agent_states: RwLock<HashMap<u64, Box<[f32]>>>,
24 agent_state_size: usize,
25 inputs: Vec<(String, Vec<usize>)>,
27 outputs: Vec<(String, Vec<usize>)>,
28}
29
30impl RecurrentState {
31 fn apply(&self, batch: &mut ScratchPadView<'_>) {
32 for pair in &self.keys {
33 let (ids, indata) = batch.input_slot_mut_with_id(pair.inslot);
34
35 let mut offset = 0;
36 let states = self.per_agent_states.read();
37 for id in ids {
38 if let Some(state) = states.get(id) {
40 indata[offset..offset + pair.numels]
41 .copy_from_slice(&state[pair.offset..pair.offset + pair.numels]);
42 } else {
43 indata[offset..offset + pair.numels].fill(0.0);
44 }
45 offset += pair.numels;
46 }
47 }
48 }
49
50 fn extract(&self, batch: &mut ScratchPadView<'_>) {
51 for pair in &self.keys {
52 let (ids, outdata) = batch.output_slot_mut_with_id(pair.outslot);
53
54 let mut offset = 0;
55 let mut states = self.per_agent_states.write();
56 for id in ids {
57 if let Some(state) = states.get_mut(id) {
59 state[pair.offset..pair.offset + pair.numels]
60 .copy_from_slice(&outdata[offset..offset + pair.numels]);
61 }
62
63 offset += pair.numels;
64 }
65 }
66 }
67}
68
69pub struct RecurrentTracker<T: Inferer> {
73 inner: T,
74 state: RecurrentState,
75}
76
77impl<T> RecurrentTracker<T>
78where
79 T: Inferer,
80{
81 pub fn wrap(inferer: T) -> Result<RecurrentTracker<T>> {
83 let inputs = inferer.raw_input_shapes();
84 let outputs = inferer.raw_output_shapes();
85
86 let mut keys = vec![];
87
88 for (inkey, inshape) in inputs {
89 for (outkey, outshape) in outputs {
90 if inkey == outkey && inshape == outshape {
91 keys.push(RecurrentInfo {
92 inkey: inkey.clone(),
93 outkey: outkey.clone(),
94 });
95 }
96 }
97 }
98
99 if keys.is_empty() {
100 let inkeys = inputs.iter().map(|(k, _)| k).join(", ");
101 let outkeys = outputs.iter().map(|(k, _)| k).join(", ");
102 anyhow::bail!(
103 "Unable to find a matching key between inputs [{inkeys}] and outputs [{outkeys}]"
104 );
105 }
106 Self::new(inferer, keys)
107 }
108
109 pub fn new(inferer: T, info: Vec<RecurrentInfo>) -> Result<Self> {
112 let raw_inputs = inferer.raw_input_shapes();
113 let raw_outputs = inferer.raw_output_shapes();
114
115 let mut offset = 0;
116 let keys = info
117 .iter()
118 .map(|info| {
119 let inslot = raw_inputs
120 .iter()
121 .position(|input| info.inkey == input.0)
122 .with_context(|| format!("no input named {}", info.inkey))?;
123 let outslot = raw_outputs
124 .iter()
125 .position(|output| info.outkey == output.0)
126 .with_context(|| format!("no output named {}", info.outkey))?;
127
128 let numels = raw_inputs[inslot].1.iter().product();
129 offset += numels;
130 Ok(RecurrentPair {
131 inslot,
132 outslot,
133 numels,
134 offset: offset - numels,
135 })
136 })
137 .collect::<Result<TVec<RecurrentPair>>>()?;
138
139 let inputs = inferer.input_shapes();
140 let outputs = inferer.output_shapes();
141
142 let inputs = inputs
143 .iter()
144 .filter(|(k, _)| !info.iter().any(|info| &info.inkey == k))
145 .cloned()
146 .collect::<Vec<_>>();
147
148 let outputs = outputs
149 .iter()
150 .filter(|(k, _)| !info.iter().any(|info| &info.outkey == k))
151 .cloned()
152 .collect::<Vec<_>>();
153
154 Ok(Self {
155 inner: inferer,
156 state: RecurrentState {
157 keys,
158 agent_state_size: offset,
159 inputs,
160 outputs,
161 per_agent_states: Default::default(),
162 },
163 })
164 }
165}
166
167impl<T> Inferer for RecurrentTracker<T>
168where
169 T: Inferer,
170{
171 fn select_batch_size(&self, max_count: usize) -> usize {
172 self.inner.select_batch_size(max_count)
173 }
174
175 fn infer_raw(&self, batch: &mut ScratchPadView<'_>) -> Result<(), anyhow::Error> {
176 self.state.apply(batch);
177
178 self.inner.infer_raw(batch)?;
179
180 self.state.extract(batch);
181
182 Ok(())
183 }
184
185 fn raw_input_shapes(&self) -> &[(String, Vec<usize>)] {
186 self.inner.raw_input_shapes()
187 }
188
189 fn raw_output_shapes(&self) -> &[(String, Vec<usize>)] {
190 self.inner.raw_output_shapes()
191 }
192
193 fn input_shapes(&self) -> &[(String, Vec<usize>)] {
194 &self.state.inputs
195 }
196
197 fn output_shapes(&self) -> &[(String, Vec<usize>)] {
198 &self.state.outputs
199 }
200
201 fn begin_agent(&self, id: u64) {
202 self.state.per_agent_states.write().insert(
203 id,
204 vec![0.0; self.state.agent_state_size].into_boxed_slice(),
205 );
206 self.inner.begin_agent(id);
207 }
208
209 fn end_agent(&self, id: u64) {
210 self.state.per_agent_states.write().remove(&id);
211 self.inner.end_agent(id);
212 }
213}
214
215pub struct RecurrentTrackerWrapper<Inner: InfererWrapper> {
220 inner: Inner,
221 state: RecurrentState,
222}
223
224impl<Inner: InfererWrapper> RecurrentTrackerWrapper<Inner> {
225 pub fn wrap<T: Inferer>(inner: Inner, inferer: &T) -> Result<RecurrentTrackerWrapper<Inner>> {
227 let inputs = inferer.raw_input_shapes();
228 let outputs = inferer.raw_output_shapes();
229
230 let mut keys = vec![];
231
232 for (inkey, inshape) in inputs {
233 for (outkey, outshape) in outputs {
234 if inkey == outkey && inshape == outshape {
235 keys.push(RecurrentInfo {
236 inkey: inkey.clone(),
237 outkey: outkey.clone(),
238 });
239 }
240 }
241 }
242
243 if keys.is_empty() {
244 let inkeys = inputs.iter().map(|(k, _)| k).join(", ");
245 let outkeys = outputs.iter().map(|(k, _)| k).join(", ");
246 anyhow::bail!(
247 "Unable to find a matching key between inputs [{inkeys}] and outputs [{outkeys}]"
248 );
249 }
250 Self::new(inner, inferer, keys)
251 }
252
253 pub fn new<T: Inferer>(inner: Inner, inferer: &T, info: Vec<RecurrentInfo>) -> Result<Self> {
256 let raw_inputs = inferer.raw_input_shapes();
257 let raw_outputs = inferer.raw_output_shapes();
258
259 let mut offset = 0;
260 let keys = info
261 .iter()
262 .map(|info| {
263 let inslot = raw_inputs
264 .iter()
265 .position(|input| info.inkey == input.0)
266 .with_context(|| format!("no input named {}", info.inkey))?;
267 let outslot = raw_outputs
268 .iter()
269 .position(|output| info.outkey == output.0)
270 .with_context(|| format!("no output named {}", info.outkey))?;
271
272 let numels = raw_inputs[inslot].1.iter().product();
273 offset += numels;
274 Ok(RecurrentPair {
275 inslot,
276 outslot,
277 numels,
278 offset: offset - numels,
279 })
280 })
281 .collect::<Result<TVec<RecurrentPair>>>()?;
282
283 let inputs = inner.input_shapes(inferer);
284 let outputs = inner.output_shapes(inferer);
285
286 let inputs = inputs
287 .iter()
288 .filter(|(k, _)| !info.iter().any(|info| &info.inkey == k))
289 .cloned()
290 .collect::<Vec<_>>();
291
292 let outputs = outputs
293 .iter()
294 .filter(|(k, _)| !info.iter().any(|info| &info.outkey == k))
295 .cloned()
296 .collect::<Vec<_>>();
297
298 Ok(Self {
299 inner,
300 state: RecurrentState {
301 keys,
302 agent_state_size: offset,
303 inputs,
304 outputs,
305 per_agent_states: Default::default(),
306 },
307 })
308 }
309}
310
311impl<Inner: InfererWrapper> InfererWrapper for RecurrentTrackerWrapper<Inner> {
312 fn invoke(&self, inferer: &dyn Inferer, batch: &mut ScratchPadView<'_>) -> anyhow::Result<()> {
313 self.state.apply(batch);
314 self.inner.invoke(inferer, batch)?;
315 self.state.extract(batch);
316
317 Ok(())
318 }
319
320 fn input_shapes<'a>(&'a self, _inferer: &'a dyn Inferer) -> &'a [(String, Vec<usize>)] {
321 self.state.inputs.as_ref()
322 }
323
324 fn output_shapes<'a>(&'a self, _inferer: &'a dyn Inferer) -> &'a [(String, Vec<usize>)] {
325 self.state.outputs.as_ref()
326 }
327
328 fn begin_agent(&self, inferer: &dyn Inferer, id: u64) {
329 self.state.per_agent_states.write().insert(
330 id,
331 vec![0.0; self.state.agent_state_size].into_boxed_slice(),
332 );
333 self.inner.begin_agent(inferer, id);
334 }
335
336 fn end_agent(&self, inferer: &dyn Inferer, id: u64) {
337 self.state.per_agent_states.write().remove(&id);
338 self.inner.end_agent(inferer, id);
339 }
340}
341
342#[cfg(test)]
343mod tests {
344
345 use std::sync::atomic::{AtomicBool, Ordering};
346
347 use crate::{
348 batcher::ScratchPadView,
349 inferer::State,
350 prelude::{Batcher, Inferer},
351 recurrent::RecurrentTrackerWrapper,
352 wrapper::InfererWrapper,
353 };
354
355 use super::RecurrentTracker;
356
357 struct DummyInferer {
358 end_called: AtomicBool,
359 begin_called: AtomicBool,
360 inputs: Vec<(String, Vec<usize>)>,
361 outputs: Vec<(String, Vec<usize>)>,
362 }
363
364 impl Default for DummyInferer {
365 fn default() -> Self {
366 Self::new_named(
367 "lstm_hidden_state",
368 "lstm_cell_state",
369 "lstm_hidden_state",
370 "lstm_cell_state",
371 )
372 }
373 }
374
375 impl DummyInferer {
376 fn new_named(
377 hidden_name_in: &str,
378 cell_name_in: &str,
379 hidden_name_out: &str,
380 cell_name_out: &str,
381 ) -> Self {
382 Self {
383 end_called: false.into(),
384 begin_called: false.into(),
385 inputs: vec![
386 ("epsilon".to_owned(), vec![2]),
387 (hidden_name_in.to_owned(), vec![2, 1]),
388 (cell_name_in.to_owned(), vec![2, 3]),
389 ],
390 outputs: vec![
391 (hidden_name_out.to_owned(), vec![2, 1]),
392 (cell_name_out.to_owned(), vec![2, 3]),
393 ("hidden_output".to_owned(), vec![2]),
394 ("cell_output".to_owned(), vec![6]),
395 ],
396 }
397 }
398 }
399
400 impl Inferer for DummyInferer {
401 fn select_batch_size(&self, _max_count: usize) -> usize {
402 1
403 }
404
405 fn infer_raw(&self, batch: &mut ScratchPadView<'_>) -> anyhow::Result<(), anyhow::Error> {
406 assert_eq!(batch.inner().input_name(1), "lstm_hidden_state");
407 let hidden_value = batch.input_slot(1);
408 let hidden_new = hidden_value.iter().map(|v| *v + 1.0).collect::<Vec<_>>();
409
410 assert_eq!(batch.inner().output_name(0), "lstm_hidden_state");
411 batch.output_slot_mut(0).copy_from_slice(&hidden_new);
412
413 assert_eq!(batch.inner().input_name(2), "lstm_cell_state");
414 let cell_value = batch.input_slot(2);
415 let cell_new = cell_value.iter().map(|v| *v + 2.0).collect::<Vec<_>>();
416
417 assert_eq!(batch.inner().output_name(1), "lstm_cell_state");
418 batch.output_slot_mut(1).copy_from_slice(&cell_new);
419
420 assert_eq!(batch.inner().output_name(2), "hidden_output");
421 let hidden = batch.output_slot_mut(2);
422 hidden.copy_from_slice(&hidden_new);
423
424 assert_eq!(batch.inner().output_name(3), "cell_output");
425 let cell = batch.output_slot_mut(3);
426 cell.copy_from_slice(&cell_new);
427
428 Ok(())
429 }
430
431 fn raw_input_shapes(&self) -> &[(String, Vec<usize>)] {
432 &self.inputs
433 }
434
435 fn raw_output_shapes(&self) -> &[(String, Vec<usize>)] {
436 &self.outputs
437 }
438
439 fn begin_agent(&self, _id: u64) {
440 self.begin_called.store(true, Ordering::Relaxed);
441 }
442 fn end_agent(&self, _id: u64) {
443 self.end_called.store(true, Ordering::Relaxed);
444 }
445 }
446
447 #[test]
448 fn begin_end_forwarded() {
449 let inferer = DummyInferer::default();
450 let recurrent = RecurrentTracker::wrap(inferer).unwrap();
451
452 recurrent.begin_agent(10);
453 assert!(recurrent.inner.begin_called.load(Ordering::Relaxed));
454
455 recurrent.end_agent(10);
456 assert!(recurrent.inner.end_called.into_inner());
457 }
458
459 #[test]
460 fn begin_creates_state() {
461 let inferer = DummyInferer::default();
462 let recurrent = RecurrentTracker::wrap(inferer).unwrap();
463
464 recurrent.begin_agent(10);
465 assert!(recurrent.state.per_agent_states.read().contains_key(&10));
466 }
467
468 #[test]
469 fn end_removes_state() {
470 let inferer = DummyInferer::default();
471 let recurrent = RecurrentTracker::wrap(inferer).unwrap();
472
473 recurrent.begin_agent(10);
474 recurrent.end_agent(10);
475
476 assert!(!recurrent.state.per_agent_states.read().contains_key(&10));
477 }
478
479 #[test]
480 fn wrap_warns_no_keys() {
481 let inferer = DummyInferer::new_named("a", "b", "c", "d");
482 let should_err = RecurrentTracker::wrap(inferer);
483 assert!(should_err.is_err());
484 }
485
486 #[test]
487 fn test_infer() {
488 let inferer = DummyInferer::default();
489 let mut batcher = Batcher::new(&inferer);
490 let recurrent = RecurrentTracker::wrap(inferer).unwrap();
491
492 recurrent.begin_agent(10);
493 batcher.push(10, State::empty()).unwrap();
494
495 batcher.execute(&recurrent).unwrap();
496 }
497
498 #[test]
499 fn test_infer_output() {
500 let inferer = DummyInferer::default();
501 let mut batcher = Batcher::new(&inferer);
502 let recurrent = RecurrentTracker::wrap(inferer).unwrap();
503
504 recurrent.begin_agent(10);
505 batcher.push(10, State::empty()).unwrap();
506
507 let res = batcher.execute(&recurrent).unwrap();
508 let agent_data = &res[&10];
509 assert!(agent_data.data.contains_key("hidden_output"));
510 assert!(agent_data.data.contains_key("cell_output"));
511
512 assert!(agent_data.data["hidden_output"].iter().all(|v| *v == 1.0));
513 assert!(agent_data.data["cell_output"].iter().all(|v| *v == 2.0));
514 }
515
516 #[test]
517 fn test_infer_twice_output() {
518 let inferer = DummyInferer::default();
519 let mut batcher = Batcher::new(&inferer);
520 let recurrent = RecurrentTracker::wrap(inferer).unwrap();
521
522 recurrent.begin_agent(10);
523 batcher.push(10, State::empty()).unwrap();
524
525 batcher.execute(&recurrent).unwrap();
526 batcher.push(10, State::empty()).unwrap();
527
528 let res = batcher.execute(&recurrent).unwrap();
529
530 let agent_data = &res[&10];
531 assert!(agent_data.data.contains_key("hidden_output"));
532 assert!(agent_data.data.contains_key("cell_output"));
533
534 assert!(agent_data.data["hidden_output"].iter().all(|v| *v == 2.0));
535 assert!(agent_data.data["cell_output"].iter().all(|v| *v == 4.0));
536 }
537
538 #[test]
539 fn test_infer_twice_reuse_id() {
540 let inferer = DummyInferer::default();
541 let mut batcher = Batcher::new(&inferer);
542 let recurrent = RecurrentTracker::wrap(inferer).unwrap();
543
544 recurrent.begin_agent(10);
545 batcher.push(10, State::empty()).unwrap();
546 batcher.execute(&recurrent).unwrap();
547
548 recurrent.end_agent(10);
549
550 recurrent.begin_agent(10);
551
552 batcher.push(10, State::empty()).unwrap();
553
554 let res = batcher.execute(&recurrent).unwrap();
555 let agent_data = &res[&10];
556
557 assert!(agent_data.data.contains_key("hidden_output"));
558 assert!(agent_data.data.contains_key("cell_output"));
559
560 assert!(agent_data.data["hidden_output"].iter().all(|v| *v == 1.0));
561 assert!(agent_data.data["cell_output"].iter().all(|v| *v == 2.0));
562 }
563
564 #[test]
565 fn test_infer_multiple_agents() {
566 let inferer = DummyInferer::default();
567 let mut batcher = Batcher::new(&inferer);
568 let recurrent = RecurrentTracker::wrap(inferer).unwrap();
569
570 recurrent.begin_agent(10);
571 recurrent.begin_agent(20);
572 batcher.push(10, State::empty()).unwrap();
573 batcher.push(20, State::empty()).unwrap();
574 batcher.execute(&recurrent).unwrap();
575
576 recurrent.begin_agent(20);
577 batcher.push(10, State::empty()).unwrap();
578 batcher.push(20, State::empty()).unwrap();
579 batcher.execute(&recurrent).unwrap();
580
581 recurrent.begin_agent(30);
582 batcher.push(10, State::empty()).unwrap();
583 batcher.push(30, State::empty()).unwrap();
584 let res = batcher.execute(&recurrent).unwrap();
585 let agent_data = &res[&10];
586
587 assert!(agent_data.data.contains_key("hidden_output"));
588 assert!(agent_data.data.contains_key("cell_output"));
589
590 assert!(agent_data.data["hidden_output"].iter().all(|v| *v == 3.0));
591 assert!(agent_data.data["cell_output"].iter().all(|v| *v == 6.0));
592
593 let agent_data = &res[&30];
594
595 assert!(agent_data.data.contains_key("hidden_output"));
596 assert!(agent_data.data.contains_key("cell_output"));
597
598 assert!(agent_data.data["hidden_output"].iter().all(|v| *v == 1.0));
599 assert!(agent_data.data["cell_output"].iter().all(|v| *v == 2.0));
600 }
601
602 #[test]
603 fn test_wrapper_does_not_expose_inner_hidden() {
604 struct DummyEpsilonWrapper {
609 inputs: Vec<(String, Vec<usize>)>,
610 }
611
612 impl InfererWrapper for DummyEpsilonWrapper {
613 fn invoke(
614 &self,
615 _inferer: &dyn Inferer,
616 _batch: &mut ScratchPadView<'_>,
617 ) -> anyhow::Result<(), anyhow::Error> {
618 Ok(())
619 }
620 fn input_shapes<'a>(&'a self, _inferer: &'a dyn Inferer) -> &'a [(String, Vec<usize>)] {
621 &self.inputs
622 }
623 fn output_shapes<'a>(
624 &'a self,
625 _inferer: &'a dyn Inferer,
626 ) -> &'a [(String, Vec<usize>)] {
627 _inferer.output_shapes()
628 }
629 fn begin_agent(&self, _inferer: &dyn Inferer, _id: u64) {}
630 fn end_agent(&self, _inferer: &dyn Inferer, _id: u64) {}
631 }
632
633 let inferer = DummyInferer::default();
634 let wrapper = DummyEpsilonWrapper {
635 inputs: vec![
636 ("lstm_hidden_state".to_owned(), vec![2, 1]),
637 ("lstm_cell_state".to_owned(), vec![2, 3]),
638 ],
639 };
640
641 let recurrent = RecurrentTrackerWrapper::wrap(wrapper, &inferer).unwrap();
642
643 assert_eq!(recurrent.input_shapes(&inferer).len(), 0);
644 assert_eq!(
645 recurrent.output_shapes(&inferer).len(),
646 2,
647 "only hidden and cell state are recurrent: {:?}",
648 recurrent.output_shapes(&inferer)
649 );
650
651 assert_eq!(recurrent.output_shapes(&inferer)[0].0, "hidden_output");
652 assert_eq!(recurrent.output_shapes(&inferer)[1].0, "cell_output");
653
654 assert_eq!(recurrent.state.inputs.len(), 0);
655 assert_eq!(recurrent.state.outputs.len(), 2);
656
657 assert_eq!(recurrent.state.keys.len(), 2);
658 assert_eq!(recurrent.state.keys[0].inslot, 1);
660 assert_eq!(recurrent.state.keys[1].inslot, 2);
661 }
662}