cervo_runtime/
runtime.rs

1// Author: Tom Solberg <tom.solberg@embark-studios.com>
2// Copyright © 2022, Tom Solberg, all rights reserved.
3// Created: 22 September 2022
4
5mod ticket;
6
7use crate::{error::CervoError, state::ModelState, AgentId, BrainId};
8use ticket::Ticket;
9
10use cervo_core::prelude::{Inferer, Response, State};
11#[cfg(feature = "threaded")]
12use rayon::iter::IntoParallelIterator;
13#[cfg(feature = "threaded")]
14use rayon::iter::IntoParallelRefMutIterator;
15#[cfg(feature = "threaded")]
16use rayon::iter::ParallelIterator;
17use std::time::Instant;
18use std::{
19    collections::{BinaryHeap, HashMap},
20    time::Duration,
21};
22
23/// The runtime wraps a multitude of inference models with batching support, and support for time-limited execution.
24pub struct Runtime {
25    models: Vec<ModelState>,
26    queue: BinaryHeap<Ticket>,
27    ticket_generation: u64,
28    brain_generation: u16,
29}
30
31impl Default for Runtime {
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37impl Runtime {
38    /// Create a new empty runtime.
39    pub fn new() -> Self {
40        Self {
41            models: Vec::with_capacity(16),
42            queue: BinaryHeap::with_capacity(16),
43            ticket_generation: 0,
44            brain_generation: 0,
45        }
46    }
47
48    /// Add a new inferer to this runtime. The new infererer will be at the end of the inference queue when using timed inference.
49    pub fn add_inferer(&mut self, inferer: impl Inferer + 'static + Send) -> BrainId {
50        let id = BrainId(self.brain_generation);
51        self.brain_generation += 1;
52
53        self.models.push(ModelState::new(id, inferer));
54
55        // New models always go to head of queue
56        self.queue.push(Ticket(self.ticket_generation, id));
57        self.ticket_generation += 1;
58
59        id
60    }
61
62    /// Queue the `state` to `brain` for `agent`, to be included in the next inference batch.
63    pub fn push(
64        &mut self,
65        brain: BrainId,
66        agent: AgentId,
67        state: State<'_>,
68    ) -> Result<(), CervoError> {
69        match self.models.iter_mut().find(|m| m.id == brain) {
70            Some(model) => model.push(agent, state),
71            None => Err(CervoError::UnknownBrain(brain)),
72        }
73    }
74
75    /// Run a single item through the specific `brain`. If there's
76    /// pending data for the `brain`, this'll have some extra overhead
77    /// for new allocations.
78    pub fn infer_single(
79        &mut self,
80        brain_id: BrainId,
81        state: State<'_>,
82    ) -> Result<Response<'_>, CervoError> {
83        match self.models.iter_mut().find(|m| m.id == brain_id) {
84            Some(model) => model.infer_single(state),
85            None => Err(CervoError::UnknownBrain(brain_id)),
86        }
87    }
88
89    /// Executes all models with queued data in parallel.
90    #[cfg(feature = "threaded")]
91    pub fn run_threaded(&mut self) -> HashMap<BrainId, HashMap<AgentId, Response<'_>>> {
92        // Use the iterator method from rayon
93        self.models
94            .par_iter_mut()
95            .filter(|model| model.needs_to_execute())
96            .map(|model| (model.id, model.run().unwrap()))
97            .collect::<HashMap<BrainId, HashMap<AgentId, Response<'_>>>>()
98    }
99
100    /// Executes all models with queued data.
101    pub fn run(&mut self) -> Result<HashMap<BrainId, HashMap<AgentId, Response<'_>>>, CervoError> {
102        let mut result = HashMap::default();
103
104        for model in self.models.iter_mut() {
105            if !model.needs_to_execute() {
106                continue;
107            }
108
109            result.insert(model.id, model.run()?);
110        }
111
112        Ok(result)
113    }
114
115    /// Executes all models with queued data in parallel. Will attempt to keep
116    /// total time below the provided duration, but due to noise or lack
117    /// of samples might miss the deadline. See the note in [the root](./index.html).
118    #[cfg(feature = "threaded")]
119    pub fn run_for_threaded(
120        &mut self,
121        duration: Duration,
122    ) -> Result<HashMap<BrainId, HashMap<AgentId, Response<'_>>>, CervoError> {
123        let mut available_cpu_time = duration * rayon::current_num_threads() as u32;
124        let mut selected_jobs = Vec::new();
125        let mut unselected_jobs = Vec::new();
126
127        while let Some(ticket) = self.queue.pop() {
128            let Some(model) = self.models.iter().find(|m| m.id == ticket.1) else {
129                continue;
130            };
131
132            if model.needs_to_execute()
133                && (selected_jobs.is_empty() || model.can_run_in_time(available_cpu_time))
134            {
135                available_cpu_time = available_cpu_time.saturating_sub(model.estimated_time());
136                selected_jobs.push((ticket, model));
137            } else {
138                unselected_jobs.push(ticket);
139            }
140        }
141
142        let results = selected_jobs
143            .into_par_iter()
144            .map(|(ticket, model)| (ticket.1, model.run()))
145            .collect::<Vec<(_, _)>>(); // collect necessary to conserve ordering
146
147        let new_tickets = results.iter().map(|(b, _)| {
148            let gen = self.ticket_generation;
149            self.ticket_generation += 1;
150            Ticket(gen, *b)
151        });
152
153        self.queue
154            .extend(unselected_jobs.into_iter().chain(new_tickets));
155
156        // transpose Iter<(B, Res<V, E>)> into Iter<Res<(B, Val)>> before collecting
157        results
158            .into_iter()
159            .map(|(b, res)| res.map(|val| (b, val)))
160            .collect::<Result<_, _>>()
161    }
162
163    /// Executes all models with queued data. Will attempt to keep
164    /// total time below the provided duration, but due to noise or lack
165    /// of samples might miss the deadline. See the note in [the root](./index.html).
166    pub fn run_for(
167        &mut self,
168        mut duration: Duration,
169    ) -> Result<HashMap<BrainId, HashMap<AgentId, Response<'_>>>, CervoError> {
170        let mut result = HashMap::default();
171
172        let mut any_executed = false;
173        let mut executed: Vec<BrainId> = vec![];
174        let mut non_executed = vec![];
175
176        while !self.queue.is_empty() {
177            let ticket = self.queue.pop().unwrap();
178            let res = match self.models.iter().find(|m| m.id == ticket.1) {
179                Some(model) => {
180                    if !model.needs_to_execute() || any_executed && !model.can_run_in_time(duration)
181                    {
182                        Ok(None)
183                    } else {
184                        let start = Instant::now();
185                        let r = model.run();
186
187                        let elapsed = start.elapsed();
188                        duration = duration.saturating_sub(elapsed);
189
190                        any_executed = true;
191                        r.map(Some)
192                    }
193                }
194
195                None => return Err(CervoError::UnknownBrain(ticket.1)),
196            }?;
197
198            match res {
199                Some(res) => {
200                    result.insert(ticket.1, res);
201                    executed.push(ticket.1);
202                }
203                None => {
204                    non_executed.push(ticket);
205                }
206            }
207        }
208
209        self.queue.extend(non_executed);
210        for id in executed {
211            let gen = self.ticket_generation;
212            self.ticket_generation += 1;
213            self.queue.push(Ticket(gen, id));
214        }
215
216        Ok(result)
217    }
218
219    /// Retrieve the output shapes for the provided brain.
220    pub fn output_shapes(&self, brain: BrainId) -> Result<&[(String, Vec<usize>)], CervoError> {
221        match self.models.iter().find(|m| m.id == brain) {
222            Some(model) => Ok(model.inferer.output_shapes()),
223            None => Err(CervoError::UnknownBrain(brain)),
224        }
225    }
226
227    /// Retrieve the input shapes for the provided brain.
228    pub fn input_shapes(&self, brain: BrainId) -> Result<&[(String, Vec<usize>)], CervoError> {
229        match self.models.iter().find(|m| m.id == brain) {
230            Some(model) => Ok(model.inferer.input_shapes()),
231            None => Err(CervoError::UnknownBrain(brain)),
232        }
233    }
234
235    /// Clear all models and all related data. Will error (after
236    /// clearing *all* data) if there was queued items that are now
237    /// orphaned.
238    pub fn clear(&mut self) -> Result<(), CervoError> {
239        // N.b. we don't clear brain generation; to avoid generational issues.
240        self.queue.clear();
241        self.ticket_generation = 0;
242
243        let mut has_data = vec![];
244        for model in self.models.drain(..) {
245            if model.needs_to_execute() {
246                has_data.push(model.id);
247            }
248        }
249
250        if !has_data.is_empty() {
251            Err(CervoError::OrphanedData(has_data))
252        } else {
253            Ok(())
254        }
255    }
256
257    /// Clear a model and related data. Will error (after clearing
258    /// *all* data) if there was queued items that are now orphaned.
259    pub fn remove_inferer(&mut self, brain: BrainId) -> Result<(), CervoError> {
260        // TODO[TSolberg]: when BinaryHeap::retain is stabilized, use that here.
261        let mut to_repush = vec![];
262        while !self.queue.is_empty() {
263            // Safety: ^ must contain 1 item
264            let elem = self.queue.pop().unwrap();
265
266            if elem.1 == brain {
267                break;
268            } else {
269                to_repush.push(elem);
270            }
271        }
272
273        self.queue.extend(to_repush);
274
275        if let Some(index) = self.models.iter().position(|state| state.id == brain) {
276            // Safety: ^ we just found the index.
277            let state = self.models.remove(index);
278            if state.needs_to_execute() {
279                Err(CervoError::OrphanedData(vec![brain]))
280            } else {
281                Ok(())
282            }
283        } else {
284            Err(CervoError::UnknownBrain(brain))
285        }
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::Runtime;
292    use crate::{BrainId, CervoError};
293    use cervo_core::prelude::{Inferer, State};
294    use std::time::Duration;
295
296    struct DummyInferer {
297        sleep_duration: Duration,
298    }
299
300    impl Inferer for DummyInferer {
301        fn select_batch_size(&self, count: usize) -> usize {
302            assert_eq!(count, 1);
303            count
304        }
305
306        fn infer_raw(
307            &self,
308            _batch: &mut cervo_core::batcher::ScratchPadView<'_>,
309        ) -> anyhow::Result<(), anyhow::Error> {
310            std::thread::sleep(self.sleep_duration);
311            Ok(())
312        }
313
314        fn raw_input_shapes(&self) -> &[(String, Vec<usize>)] {
315            &[]
316        }
317
318        fn raw_output_shapes(&self) -> &[(String, Vec<usize>)] {
319            &[]
320        }
321
322        fn begin_agent(&self, _id: u64) {}
323        fn end_agent(&self, _id: u64) {}
324    }
325
326    #[test]
327    fn test_run_for_rotation() {
328        let mut runtime = Runtime::new();
329        let mut keys = vec![];
330        for sleep in [0.02, 0.04, 0.06, 0.04] {
331            keys.push(runtime.add_inferer(DummyInferer {
332                sleep_duration: Duration::from_secs_f32(sleep),
333            }));
334        }
335
336        let push = |runtime: &mut Runtime, keys: &[BrainId]| {
337            for k in keys {
338                runtime.push(*k, 0, State::empty()).unwrap();
339            }
340        };
341
342        for _ in 0..10 {
343            push(&mut runtime, &keys);
344            runtime.run().unwrap();
345        }
346
347        push(&mut runtime, &keys);
348        let res = runtime.run_for(Duration::from_secs_f32(0.07)).unwrap();
349        assert_eq!(res.len(), 2, "got keys: {:?}", res.keys());
350        assert!(res.contains_key(&keys[0]));
351        assert!(res.contains_key(&keys[1]));
352
353        let res = runtime.run_for(Duration::from_secs_f32(0.07)).unwrap();
354        assert_eq!(res.len(), 1);
355        assert!(res.contains_key(&keys[2]));
356
357        // queue should be 3, 0, 1, 2.
358        // The below can run both 3 and 0 but only 3 has data.
359        let res = runtime.run_for(Duration::from_secs_f32(0.07)).unwrap();
360        assert_eq!(res.len(), 1);
361        assert!(res.contains_key(&keys[3]));
362
363        push(&mut runtime, &keys);
364        let res = runtime.run_for(Duration::from_secs_f32(0.165)).unwrap();
365        assert_eq!(res.len(), 4, "got keys: {:?}", res.keys());
366        assert!(res.contains_key(&keys[0]));
367        assert!(res.contains_key(&keys[1]));
368        assert!(res.contains_key(&keys[2]));
369        assert!(res.contains_key(&keys[3]));
370    }
371
372    #[test]
373    fn test_run_skip_expensive() {
374        let mut runtime = Runtime::new();
375        let mut keys = vec![];
376        for sleep in [0.02, 0.04, 0.06, 0.04] {
377            keys.push(runtime.add_inferer(DummyInferer {
378                sleep_duration: Duration::from_secs_f32(sleep),
379            }));
380        }
381
382        let push = |runtime: &mut Runtime, keys: &[BrainId]| {
383            for k in keys {
384                runtime.push(*k, 0, State::empty()).unwrap();
385            }
386        };
387
388        for _ in 0..10 {
389            push(&mut runtime, &keys);
390            runtime.run().unwrap();
391        }
392
393        push(&mut runtime, &keys);
394        let res = runtime.run_for(Duration::from_secs_f32(0.11)).unwrap();
395        assert_eq!(res.len(), 3, "got keys: {:?}", res.keys());
396        assert!(res.contains_key(&keys[0]));
397        assert!(res.contains_key(&keys[1]));
398        assert!(res.contains_key(&keys[3]));
399    }
400
401    #[test]
402    fn test_run_for_greedy() {
403        let mut runtime = Runtime::new();
404        let mut keys = vec![];
405        for sleep in [0.02, 0.04, 0.06] {
406            keys.push(runtime.add_inferer(DummyInferer {
407                sleep_duration: Duration::from_secs_f32(sleep),
408            }));
409        }
410
411        let push = |runtime: &mut Runtime, keys: &[BrainId]| {
412            for k in keys {
413                runtime.push(*k, 0, State::empty()).unwrap();
414            }
415        };
416
417        for _ in 0..10 {
418            push(&mut runtime, &keys);
419            runtime.run().unwrap();
420        }
421
422        push(&mut runtime, &keys);
423        let res = runtime.run_for(Duration::from_secs_f32(0.0)).unwrap();
424        assert_eq!(res.len(), 1, "got keys: {:?}", res.keys());
425        assert!(res.contains_key(&keys[0]));
426
427        // queue should be 1, 2, 0
428        let res = runtime.run_for(Duration::from_secs_f32(0.0)).unwrap();
429        assert_eq!(res.len(), 1);
430        assert!(res.contains_key(&keys[1]));
431
432        // queue should be 2, 1, 0
433        let res = runtime.run_for(Duration::from_secs_f32(0.0)).unwrap();
434        assert_eq!(res.len(), 1);
435        assert!(res.contains_key(&keys[2]));
436    }
437
438    #[test]
439    fn test_run_single() {
440        let mut runtime = Runtime::new();
441
442        let k = runtime.add_inferer(DummyInferer {
443            sleep_duration: Duration::from_secs_f32(0.01),
444        });
445
446        runtime.infer_single(k, State::empty()).unwrap();
447        let r = runtime.run().unwrap();
448        assert_eq!(r.len(), 0);
449    }
450
451    #[test]
452    fn test_run_single_with_push() {
453        let mut runtime = Runtime::new();
454
455        let k = runtime.add_inferer(DummyInferer {
456            sleep_duration: Duration::from_secs_f32(0.01),
457        });
458
459        runtime.push(k, 0, State::empty()).unwrap();
460
461        runtime.infer_single(k, State::empty()).unwrap();
462        let mut r = runtime.run().unwrap();
463        assert_eq!(r.len(), 1);
464        let data = r.remove(&k).unwrap();
465
466        assert_eq!(data.len(), 1);
467        assert!(data.contains_key(&0));
468    }
469
470    #[test]
471    fn unknown_brain_push() {
472        let mut runtime = Runtime::new();
473        let res = runtime.push(BrainId(10), 0, State::empty());
474
475        assert!(res.is_err());
476        let err = res.unwrap_err();
477
478        if let CervoError::UnknownBrain(BrainId(10)) = err {
479        } else {
480            panic!("expected CervoError::UnknownBrain")
481        }
482    }
483
484    #[test]
485    fn unknown_brain_infer_single() {
486        let mut runtime = Runtime::new();
487        let res = runtime.infer_single(BrainId(10), State::empty());
488
489        assert!(res.is_err());
490        let err = res.unwrap_err();
491
492        if let CervoError::UnknownBrain(BrainId(10)) = err {
493        } else {
494            panic!("expected CervoError::UnknownBrain")
495        }
496    }
497
498    #[test]
499    fn unknown_brain_remove() {
500        let mut runtime = Runtime::new();
501        let res = runtime.remove_inferer(BrainId(10));
502
503        assert!(res.is_err());
504        let err = res.unwrap_err();
505
506        if let CervoError::UnknownBrain(BrainId(10)) = err {
507        } else {
508            panic!("expected CervoError::UnknownBrain")
509        }
510    }
511
512    #[test]
513    fn unknown_brain_remove_orphaned() {
514        let mut runtime = Runtime::new();
515        let k = runtime.add_inferer(DummyInferer {
516            sleep_duration: Duration::from_secs_f32(0.1),
517        });
518        runtime.push(k, 0, State::empty()).unwrap();
519        let res = runtime.remove_inferer(k);
520
521        assert!(res.is_err());
522        let err = res.unwrap_err();
523
524        if let CervoError::OrphanedData(keys) = err {
525            assert_eq!(keys, vec![k]);
526        } else {
527            panic!("expected CervoError::OrphanedData")
528        }
529    }
530
531    #[test]
532    fn unknown_brain_clear_orphaned() {
533        let mut runtime = Runtime::new();
534        let k = runtime.add_inferer(DummyInferer {
535            sleep_duration: Duration::from_secs_f32(0.1),
536        });
537        runtime.push(k, 0, State::empty()).unwrap();
538        let res = runtime.clear();
539
540        assert!(res.is_err());
541        let err = res.unwrap_err();
542
543        if let CervoError::OrphanedData(keys) = err {
544            assert_eq!(keys, vec![k]);
545        } else {
546            panic!("expected CervoError::OrphanedData")
547        }
548    }
549}