cervo-runtime 0.9.2

Multi-model multi-agent RL runtime for games.
Documentation
// Author: Tom Solberg <tom.solberg@embark-studios.com>
// Copyright © 2022, Tom Solberg, all rights reserved.
// Created: 29 July 2022

use crate::{timing::TimingBucket, AgentId};
use cervo_core::prelude::{Batcher, Inferer, InfererExt, Response, State};
use std::{
    cell::RefCell,
    collections::HashMap,
    time::{Duration, Instant},
};

use crate::{error::CervoError, BrainId};

pub struct ModelState {
    pub(crate) id: BrainId,
    pub(crate) inferer: Box<dyn Inferer + 'static + Send>,
    batcher: RefCell<Batcher>,
    timings: RefCell<Vec<TimingBucket>>,
}

// SAFETY: ModelState is only ever accessed from a single thread, and is never shared between threads
#[allow(unsafe_code)]
unsafe impl Sync for ModelState {}

impl ModelState {
    pub(crate) fn new(id: BrainId, inferer: impl Inferer + 'static + Send) -> Self {
        let batcher = RefCell::new(Batcher::new(&inferer));
        Self {
            id,
            inferer: Box::new(inferer),
            batcher,
            timings: RefCell::new(vec![]),
        }
    }

    pub(crate) fn push(&mut self, agent_id: AgentId, state: State<'_>) -> Result<(), CervoError> {
        let mut batcher = self.batcher.borrow_mut();
        batcher.push(agent_id, state).map_err(CervoError::Internal)
    }

    pub(crate) fn needs_to_execute(&self) -> bool {
        !self.batcher.borrow().is_empty()
    }

    pub(crate) fn estimated_time(&self) -> Duration {
        if self.timings.borrow().is_empty() {
            return Duration::ZERO;
        }

        let size = self.batcher.borrow().len();
        let timings = self.timings.borrow();
        let partition = timings.partition_point(|b| b.size <= size);

        if partition == timings.len() {
            let last = timings.last().unwrap();
            last.scaled_mean(size)
        } else {
            let elem = &timings[partition];
            if elem.size == size {
                elem.mean()
            } else if partition == 0 {
                let elem = &timings[partition];
                elem.scaled_mean(size)
            } else {
                let elem = &timings[partition - 1];
                elem.scaled_mean(size)
            }
        }
    }

    pub(crate) fn can_run_in_time(&self, duration: Duration) -> bool {
        self.estimated_time() <= duration
    }

    pub(crate) fn infer_single<'a>(
        &'a mut self,
        state: State<'_>,
    ) -> Result<Response<'a>, CervoError> {
        let start = Instant::now();
        let mut batcher = self.batcher.borrow_mut();

        let res = if batcher.is_empty() {
            batcher.push(0, state).map_err(CervoError::Internal)?;

            let mut res = batcher
                .execute(&self.inferer)
                .map_err(CervoError::Internal)?;

            res.remove(&0).ok_or_else(|| {
                CervoError::Internal(anyhow::anyhow!(
                    "fatal error, no data when data was expected"
                ))
            })
        } else {
            self.inferer
                .infer_single(state)
                .map_err(CervoError::Internal)
        }?;

        let elapsed = start.elapsed();
        let mut timings = self.timings.borrow_mut();
        if let Some(bucket) = timings.iter_mut().find(|b| b.size == 1) {
            bucket.add(elapsed);
        } else {
            timings.push(TimingBucket::new(1, elapsed));
            timings.sort_by_key(|b| b.size);
        }

        Ok(res)
    }

    pub(crate) fn run(&self) -> Result<HashMap<AgentId, Response<'_>>, CervoError> {
        let mut batcher = self.batcher.borrow_mut();

        if batcher.is_empty() {
            return Ok(HashMap::default());
        }

        let start = Instant::now();
        let batch_size = batcher.len();

        let res = batcher
            .execute(&self.inferer)
            .map_err(CervoError::Internal)?;

        let elapsed = start.elapsed();
        let mut timings = self.timings.borrow_mut();
        if let Some(bucket) = timings.iter_mut().find(|b| b.size == batch_size) {
            bucket.add(elapsed);
        } else {
            timings.push(TimingBucket::new(batch_size, elapsed));
            timings.sort_by_key(|b| b.size);
        }

        Ok(res)
    }
}

