use crate::prelude::SortTrack;
use crate::trackers::visual::simple_api::options::VisualSortOptions;
use crate::trackers::visual::simple_api::VisualSort;
use crate::trackers::visual::{PyWastedVisualSortTrack, VisualObservation};
use crate::utils::bbox::Universal2DBox;
use pyo3::prelude::*;
#[pyclass(
text_signature = "(feature_opt, feature_quality_opt, bounding_box, custom_object_id_opt)"
)]
#[derive(Debug, Clone)]
#[pyo3(name = "VisualObservation")]
pub struct PyVisualObservation {
pub feature: Option<Vec<f32>>,
pub feature_quality: Option<f32>,
pub bounding_box: Universal2DBox,
pub custom_object_id: Option<i64>,
}
#[pymethods]
impl PyVisualObservation {
#[new]
pub fn new(
feature: Option<Vec<f32>>,
feature_quality: Option<f32>,
bounding_box: Universal2DBox,
custom_object_id: Option<i64>,
) -> Self {
Self {
feature,
feature_quality,
bounding_box,
custom_object_id,
}
}
#[classattr]
const __hash__: Option<Py<PyAny>> = None;
fn __repr__(&self) -> String {
format!("{:?}", self)
}
fn __str__(&self) -> String {
format!("{:#?}", self)
}
}
#[pyclass(
text_signature = "(feature_opt, feature_quality_opt, bounding_box, custom_object_id_opt)"
)]
#[derive(Debug)]
#[pyo3(name = "VisualObservationSet")]
pub struct PyVisualObservationSet {
inner: Vec<PyVisualObservation>,
}
#[pymethods]
impl PyVisualObservationSet {
#[new]
fn new() -> Self {
Self {
inner: Vec::default(),
}
}
#[pyo3(text_signature = "($self, observation)")]
fn add(&mut self, observation: PyVisualObservation) {
self.inner.push(observation);
}
#[classattr]
const __hash__: Option<Py<PyAny>> = None;
fn __repr__(&self) -> String {
format!("{:?}", self)
}
fn __str__(&self) -> String {
format!("{:#?}", self)
}
}
#[pymethods]
impl VisualSort {
#[new]
pub fn new_py(shards: i64, opts: &VisualSortOptions) -> Self {
assert!(shards > 0);
Self::new(shards.try_into().unwrap(), opts)
}
#[pyo3(name = "skip_epochs", text_signature = "($self, n)")]
pub fn skip_epochs_py(&mut self, n: i64) {
assert!(n > 0);
self.skip_epochs(n.try_into().unwrap())
}
#[pyo3(
name = "skip_epochs_for_scene",
text_signature = "($self, scene_id, n)"
)]
pub fn skip_epochs_for_scene_py(&mut self, scene_id: i64, n: i64) {
assert!(n > 0 && scene_id >= 0);
self.skip_epochs_for_scene(scene_id.try_into().unwrap(), n.try_into().unwrap())
}
#[pyo3(name = "shard_stats", text_signature = "($self)")]
pub fn shard_stats_py(&self) -> Vec<i64> {
let gil = Python::acquire_gil();
let py = gil.python();
py.allow_threads(|| {
self.store
.shard_stats()
.into_iter()
.map(|e| i64::try_from(e).unwrap())
.collect()
})
}
#[pyo3(name = "current_epoch", text_signature = "($self)")]
pub fn current_epoch_py(&self) -> i64 {
self.current_epoch_with_scene(0).try_into().unwrap()
}
#[pyo3(
name = "current_epoch_with_scene",
text_signature = "($self, scene_id)"
)]
pub fn current_epoch_with_scene_py(&self, scene_id: i64) -> isize {
assert!(scene_id >= 0);
self.current_epoch_with_scene(scene_id.try_into().unwrap())
.try_into()
.unwrap()
}
#[pyo3(name = "predict", text_signature = "($self, observation_set)")]
pub fn predict_py(&mut self, observation_set: &PyVisualObservationSet) -> Vec<SortTrack> {
self.predict_with_scene_py(0, observation_set)
}
#[pyo3(
name = "predict_with_scene",
text_signature = "($self, scene_id, observations)"
)]
pub fn predict_with_scene_py(
&mut self,
scene_id: i64,
observation_set: &PyVisualObservationSet,
) -> Vec<SortTrack> {
assert!(scene_id >= 0);
let gil = Python::acquire_gil();
let py = gil.python();
let observations = observation_set
.inner
.iter()
.map(|e| {
VisualObservation::new(
e.feature.as_ref(),
e.feature_quality,
e.bounding_box.clone(),
e.custom_object_id,
)
})
.collect::<Vec<_>>();
py.allow_threads(|| self.predict_with_scene(scene_id.try_into().unwrap(), &observations))
}
#[pyo3(name = "wasted", text_signature = "($self)")]
pub fn wasted_py(&mut self) -> Vec<PyWastedVisualSortTrack> {
let gil = Python::acquire_gil();
let py = gil.python();
py.allow_threads(|| {
self.wasted()
.into_iter()
.map(PyWastedVisualSortTrack::from)
.collect()
})
}
}