similari 0.21.4

Machine learning framework for building object trackers and similarity search engines
Documentation
use crate::trackers::sort::SortAttributesOptions;
use crate::trackers::visual::metric::builder::VisualMetricBuilder;
use crate::trackers::visual::metric::{
    PositionalMetricType, PyPositionalMetricType, PyVisualMetricType, VisualMetric,
    VisualMetricType,
};
use pyo3::prelude::*;
use std::collections::HashMap;
use std::sync::RwLock;

#[pyclass]
#[derive(Debug, Clone)]
pub struct VisualSortOptions {
    max_idle_epochs: usize,
    kept_history_length: usize,
    metric_builder: VisualMetricBuilder,
}

impl VisualSortOptions {
    pub fn build(self) -> (SortAttributesOptions, VisualMetric) {
        (
            SortAttributesOptions::new(
                Some(RwLock::new(HashMap::default())),
                self.max_idle_epochs,
                self.kept_history_length,
            ),
            self.metric_builder.build(),
        )
    }

    pub fn max_idle_epochs(mut self, n: usize) -> Self {
        self.max_idle_epochs = n;
        self
    }

    pub fn kept_history_length(mut self, n: usize) -> Self {
        self.kept_history_length = n;
        self
    }

    pub fn visual_metric(mut self, metric: VisualMetricType) -> Self {
        self.metric_builder = self.metric_builder.visual_metric(metric);
        self
    }

    pub fn visual_min_votes(mut self, n: usize) -> Self {
        self.metric_builder = self.metric_builder.visual_min_votes(n);
        self
    }

    pub fn positional_metric(mut self, metric: PositionalMetricType) -> Self {
        self.metric_builder = self.metric_builder.positional_metric(metric);
        self
    }

    pub fn visual_minimal_track_length(mut self, length: usize) -> Self {
        self.metric_builder = self.metric_builder.visual_minimal_track_length(length);
        self
    }

    pub fn visual_minimal_area(mut self, area: f32) -> Self {
        self.metric_builder = self.metric_builder.visual_minimal_area(area);
        self
    }

    pub fn visual_minimal_quality_use(mut self, q: f32) -> Self {
        self.metric_builder = self.metric_builder.visual_minimal_quality_use(q);
        self
    }

    pub fn visual_max_observations(mut self, n: usize) -> Self {
        self.metric_builder = self.metric_builder.visual_max_observations(n);
        self
    }

    pub fn visual_minimal_quality_collect(mut self, q: f32) -> Self {
        self.metric_builder = self.metric_builder.visual_minimal_quality_collect(q);
        self
    }
}

impl Default for VisualSortOptions {
    fn default() -> Self {
        Self {
            max_idle_epochs: 2,
            kept_history_length: 10,
            metric_builder: VisualMetricBuilder::default(),
        }
    }
}

#[pymethods]
impl VisualSortOptions {
    #[new]
    fn new() -> Self {
        Self::default()
    }

    #[pyo3(name = "max_idle_epochs", text_signature = "($self, n)")]
    fn max_idle_epochs_py(&mut self, n: i64) {
        self.max_idle_epochs = n.try_into().expect("Parameter must be a positive number");
    }

    #[pyo3(name = "kept_history_length", text_signature = "($self, n)")]
    fn kept_history_length_py(&mut self, n: i64) {
        self.kept_history_length = n.try_into().expect("Parameter must be a positive number");
    }

    #[pyo3(name = "visual_min_votes", text_signature = "($self, n)")]
    fn visual_min_votes_py(&mut self, n: i64) {
        self.metric_builder.visual_min_votes_py(n);
    }

    #[pyo3(name = "visual_metric", text_signature = "($self, metric)")]
    fn visual_metric_py(&mut self, metric: PyVisualMetricType) {
        self.metric_builder.visual_metric_py(metric);
    }

    #[pyo3(name = "positional_metric", text_signature = "($self, metric)")]
    fn positional_metric_py(&mut self, metric: PyPositionalMetricType) {
        self.metric_builder.positional_metric_py(metric.0);
    }

    #[pyo3(
        name = "visual_minimal_track_length",
        text_signature = "($self, length)"
    )]
    fn visual_minimal_track_length_py(&mut self, length: i64) {
        self.metric_builder.visual_minimal_track_length_py(
            length
                .try_into()
                .expect("Parameter must be a positive number"),
        );
    }

    #[pyo3(name = "visual_minimal_area", text_signature = "($self, area)")]
    fn visual_minimal_area_py(&mut self, area: f32) {
        self.metric_builder.visual_minimal_area_py(area);
    }

    #[pyo3(name = "visual_minimal_quality_use", text_signature = "($self, q)")]
    fn visual_minimal_quality_use_py(&mut self, q: f32) {
        self.metric_builder.visual_minimal_quality_use_py(q);
    }

    #[pyo3(name = "visual_max_observations", text_signature = "($self, n)")]
    fn visual_max_observations_py(&mut self, n: i64) {
        self.metric_builder
            .visual_max_observations_py(n.try_into().expect("Parameter must be a positive number"));
    }

    #[pyo3(name = "visual_minimal_quality_collect", text_signature = "($self, q)")]
    fn visual_minimal_quality_collect_py(&mut self, q: f32) {
        self.metric_builder.visual_minimal_quality_collect_py(q);
    }

    #[classattr]
    const __hash__: Option<Py<PyAny>> = None;

    fn __repr__(&self) -> String {
        format!("{:?}", self)
    }

    fn __str__(&self) -> String {
        format!("{:#?}", self)
    }
}

#[cfg(test)]
mod tests {
    use crate::trackers::visual::metric::{
        PositionalMetricType, PyPositionalMetricType, PyVisualMetricType, VisualMetricType,
    };
    use crate::trackers::visual::simple_visual::options::VisualSortOptions;

    #[test]
    fn visual_sort_options_builder() {
        let (opts, metric) = dbg!(VisualSortOptions::new()
            .max_idle_epochs(3)
            .kept_history_length(10)
            .visual_metric(VisualMetricType::Euclidean(100.0))
            .positional_metric(PositionalMetricType::Mahalanobis)
            .visual_minimal_track_length(3)
            .visual_minimal_area(5.0)
            .visual_minimal_quality_use(0.45)
            .visual_minimal_quality_collect(0.5)
            .visual_max_observations(25)
            .visual_min_votes(5)
            .build());

        let mut opts_builder = VisualSortOptions::new();
        opts_builder.max_idle_epochs_py(3);
        opts_builder.kept_history_length_py(10);
        opts_builder.visual_metric_py(PyVisualMetricType::euclidean(100.0));
        opts_builder.positional_metric_py(PyPositionalMetricType::maha());
        opts_builder.visual_minimal_track_length_py(3);
        opts_builder.visual_minimal_area_py(5.0);
        opts_builder.visual_minimal_quality_use_py(0.45);
        opts_builder.visual_minimal_quality_collect_py(0.5);
        opts_builder.visual_max_observations_py(25);
        opts_builder.visual_min_votes_py(5);
        let (opts_py, metric_py) = dbg!(opts_builder.build());

        assert_eq!(format!("{:?}", opts), format!("{:?}", opts_py));
        assert_eq!(format!("{:?}", metric), format!("{:?}", metric_py));
    }
}