#[cfg(test)]
mod tests {
    use std::time::Duration;

    use cervo_core::{
        batcher::ScratchPadView,
        prelude::{Batcher, Inferer, State},
    };

    use super::ModelState;
    use crate::{timing::TimingBucket, BrainId};

    struct DummyInferer;

    impl Inferer for DummyInferer {
        fn select_batch_size(&self, _max_count: usize) -> usize {
            0
        }

        fn infer_raw(&self, _batch: &mut ScratchPadView<'_>) -> anyhow::Result<(), anyhow::Error> {
            Ok(())
        }

        fn raw_input_shapes(&self) -> &[(String, Vec<usize>)] {
            &[]
        }

        fn raw_output_shapes(&self) -> &[(String, Vec<usize>)] {
            &[]
        }

        fn begin_agent(&self, _id: u64) {}
        fn end_agent(&self, _id: u64) {}
    }

    #[test]
    fn can_fit_yes() {
        let batcher = Batcher::new(&DummyInferer).into();
        let state = ModelState {
            id: BrainId(0),
            inferer: Box::new(DummyInferer),
            batcher,
            timings: vec![TimingBucket::new(1, Duration::from_secs(1))].into(),
        };

        state.batcher.borrow_mut().push(0, State::empty()).unwrap();
        assert!(state.can_run_in_time(Duration::from_secs(1)));
    }

    #[test]
    fn can_fit_yes_extrapolate() {
        let batcher = Batcher::new(&DummyInferer).into();
        let state = ModelState {
            id: BrainId(0),
            inferer: Box::new(DummyInferer),
            batcher,
            timings: vec![TimingBucket::new(1, Duration::from_secs(1))].into(),
        };

        state.batcher.borrow_mut().push(0, State::empty()).unwrap();
        state.batcher.borrow_mut().push(0, State::empty()).unwrap();
        assert!(state.can_run_in_time(Duration::from_secs(2)));
    }

    #[test]
    fn can_fit_no_extrapolate() {
        let batcher = Batcher::new(&DummyInferer).into();
        let state = ModelState {
            id: BrainId(0),
            inferer: Box::new(DummyInferer),
            batcher,
            timings: vec![TimingBucket::new(1, Duration::from_secs(1))].into(),
        };

        state.batcher.borrow_mut().push(0, State::empty()).unwrap();
        state.batcher.borrow_mut().push(0, State::empty()).unwrap();
        assert!(state.can_run_in_time(Duration::from_secs(2)));
    }

    #[test]
    fn can_fit_empty() {
        let batcher = Batcher::new(&DummyInferer).into();
        let state = ModelState {
            id: BrainId(0),
            inferer: Box::new(DummyInferer),
            batcher,
            timings: vec![].into(),
        };

        state.batcher.borrow_mut().push(0, State::empty()).unwrap();
        state.batcher.borrow_mut().push(0, State::empty()).unwrap();
        assert!(state.can_run_in_time(Duration::from_secs(2)));
    }

    #[test]
    fn can_fit_yes_after() {
        let batcher = Batcher::new(&DummyInferer).into();
        let state = ModelState {
            id: BrainId(0),
            inferer: Box::new(DummyInferer),
            batcher,
            timings: vec![TimingBucket::new(2, Duration::from_secs(1))].into(),
        };

        state.batcher.borrow_mut().push(0, State::empty()).unwrap();
        assert!(state.can_run_in_time(Duration::from_secs_f32(0.5)));
    }

    #[test]
    fn can_fit_no_after() {
        let batcher = Batcher::new(&DummyInferer).into();
        let state = ModelState {
            id: BrainId(0),
            inferer: Box::new(DummyInferer),
            batcher,
            timings: vec![TimingBucket::new(2, Duration::from_secs(1))].into(),
        };

        state.batcher.borrow_mut().push(0, State::empty()).unwrap();
        assert!(state.can_run_in_time(Duration::from_secs_f32(0.6)));
    }
}