cecile-supercool-tracker 0.0.1

Machine learning framework for building object trackers and similarity search engines
Documentation
use crate::prelude::SortTrack;
use crossbeam::channel::{Receiver, Sender};
use log::debug;

use std::collections::HashMap;
use std::sync::{Arc, Mutex};

pub type BatchRecords<T> = HashMap<u64, Vec<T>>;
pub type SceneTracks = (u64, Vec<SortTrack>);

#[derive(Debug, Clone)]
pub struct PredictionBatchRequest<T> {
    batch: BatchRecords<T>,
    sender: Sender<SceneTracks>,
    batch_size: Arc<Mutex<usize>>,
}

#[derive(Clone, Debug)]
pub struct PredictionBatchResult {
    receiver: Receiver<SceneTracks>,
    batch_size: Arc<Mutex<usize>>,
}

impl PredictionBatchResult {
    pub fn ready(&self) -> bool {
        !self.receiver.is_empty()
    }

    pub fn get(&self) -> SceneTracks {
        self.receiver
            .recv()
            .expect("Receiver must always receive batch computation result")
    }

    pub fn batch_size(&self) -> usize {
        *self.batch_size.lock().unwrap()
    }
}

#[cfg(feature = "python")]
pub mod python {
    use crate::trackers::sort::python::PySortTrack;

    use super::PredictionBatchResult;
    use pyo3::prelude::*;

    pub type PySceneTracks = (u64, Vec<PySortTrack>);

    #[pyclass]
    #[derive(Clone, Debug)]
    #[pyo3(name = "PredictionBatchResult")]
    pub struct PyPredictionBatchResult(pub(crate) PredictionBatchResult);

    #[pymethods]
    impl PyPredictionBatchResult {
        pub fn ready(&self) -> bool {
            self.0.ready()
        }

        #[pyo3(signature = ())]
        fn get(&self) -> PySceneTracks {
            Python::with_gil(|py| py.allow_threads(|| unsafe { std::mem::transmute(self.0.get()) }))
        }

        pub fn batch_size(&self) -> usize {
            self.0.batch_size()
        }
    }
}

impl<T> PredictionBatchRequest<T> {
    pub fn get_sender(&self) -> Sender<SceneTracks> {
        self.sender.clone()
    }

    #[allow(dead_code)]
    pub(crate) fn send(&self, res: SceneTracks) -> bool {
        let res = self.sender.send(res);
        if let Err(e) = res {
            debug!(
                "Error occurred when sending results to the batch result object. Error is: {:?}",
                e
            );
            false
        } else {
            true
        }
    }

    pub fn batch_size(&self) -> usize {
        *self.batch_size.lock().unwrap()
    }

    pub fn add(&mut self, scene_id: u64, elt: T) {
        let vec = self.batch.get_mut(&scene_id);
        if let Some(vec) = vec {
            vec.push(elt);
        } else {
            self.batch.insert(scene_id, vec![elt]);
        }
        let mut batch_size = self.batch_size.lock().unwrap();
        *batch_size = self.batch.len();
    }

    pub fn new() -> (Self, PredictionBatchResult) {
        let (sender, receiver) = crossbeam::channel::bounded(1);
        let batch_size = Arc::new(Mutex::new(0));
        (
            Self {
                batch: BatchRecords::default(),
                sender,
                batch_size: batch_size.clone(),
            },
            PredictionBatchResult {
                receiver,
                batch_size,
            },
        )
    }

    pub fn get_batch(&self) -> &BatchRecords<T> {
        &self.batch
    }
}

#[cfg(test)]
mod tests {
    use crate::prelude::Universal2DBox;
    use crate::trackers::batch::PredictionBatchRequest;

    #[test]
    fn test() {
        let (mut request, result) = PredictionBatchRequest::<Universal2DBox>::new();
        request.add(0, Universal2DBox::new(0.0, 0.0, Some(0.5), 1.0, 5.0));
        request.add(0, Universal2DBox::new(5.0, 5.0, Some(0.0), 1.5, 10.0));
        request.add(1, Universal2DBox::new(0.0, 0.0, Some(1.0), 0.7, 5.1));
        let _batch = request.get_batch();
        assert_eq!(result.batch_size(), 2);

        assert!(request.send((0, vec![])));
        assert_eq!(result.ready(), true);
        let res = result.get();
        assert_eq!(res.0, 0);
        assert!(res.1.is_empty());
        drop(result);
        assert!(!request.send((0, vec![])));
    }
}