cervo_core/
recurrent.rs

1use std::collections::HashMap;
2
3use crate::{batcher::ScratchPadView, inferer::Inferer, prelude::InfererWrapper};
4use anyhow::{Context, Result};
5use itertools::Itertools;
6use parking_lot::RwLock;
7use tract_core::tract_data::TVec;
8
9pub struct RecurrentInfo {
10    pub inkey: String,
11    pub outkey: String,
12}
13
14struct RecurrentPair {
15    inslot: usize,
16    outslot: usize,
17    numels: usize,
18    offset: usize,
19}
20
21struct RecurrentState {
22    keys: TVec<RecurrentPair>,
23    per_agent_states: RwLock<HashMap<u64, Box<[f32]>>>,
24    agent_state_size: usize,
25    // https://github.com/EmbarkStudios/cervo/issues/31
26    inputs: Vec<(String, Vec<usize>)>,
27    outputs: Vec<(String, Vec<usize>)>,
28}
29
30impl RecurrentState {
31    fn apply(&self, batch: &mut ScratchPadView<'_>) {
32        for pair in &self.keys {
33            let (ids, indata) = batch.input_slot_mut_with_id(pair.inslot);
34
35            let mut offset = 0;
36            let states = self.per_agent_states.read();
37            for id in ids {
38                // if None, leave as zeros and pray
39                if let Some(state) = states.get(id) {
40                    indata[offset..offset + pair.numels]
41                        .copy_from_slice(&state[pair.offset..pair.offset + pair.numels]);
42                } else {
43                    indata[offset..offset + pair.numels].fill(0.0);
44                }
45                offset += pair.numels;
46            }
47        }
48    }
49
50    fn extract(&self, batch: &mut ScratchPadView<'_>) {
51        for pair in &self.keys {
52            let (ids, outdata) = batch.output_slot_mut_with_id(pair.outslot);
53
54            let mut offset = 0;
55            let mut states = self.per_agent_states.write();
56            for id in ids {
57                // if None, leave as zeros and pray
58                if let Some(state) = states.get_mut(id) {
59                    state[pair.offset..pair.offset + pair.numels]
60                        .copy_from_slice(&outdata[offset..offset + pair.numels]);
61                }
62
63                offset += pair.numels;
64            }
65        }
66    }
67}
68
69/// The [`RecurrentTracker`] wraps an inferer to manage states that
70/// are input/output in a recurrent fashion, instead of roundtripping
71/// them to the high-level code.
72pub struct RecurrentTracker<T: Inferer> {
73    inner: T,
74    state: RecurrentState,
75}
76
77impl<T> RecurrentTracker<T>
78where
79    T: Inferer,
80{
81    /// Wraps the provided `inferer` to automatically track any keys that are both inputs/outputs.
82    pub fn wrap(inferer: T) -> Result<RecurrentTracker<T>> {
83        let inputs = inferer.raw_input_shapes();
84        let outputs = inferer.raw_output_shapes();
85
86        let mut keys = vec![];
87
88        for (inkey, inshape) in inputs {
89            for (outkey, outshape) in outputs {
90                if inkey == outkey && inshape == outshape {
91                    keys.push(RecurrentInfo {
92                        inkey: inkey.clone(),
93                        outkey: outkey.clone(),
94                    });
95                }
96            }
97        }
98
99        if keys.is_empty() {
100            let inkeys = inputs.iter().map(|(k, _)| k).join(", ");
101            let outkeys = outputs.iter().map(|(k, _)| k).join(", ");
102            anyhow::bail!(
103                "Unable to find a matching key between inputs [{inkeys}] and outputs [{outkeys}]"
104            );
105        }
106        Self::new(inferer, keys)
107    }
108
109    /// Create a new recurrency tracker for the model.
110    ///
111    pub fn new(inferer: T, info: Vec<RecurrentInfo>) -> Result<Self> {
112        let raw_inputs = inferer.raw_input_shapes();
113        let raw_outputs = inferer.raw_output_shapes();
114
115        let mut offset = 0;
116        let keys = info
117            .iter()
118            .map(|info| {
119                let inslot = raw_inputs
120                    .iter()
121                    .position(|input| info.inkey == input.0)
122                    .with_context(|| format!("no input named {}", info.inkey))?;
123                let outslot = raw_outputs
124                    .iter()
125                    .position(|output| info.outkey == output.0)
126                    .with_context(|| format!("no output named {}", info.outkey))?;
127
128                let numels = raw_inputs[inslot].1.iter().product();
129                offset += numels;
130                Ok(RecurrentPair {
131                    inslot,
132                    outslot,
133                    numels,
134                    offset: offset - numels,
135                })
136            })
137            .collect::<Result<TVec<RecurrentPair>>>()?;
138
139        let inputs = inferer.input_shapes();
140        let outputs = inferer.output_shapes();
141
142        let inputs = inputs
143            .iter()
144            .filter(|(k, _)| !info.iter().any(|info| &info.inkey == k))
145            .cloned()
146            .collect::<Vec<_>>();
147
148        let outputs = outputs
149            .iter()
150            .filter(|(k, _)| !info.iter().any(|info| &info.outkey == k))
151            .cloned()
152            .collect::<Vec<_>>();
153
154        Ok(Self {
155            inner: inferer,
156            state: RecurrentState {
157                keys,
158                agent_state_size: offset,
159                inputs,
160                outputs,
161                per_agent_states: Default::default(),
162            },
163        })
164    }
165}
166
167impl<T> Inferer for RecurrentTracker<T>
168where
169    T: Inferer,
170{
171    fn select_batch_size(&self, max_count: usize) -> usize {
172        self.inner.select_batch_size(max_count)
173    }
174
175    fn infer_raw(&self, batch: &mut ScratchPadView<'_>) -> Result<(), anyhow::Error> {
176        self.state.apply(batch);
177
178        self.inner.infer_raw(batch)?;
179
180        self.state.extract(batch);
181
182        Ok(())
183    }
184
185    fn raw_input_shapes(&self) -> &[(String, Vec<usize>)] {
186        self.inner.raw_input_shapes()
187    }
188
189    fn raw_output_shapes(&self) -> &[(String, Vec<usize>)] {
190        self.inner.raw_output_shapes()
191    }
192
193    fn input_shapes(&self) -> &[(String, Vec<usize>)] {
194        &self.state.inputs
195    }
196
197    fn output_shapes(&self) -> &[(String, Vec<usize>)] {
198        &self.state.outputs
199    }
200
201    fn begin_agent(&self, id: u64) {
202        self.state.per_agent_states.write().insert(
203            id,
204            vec![0.0; self.state.agent_state_size].into_boxed_slice(),
205        );
206        self.inner.begin_agent(id);
207    }
208
209    fn end_agent(&self, id: u64) {
210        self.state.per_agent_states.write().remove(&id);
211        self.inner.end_agent(id);
212    }
213}
214
215/// A wrapper that adds recurrent state tracking to an inner model.
216///
217/// This is an alternative to using [`RecurrentTracker`] which allows separate
218/// state tracking from the inferer itself.
219pub struct RecurrentTrackerWrapper<Inner: InfererWrapper> {
220    inner: Inner,
221    state: RecurrentState,
222}
223
224impl<Inner: InfererWrapper> RecurrentTrackerWrapper<Inner> {
225    /// Wraps the provided `inferer` to automatically track any keys that are both inputs/outputs.
226    pub fn wrap<T: Inferer>(inner: Inner, inferer: &T) -> Result<RecurrentTrackerWrapper<Inner>> {
227        let inputs = inferer.raw_input_shapes();
228        let outputs = inferer.raw_output_shapes();
229
230        let mut keys = vec![];
231
232        for (inkey, inshape) in inputs {
233            for (outkey, outshape) in outputs {
234                if inkey == outkey && inshape == outshape {
235                    keys.push(RecurrentInfo {
236                        inkey: inkey.clone(),
237                        outkey: outkey.clone(),
238                    });
239                }
240            }
241        }
242
243        if keys.is_empty() {
244            let inkeys = inputs.iter().map(|(k, _)| k).join(", ");
245            let outkeys = outputs.iter().map(|(k, _)| k).join(", ");
246            anyhow::bail!(
247                "Unable to find a matching key between inputs [{inkeys}] and outputs [{outkeys}]"
248            );
249        }
250        Self::new(inner, inferer, keys)
251    }
252
253    /// Create a new recurrency tracker for the model.
254    ///
255    pub fn new<T: Inferer>(inner: Inner, inferer: &T, info: Vec<RecurrentInfo>) -> Result<Self> {
256        let raw_inputs = inferer.raw_input_shapes();
257        let raw_outputs = inferer.raw_output_shapes();
258
259        let mut offset = 0;
260        let keys = info
261            .iter()
262            .map(|info| {
263                let inslot = raw_inputs
264                    .iter()
265                    .position(|input| info.inkey == input.0)
266                    .with_context(|| format!("no input named {}", info.inkey))?;
267                let outslot = raw_outputs
268                    .iter()
269                    .position(|output| info.outkey == output.0)
270                    .with_context(|| format!("no output named {}", info.outkey))?;
271
272                let numels = raw_inputs[inslot].1.iter().product();
273                offset += numels;
274                Ok(RecurrentPair {
275                    inslot,
276                    outslot,
277                    numels,
278                    offset: offset - numels,
279                })
280            })
281            .collect::<Result<TVec<RecurrentPair>>>()?;
282
283        let inputs = inner.input_shapes(inferer);
284        let outputs = inner.output_shapes(inferer);
285
286        let inputs = inputs
287            .iter()
288            .filter(|(k, _)| !info.iter().any(|info| &info.inkey == k))
289            .cloned()
290            .collect::<Vec<_>>();
291
292        let outputs = outputs
293            .iter()
294            .filter(|(k, _)| !info.iter().any(|info| &info.outkey == k))
295            .cloned()
296            .collect::<Vec<_>>();
297
298        Ok(Self {
299            inner,
300            state: RecurrentState {
301                keys,
302                agent_state_size: offset,
303                inputs,
304                outputs,
305                per_agent_states: Default::default(),
306            },
307        })
308    }
309}
310
311impl<Inner: InfererWrapper> InfererWrapper for RecurrentTrackerWrapper<Inner> {
312    fn invoke(&self, inferer: &dyn Inferer, batch: &mut ScratchPadView<'_>) -> anyhow::Result<()> {
313        self.state.apply(batch);
314        self.inner.invoke(inferer, batch)?;
315        self.state.extract(batch);
316
317        Ok(())
318    }
319
320    fn input_shapes<'a>(&'a self, _inferer: &'a dyn Inferer) -> &'a [(String, Vec<usize>)] {
321        self.state.inputs.as_ref()
322    }
323
324    fn output_shapes<'a>(&'a self, _inferer: &'a dyn Inferer) -> &'a [(String, Vec<usize>)] {
325        self.state.outputs.as_ref()
326    }
327
328    fn begin_agent(&self, inferer: &dyn Inferer, id: u64) {
329        self.state.per_agent_states.write().insert(
330            id,
331            vec![0.0; self.state.agent_state_size].into_boxed_slice(),
332        );
333        self.inner.begin_agent(inferer, id);
334    }
335
336    fn end_agent(&self, inferer: &dyn Inferer, id: u64) {
337        self.state.per_agent_states.write().remove(&id);
338        self.inner.end_agent(inferer, id);
339    }
340}
341
342#[cfg(test)]
343mod tests {
344
345    use std::sync::atomic::{AtomicBool, Ordering};
346
347    use crate::{
348        batcher::ScratchPadView,
349        inferer::State,
350        prelude::{Batcher, Inferer},
351        recurrent::RecurrentTrackerWrapper,
352        wrapper::InfererWrapper,
353    };
354
355    use super::RecurrentTracker;
356
357    struct DummyInferer {
358        end_called: AtomicBool,
359        begin_called: AtomicBool,
360        inputs: Vec<(String, Vec<usize>)>,
361        outputs: Vec<(String, Vec<usize>)>,
362    }
363
364    impl Default for DummyInferer {
365        fn default() -> Self {
366            Self::new_named(
367                "lstm_hidden_state",
368                "lstm_cell_state",
369                "lstm_hidden_state",
370                "lstm_cell_state",
371            )
372        }
373    }
374
375    impl DummyInferer {
376        fn new_named(
377            hidden_name_in: &str,
378            cell_name_in: &str,
379            hidden_name_out: &str,
380            cell_name_out: &str,
381        ) -> Self {
382            Self {
383                end_called: false.into(),
384                begin_called: false.into(),
385                inputs: vec![
386                    ("epsilon".to_owned(), vec![2]),
387                    (hidden_name_in.to_owned(), vec![2, 1]),
388                    (cell_name_in.to_owned(), vec![2, 3]),
389                ],
390                outputs: vec![
391                    (hidden_name_out.to_owned(), vec![2, 1]),
392                    (cell_name_out.to_owned(), vec![2, 3]),
393                    ("hidden_output".to_owned(), vec![2]),
394                    ("cell_output".to_owned(), vec![6]),
395                ],
396            }
397        }
398    }
399
400    impl Inferer for DummyInferer {
401        fn select_batch_size(&self, _max_count: usize) -> usize {
402            1
403        }
404
405        fn infer_raw(&self, batch: &mut ScratchPadView<'_>) -> anyhow::Result<(), anyhow::Error> {
406            assert_eq!(batch.inner().input_name(1), "lstm_hidden_state");
407            let hidden_value = batch.input_slot(1);
408            let hidden_new = hidden_value.iter().map(|v| *v + 1.0).collect::<Vec<_>>();
409
410            assert_eq!(batch.inner().output_name(0), "lstm_hidden_state");
411            batch.output_slot_mut(0).copy_from_slice(&hidden_new);
412
413            assert_eq!(batch.inner().input_name(2), "lstm_cell_state");
414            let cell_value = batch.input_slot(2);
415            let cell_new = cell_value.iter().map(|v| *v + 2.0).collect::<Vec<_>>();
416
417            assert_eq!(batch.inner().output_name(1), "lstm_cell_state");
418            batch.output_slot_mut(1).copy_from_slice(&cell_new);
419
420            assert_eq!(batch.inner().output_name(2), "hidden_output");
421            let hidden = batch.output_slot_mut(2);
422            hidden.copy_from_slice(&hidden_new);
423
424            assert_eq!(batch.inner().output_name(3), "cell_output");
425            let cell = batch.output_slot_mut(3);
426            cell.copy_from_slice(&cell_new);
427
428            Ok(())
429        }
430
431        fn raw_input_shapes(&self) -> &[(String, Vec<usize>)] {
432            &self.inputs
433        }
434
435        fn raw_output_shapes(&self) -> &[(String, Vec<usize>)] {
436            &self.outputs
437        }
438
439        fn begin_agent(&self, _id: u64) {
440            self.begin_called.store(true, Ordering::Relaxed);
441        }
442        fn end_agent(&self, _id: u64) {
443            self.end_called.store(true, Ordering::Relaxed);
444        }
445    }
446
447    #[test]
448    fn begin_end_forwarded() {
449        let inferer = DummyInferer::default();
450        let recurrent = RecurrentTracker::wrap(inferer).unwrap();
451
452        recurrent.begin_agent(10);
453        assert!(recurrent.inner.begin_called.load(Ordering::Relaxed));
454
455        recurrent.end_agent(10);
456        assert!(recurrent.inner.end_called.into_inner());
457    }
458
459    #[test]
460    fn begin_creates_state() {
461        let inferer = DummyInferer::default();
462        let recurrent = RecurrentTracker::wrap(inferer).unwrap();
463
464        recurrent.begin_agent(10);
465        assert!(recurrent.state.per_agent_states.read().contains_key(&10));
466    }
467
468    #[test]
469    fn end_removes_state() {
470        let inferer = DummyInferer::default();
471        let recurrent = RecurrentTracker::wrap(inferer).unwrap();
472
473        recurrent.begin_agent(10);
474        recurrent.end_agent(10);
475
476        assert!(!recurrent.state.per_agent_states.read().contains_key(&10));
477    }
478
479    #[test]
480    fn wrap_warns_no_keys() {
481        let inferer = DummyInferer::new_named("a", "b", "c", "d");
482        let should_err = RecurrentTracker::wrap(inferer);
483        assert!(should_err.is_err());
484    }
485
486    #[test]
487    fn test_infer() {
488        let inferer = DummyInferer::default();
489        let mut batcher = Batcher::new(&inferer);
490        let recurrent = RecurrentTracker::wrap(inferer).unwrap();
491
492        recurrent.begin_agent(10);
493        batcher.push(10, State::empty()).unwrap();
494
495        batcher.execute(&recurrent).unwrap();
496    }
497
498    #[test]
499    fn test_infer_output() {
500        let inferer = DummyInferer::default();
501        let mut batcher = Batcher::new(&inferer);
502        let recurrent = RecurrentTracker::wrap(inferer).unwrap();
503
504        recurrent.begin_agent(10);
505        batcher.push(10, State::empty()).unwrap();
506
507        let res = batcher.execute(&recurrent).unwrap();
508        let agent_data = &res[&10];
509        assert!(agent_data.data.contains_key("hidden_output"));
510        assert!(agent_data.data.contains_key("cell_output"));
511
512        assert!(agent_data.data["hidden_output"].iter().all(|v| *v == 1.0));
513        assert!(agent_data.data["cell_output"].iter().all(|v| *v == 2.0));
514    }
515
516    #[test]
517    fn test_infer_twice_output() {
518        let inferer = DummyInferer::default();
519        let mut batcher = Batcher::new(&inferer);
520        let recurrent = RecurrentTracker::wrap(inferer).unwrap();
521
522        recurrent.begin_agent(10);
523        batcher.push(10, State::empty()).unwrap();
524
525        batcher.execute(&recurrent).unwrap();
526        batcher.push(10, State::empty()).unwrap();
527
528        let res = batcher.execute(&recurrent).unwrap();
529
530        let agent_data = &res[&10];
531        assert!(agent_data.data.contains_key("hidden_output"));
532        assert!(agent_data.data.contains_key("cell_output"));
533
534        assert!(agent_data.data["hidden_output"].iter().all(|v| *v == 2.0));
535        assert!(agent_data.data["cell_output"].iter().all(|v| *v == 4.0));
536    }
537
538    #[test]
539    fn test_infer_twice_reuse_id() {
540        let inferer = DummyInferer::default();
541        let mut batcher = Batcher::new(&inferer);
542        let recurrent = RecurrentTracker::wrap(inferer).unwrap();
543
544        recurrent.begin_agent(10);
545        batcher.push(10, State::empty()).unwrap();
546        batcher.execute(&recurrent).unwrap();
547
548        recurrent.end_agent(10);
549
550        recurrent.begin_agent(10);
551
552        batcher.push(10, State::empty()).unwrap();
553
554        let res = batcher.execute(&recurrent).unwrap();
555        let agent_data = &res[&10];
556
557        assert!(agent_data.data.contains_key("hidden_output"));
558        assert!(agent_data.data.contains_key("cell_output"));
559
560        assert!(agent_data.data["hidden_output"].iter().all(|v| *v == 1.0));
561        assert!(agent_data.data["cell_output"].iter().all(|v| *v == 2.0));
562    }
563
564    #[test]
565    fn test_infer_multiple_agents() {
566        let inferer = DummyInferer::default();
567        let mut batcher = Batcher::new(&inferer);
568        let recurrent = RecurrentTracker::wrap(inferer).unwrap();
569
570        recurrent.begin_agent(10);
571        recurrent.begin_agent(20);
572        batcher.push(10, State::empty()).unwrap();
573        batcher.push(20, State::empty()).unwrap();
574        batcher.execute(&recurrent).unwrap();
575
576        recurrent.begin_agent(20);
577        batcher.push(10, State::empty()).unwrap();
578        batcher.push(20, State::empty()).unwrap();
579        batcher.execute(&recurrent).unwrap();
580
581        recurrent.begin_agent(30);
582        batcher.push(10, State::empty()).unwrap();
583        batcher.push(30, State::empty()).unwrap();
584        let res = batcher.execute(&recurrent).unwrap();
585        let agent_data = &res[&10];
586
587        assert!(agent_data.data.contains_key("hidden_output"));
588        assert!(agent_data.data.contains_key("cell_output"));
589
590        assert!(agent_data.data["hidden_output"].iter().all(|v| *v == 3.0));
591        assert!(agent_data.data["cell_output"].iter().all(|v| *v == 6.0));
592
593        let agent_data = &res[&30];
594
595        assert!(agent_data.data.contains_key("hidden_output"));
596        assert!(agent_data.data.contains_key("cell_output"));
597
598        assert!(agent_data.data["hidden_output"].iter().all(|v| *v == 1.0));
599        assert!(agent_data.data["cell_output"].iter().all(|v| *v == 2.0));
600    }
601
602    #[test]
603    fn test_wrapper_does_not_expose_inner_hidden() {
604        // Imagine Recurrent<Epsilon<...>>. We want to assert that
605        // Recurrent hides its own fields while also not exposing any
606        // fields from the inner epsilon wrapper.
607
608        struct DummyEpsilonWrapper {
609            inputs: Vec<(String, Vec<usize>)>,
610        }
611
612        impl InfererWrapper for DummyEpsilonWrapper {
613            fn invoke(
614                &self,
615                _inferer: &dyn Inferer,
616                _batch: &mut ScratchPadView<'_>,
617            ) -> anyhow::Result<(), anyhow::Error> {
618                Ok(())
619            }
620            fn input_shapes<'a>(&'a self, _inferer: &'a dyn Inferer) -> &'a [(String, Vec<usize>)] {
621                &self.inputs
622            }
623            fn output_shapes<'a>(
624                &'a self,
625                _inferer: &'a dyn Inferer,
626            ) -> &'a [(String, Vec<usize>)] {
627                _inferer.output_shapes()
628            }
629            fn begin_agent(&self, _inferer: &dyn Inferer, _id: u64) {}
630            fn end_agent(&self, _inferer: &dyn Inferer, _id: u64) {}
631        }
632
633        let inferer = DummyInferer::default();
634        let wrapper = DummyEpsilonWrapper {
635            inputs: vec![
636                ("lstm_hidden_state".to_owned(), vec![2, 1]),
637                ("lstm_cell_state".to_owned(), vec![2, 3]),
638            ],
639        };
640
641        let recurrent = RecurrentTrackerWrapper::wrap(wrapper, &inferer).unwrap();
642
643        assert_eq!(recurrent.input_shapes(&inferer).len(), 0);
644        assert_eq!(
645            recurrent.output_shapes(&inferer).len(),
646            2,
647            "only hidden and cell state are recurrent: {:?}",
648            recurrent.output_shapes(&inferer)
649        );
650
651        assert_eq!(recurrent.output_shapes(&inferer)[0].0, "hidden_output");
652        assert_eq!(recurrent.output_shapes(&inferer)[1].0, "cell_output");
653
654        assert_eq!(recurrent.state.inputs.len(), 0);
655        assert_eq!(recurrent.state.outputs.len(), 2);
656
657        assert_eq!(recurrent.state.keys.len(), 2);
658        // slots are still correct despite epsilon being hidden
659        assert_eq!(recurrent.state.keys[0].inslot, 1);
660        assert_eq!(recurrent.state.keys[1].inslot, 2);
661    }
662}