use crate::trackers::sort::{PositionalMetricType, PyPositionalMetricType, SortAttributesOptions};
use crate::trackers::spatio_temporal_constraints::SpatioTemporalConstraints;
use crate::trackers::visual_sort::metric::builder::VisualMetricBuilder;
use crate::trackers::visual_sort::metric::{
PyVisualSortMetricType, VisualMetric, VisualSortMetricType,
};
use pyo3::prelude::*;
use std::collections::HashMap;
use std::sync::RwLock;
#[pyclass]
#[derive(Debug, Clone)]
pub struct VisualSortOptions {
max_idle_epochs: usize,
kept_history_length: usize,
spatio_temporal_constraints: SpatioTemporalConstraints,
metric_builder: VisualMetricBuilder,
}
impl VisualSortOptions {
pub(crate) fn build(self) -> (SortAttributesOptions, VisualMetric) {
(
SortAttributesOptions::new(
Some(RwLock::new(HashMap::default())),
self.max_idle_epochs,
self.kept_history_length,
self.spatio_temporal_constraints,
),
self.metric_builder.build(),
)
}
pub fn max_idle_epochs(mut self, n: usize) -> Self {
self.max_idle_epochs = n;
self
}
pub fn kept_history_length(mut self, n: usize) -> Self {
assert!(n > 0, "History length must be a positive number");
self.kept_history_length = n;
self
}
pub fn visual_metric(mut self, metric: VisualSortMetricType) -> Self {
self.metric_builder = self.metric_builder.visual_metric(metric);
self
}
pub fn visual_min_votes(mut self, n: usize) -> Self {
self.metric_builder = self.metric_builder.visual_min_votes(n);
self
}
pub fn visual_max_observations(mut self, n: usize) -> Self {
self.metric_builder = self.metric_builder.visual_max_observations(n);
self
}
pub fn positional_min_confidence(mut self, conf: f32) -> Self {
self.metric_builder = self.metric_builder.positional_min_confidence(conf);
self
}
pub fn spatio_temporal_constraints(mut self, constraints: SpatioTemporalConstraints) -> Self {
self.spatio_temporal_constraints = constraints;
self
}
pub fn positional_metric(mut self, metric: PositionalMetricType) -> Self {
self.metric_builder = self.metric_builder.positional_metric(metric);
self
}
pub fn visual_minimal_track_length(mut self, length: usize) -> Self {
self.metric_builder = self.metric_builder.visual_minimal_track_length(length);
self
}
pub fn visual_minimal_area(mut self, area: f32) -> Self {
self.metric_builder = self.metric_builder.visual_minimal_area(area);
self
}
pub fn visual_minimal_quality_use(mut self, q: f32) -> Self {
self.metric_builder = self.metric_builder.visual_minimal_quality_use(q);
self
}
pub fn visual_minimal_quality_collect(mut self, q: f32) -> Self {
self.metric_builder = self.metric_builder.visual_minimal_quality_collect(q);
self
}
pub fn visual_minimal_own_area_percentage_use(mut self, area: f32) -> Self {
self.metric_builder = self
.metric_builder
.visual_minimal_own_area_percentage_use(area);
self
}
pub fn visual_minimal_own_area_percentage_collect(mut self, area: f32) -> Self {
self.metric_builder = self
.metric_builder
.visual_minimal_own_area_percentage_collect(area);
self
}
}
impl Default for VisualSortOptions {
fn default() -> Self {
Self {
max_idle_epochs: 2,
kept_history_length: 10,
metric_builder: VisualMetricBuilder::default(),
spatio_temporal_constraints: SpatioTemporalConstraints::default(),
}
}
}
#[pymethods]
impl VisualSortOptions {
#[new]
fn new() -> Self {
Self::default()
}
#[pyo3(name = "max_idle_epochs", text_signature = "($self, n)")]
fn max_idle_epochs_py(&mut self, n: i64) {
self.max_idle_epochs = n.try_into().expect("Parameter must be a positive number");
}
#[pyo3(name = "kept_history_length", text_signature = "($self, n)")]
fn kept_history_length_py(&mut self, n: i64) {
self.kept_history_length = n.try_into().expect("Parameter must be a positive number");
}
#[pyo3(name = "visual_min_votes", text_signature = "($self, n)")]
fn visual_min_votes_py(&mut self, n: i64) {
self.metric_builder.visual_min_votes_py(n);
}
#[pyo3(name = "visual_metric", text_signature = "($self, metric)")]
fn visual_metric_py(&mut self, metric: PyVisualSortMetricType) {
self.metric_builder.visual_metric_py(metric);
}
#[pyo3(
name = "spatio_temporal_constraints",
text_signature = "($self, constraints)"
)]
fn spatio_temporal_constraints_py(&mut self, constraints: SpatioTemporalConstraints) {
self.spatio_temporal_constraints = constraints;
}
#[pyo3(name = "positional_metric", text_signature = "($self, metric)")]
fn positional_metric_py(&mut self, metric: PyPositionalMetricType) {
self.metric_builder.positional_metric_py(metric.0);
}
#[pyo3(
name = "visual_minimal_track_length",
text_signature = "($self, length)"
)]
fn visual_minimal_track_length_py(&mut self, length: i64) {
self.metric_builder.visual_minimal_track_length_py(
length
.try_into()
.expect("Parameter must be a positive number"),
);
}
#[pyo3(name = "visual_minimal_area", text_signature = "($self, area)")]
fn visual_minimal_area_py(&mut self, area: f32) {
self.metric_builder.visual_minimal_area_py(area);
}
#[pyo3(name = "visual_minimal_quality_use", text_signature = "($self, q)")]
fn visual_minimal_quality_use_py(&mut self, q: f32) {
self.metric_builder.visual_minimal_quality_use_py(q);
}
#[pyo3(name = "positional_min_confidence", text_signature = "($self, conf)")]
fn positional_min_confidence_py(&mut self, conf: f32) {
self.metric_builder.positional_min_confidence_py(conf);
}
#[pyo3(name = "visual_max_observations", text_signature = "($self, n)")]
fn visual_max_observations_py(&mut self, n: i64) {
self.metric_builder
.visual_max_observations_py(n.try_into().expect("Parameter must be a positive number"));
}
#[pyo3(name = "visual_minimal_quality_collect", text_signature = "($self, q)")]
fn visual_minimal_quality_collect_py(&mut self, q: f32) {
self.metric_builder.visual_minimal_quality_collect_py(q);
}
#[pyo3(
name = "visual_minimal_own_area_percentage_use",
text_signature = "($self, area)"
)]
fn visual_minimal_own_area_percentage_use_py(&mut self, area: f32) {
self.metric_builder
.visual_minimal_own_area_percentage_use_py(area);
}
#[pyo3(
name = "visual_minimal_own_area_percentage_collect",
text_signature = "($self, area)"
)]
fn visual_minimal_own_area_percentage_collect_py(&mut self, area: f32) {
self.metric_builder
.visual_minimal_own_area_percentage_collect_py(area);
}
#[classattr]
const __hash__: Option<Py<PyAny>> = None;
fn __repr__(&self) -> String {
format!("{:?}", self)
}
fn __str__(&self) -> String {
format!("{:#?}", self)
}
}
#[cfg(test)]
mod tests {
use crate::trackers::sort::{PositionalMetricType, PyPositionalMetricType};
use crate::trackers::spatio_temporal_constraints::SpatioTemporalConstraints;
use crate::trackers::visual_sort::metric::{PyVisualSortMetricType, VisualSortMetricType};
use crate::trackers::visual_sort::options::VisualSortOptions;
#[test]
fn visual_sort_options_builder() {
let (opts, metric) = dbg!(VisualSortOptions::new()
.max_idle_epochs(3)
.kept_history_length(10)
.visual_metric(VisualSortMetricType::Euclidean(100.0))
.positional_metric(PositionalMetricType::Mahalanobis)
.visual_minimal_track_length(3)
.visual_minimal_area(5.0)
.visual_minimal_quality_use(0.45)
.visual_minimal_quality_collect(0.5)
.visual_max_observations(25)
.visual_min_votes(5)
.positional_min_confidence(0.13)
.visual_minimal_own_area_percentage_use(0.1)
.visual_minimal_own_area_percentage_collect(0.2)
.spatio_temporal_constraints(
SpatioTemporalConstraints::default().constraints(&[(5, 7.0)])
)
.build());
let mut opts_builder = VisualSortOptions::new();
opts_builder.max_idle_epochs_py(3);
opts_builder.kept_history_length_py(10);
opts_builder.visual_metric_py(PyVisualSortMetricType::euclidean(100.0));
opts_builder.positional_metric_py(PyPositionalMetricType::maha());
opts_builder.visual_minimal_track_length_py(3);
opts_builder.visual_minimal_area_py(5.0);
opts_builder.visual_minimal_quality_use_py(0.45);
opts_builder.visual_minimal_quality_collect_py(0.5);
opts_builder.visual_max_observations_py(25);
opts_builder.positional_min_confidence_py(0.13);
opts_builder.visual_minimal_own_area_percentage_use_py(0.1);
opts_builder.visual_minimal_own_area_percentage_collect_py(0.2);
opts_builder.visual_min_votes_py(5);
let mut constraints = SpatioTemporalConstraints::default();
constraints.add_constraints(vec![(5, 7.0)]);
opts_builder.spatio_temporal_constraints_py(constraints);
let (opts_py, metric_py) = dbg!(opts_builder.build());
assert_eq!(format!("{:?}", opts), format!("{:?}", opts_py));
assert_eq!(format!("{:?}", metric), format!("{:?}", metric_py));
}
}