use std::collections::HashMap;
use ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2, Axis, s};
#[derive(Debug, Clone, Default)]
pub struct Speed {
pub preprocess: Option<f64>,
pub inference: Option<f64>,
pub postprocess: Option<f64>,
}
impl Speed {
#[must_use]
pub const fn new(preprocess: f64, inference: f64, postprocess: f64) -> Self {
Self {
preprocess: Some(preprocess),
inference: Some(inference),
postprocess: Some(postprocess),
}
}
#[must_use]
pub fn total(&self) -> f64 {
self.preprocess.unwrap_or(0.0)
+ self.inference.unwrap_or(0.0)
+ self.postprocess.unwrap_or(0.0)
}
}
#[derive(Debug, Clone)]
pub struct Results {
pub orig_img: Array3<u8>,
pub orig_shape: (u32, u32),
pub inference_shape: (u32, u32),
pub boxes: Option<Boxes>,
pub masks: Option<Masks>,
pub keypoints: Option<Keypoints>,
pub probs: Option<Probs>,
pub obb: Option<Obb>,
pub speed: Speed,
pub names: HashMap<usize, String>,
pub path: String,
}
impl Results {
#[must_use]
pub fn new(
orig_img: Array3<u8>,
path: String,
names: HashMap<usize, String>,
speed: Speed,
inference_shape: (u32, u32),
) -> Self {
let shape = orig_img.shape();
#[allow(clippy::cast_possible_truncation)]
let orig_shape = (shape[0] as u32, shape[1] as u32);
Self {
orig_img,
orig_shape,
inference_shape,
boxes: None,
masks: None,
keypoints: None,
probs: None,
obb: None,
speed,
names,
path,
}
}
#[must_use]
pub fn len(&self) -> usize {
if let Some(ref boxes) = self.boxes {
return boxes.len();
}
if let Some(ref masks) = self.masks {
return masks.len();
}
if let Some(ref keypoints) = self.keypoints {
return keypoints.len();
}
if let Some(ref probs) = self.probs {
return usize::from(!probs.data.is_empty());
}
if let Some(ref obb) = self.obb {
return obb.len();
}
0
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[must_use]
pub const fn orig_shape(&self) -> (u32, u32) {
self.orig_shape
}
#[must_use]
pub const fn inference_shape(&self) -> (u32, u32) {
self.inference_shape
}
#[must_use]
pub fn verbose(&self) -> String {
if self.is_empty() {
if self.probs.is_some() {
return String::new();
}
return "(no detections), ".to_string();
}
if let Some(ref probs) = self.probs {
let top5: Vec<String> = probs
.top5()
.iter()
.map(|&i| {
let name = self.names.get(&i).cloned().unwrap_or_else(|| i.to_string());
format!("{} {:.2}", name, probs.data[i])
})
.collect();
return format!("{}, ", top5.join(", "));
}
if let Some(ref boxes) = self.boxes {
let cls = boxes.cls();
let mut counts: HashMap<usize, usize> = HashMap::new();
for &c in cls {
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let c = c as usize;
*counts.entry(c).or_insert(0) += 1;
}
let mut parts = Vec::new();
for (class_id, count) in &counts {
let name = self
.names
.get(class_id)
.cloned()
.unwrap_or_else(|| class_id.to_string());
let suffix = if *count > 1 { "s" } else { "" };
parts.push(format!("{count} {name}{suffix}"));
}
return format!("{}, ", parts.join(", "));
}
String::new()
}
#[must_use]
pub fn summary(&self, normalize: bool) -> Vec<HashMap<String, SummaryValue>> {
let mut results = Vec::new();
if let Some(ref probs) = self.probs {
let class_id = probs.top1();
let mut entry = HashMap::new();
entry.insert(
"name".to_string(),
SummaryValue::String(
self.names
.get(&class_id)
.cloned()
.unwrap_or_else(|| class_id.to_string()),
),
);
entry.insert("class".to_string(), SummaryValue::Int(class_id));
entry.insert(
"confidence".to_string(),
SummaryValue::Float(probs.top1conf()),
);
results.push(entry);
return results;
}
if let Some(ref boxes) = self.boxes {
let (h, w) = if normalize {
#[allow(clippy::cast_precision_loss)]
(self.orig_shape.0 as f32, self.orig_shape.1 as f32)
} else {
(1.0, 1.0)
};
let xyxy = boxes.xyxy();
let conf = boxes.conf();
let cls = boxes.cls();
for i in 0..boxes.len() {
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let class_id = cls[i] as usize;
let mut entry = HashMap::new();
entry.insert(
"name".to_string(),
SummaryValue::String(
self.names
.get(&class_id)
.cloned()
.unwrap_or_else(|| class_id.to_string()),
),
);
entry.insert("class".to_string(), SummaryValue::Int(class_id));
entry.insert("confidence".to_string(), SummaryValue::Float(conf[i]));
let mut box_coords = HashMap::new();
box_coords.insert("x1".to_string(), SummaryValue::Float(xyxy[[i, 0]] / w));
box_coords.insert("y1".to_string(), SummaryValue::Float(xyxy[[i, 1]] / h));
box_coords.insert("x2".to_string(), SummaryValue::Float(xyxy[[i, 2]] / w));
box_coords.insert("y2".to_string(), SummaryValue::Float(xyxy[[i, 3]] / h));
entry.insert("box".to_string(), SummaryValue::Box(box_coords));
results.push(entry);
}
}
results
}
#[cfg(feature = "annotate")]
pub fn save<P: AsRef<std::path::Path>>(&self, path: P) -> crate::error::Result<()> {
let img = crate::utils::array_to_image(&self.orig_img)?;
let annotated = crate::annotate::annotate_image(&img, self, None);
annotated
.save(path)
.map_err(|e| crate::error::InferenceError::ImageError(e.to_string()))
}
}
#[derive(Debug, Clone)]
pub enum SummaryValue {
String(String),
Int(usize),
Float(f32),
Box(HashMap<String, Self>),
}
#[derive(Debug, Clone)]
pub struct Boxes {
pub data: Array2<f32>,
pub orig_shape: (u32, u32),
is_track: bool,
}
impl Boxes {
#[must_use]
pub fn new(data: Array2<f32>, orig_shape: (u32, u32)) -> Self {
let is_track = data.shape()[1] == 7;
Self {
data,
orig_shape,
is_track,
}
}
#[must_use]
pub fn len(&self) -> usize {
self.data.nrows()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
#[must_use]
pub fn xyxy(&self) -> ArrayView2<'_, f32> {
self.data.slice(s![.., 0..4])
}
#[must_use]
pub fn conf(&self) -> ArrayView1<'_, f32> {
self.data.slice(s![.., -2])
}
#[must_use]
pub fn cls(&self) -> ArrayView1<'_, f32> {
self.data.slice(s![.., -1])
}
#[must_use]
pub fn id(&self) -> Option<ArrayView1<'_, f32>> {
if self.is_track {
Some(self.data.slice(s![.., -3]))
} else {
None
}
}
#[must_use]
pub fn xywh(&self) -> Array2<f32> {
let xyxy = self.xyxy();
let n = xyxy.nrows();
let mut xywh = Array2::zeros((n, 4));
for i in 0..n {
let x1 = xyxy[[i, 0]];
let y1 = xyxy[[i, 1]];
let x2 = xyxy[[i, 2]];
let y2 = xyxy[[i, 3]];
xywh[[i, 0]] = f32::midpoint(x1, x2); xywh[[i, 1]] = f32::midpoint(y1, y2); xywh[[i, 2]] = x2 - x1; xywh[[i, 3]] = y2 - y1; }
xywh
}
#[must_use]
pub fn xyxyn(&self) -> Array2<f32> {
let mut xyxyn = self.xyxy().to_owned();
#[allow(clippy::cast_precision_loss)]
let (h, w) = (self.orig_shape.0 as f32, self.orig_shape.1 as f32);
for mut row in xyxyn.rows_mut() {
row[0] /= w;
row[1] /= h;
row[2] /= w;
row[3] /= h;
}
xyxyn
}
#[must_use]
pub fn xywhn(&self) -> Array2<f32> {
let mut xywhn = self.xywh();
#[allow(clippy::cast_precision_loss)]
let (h, w) = (self.orig_shape.0 as f32, self.orig_shape.1 as f32);
for mut row in xywhn.rows_mut() {
row[0] /= w;
row[1] /= h;
row[2] /= w;
row[3] /= h;
}
xywhn
}
#[must_use]
pub const fn is_track(&self) -> bool {
self.is_track
}
}
#[derive(Debug, Clone)]
pub struct Masks {
pub data: Array3<f32>,
pub orig_shape: (u32, u32),
}
impl Masks {
#[must_use]
pub const fn new(data: Array3<f32>, orig_shape: (u32, u32)) -> Self {
Self { data, orig_shape }
}
#[must_use]
pub fn len(&self) -> usize {
self.data.shape()[0]
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
}
#[derive(Debug, Clone)]
pub struct Keypoints {
pub data: Array3<f32>,
pub orig_shape: (u32, u32),
has_visible: bool,
}
impl Keypoints {
#[must_use]
pub fn new(data: Array3<f32>, orig_shape: (u32, u32)) -> Self {
let has_visible = data.shape()[2] == 3;
Self {
data,
orig_shape,
has_visible,
}
}
#[must_use]
pub fn len(&self) -> usize {
self.data.shape()[0]
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
#[must_use]
pub fn xy(&self) -> Array3<f32> {
self.data.slice(s![.., .., 0..2]).to_owned()
}
#[must_use]
pub fn xyn(&self) -> Array3<f32> {
let mut xyn = self.xy();
#[allow(clippy::cast_precision_loss)]
let (h, w) = (self.orig_shape.0 as f32, self.orig_shape.1 as f32);
for mut point in xyn.axis_iter_mut(Axis(2)) {
if point.shape()[0] > 0 {
point.mapv_inplace(|v| v / w);
}
if point.shape()[0] > 1 {
point.mapv_inplace(|v| v / h);
}
}
xyn
}
#[must_use]
pub fn conf(&self) -> Option<Array2<f32>> {
if self.has_visible {
Some(self.data.slice(s![.., .., 2]).to_owned())
} else {
None
}
}
}
#[derive(Debug, Clone)]
pub struct Probs {
pub data: Array1<f32>,
}
impl Probs {
#[must_use]
pub const fn new(data: Array1<f32>) -> Self {
Self { data }
}
#[must_use]
pub fn top1(&self) -> usize {
self.data
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map_or(0, |(i, _)| i)
}
#[must_use]
pub fn top5(&self) -> Vec<usize> {
self.top_k(5)
}
#[must_use]
pub fn top_k(&self, k: usize) -> Vec<usize> {
let mut indices: Vec<usize> = (0..self.data.len()).collect();
indices.sort_by(|&a, &b| {
self.data[b]
.partial_cmp(&self.data[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
indices.truncate(k);
indices
}
#[must_use]
pub fn top1conf(&self) -> f32 {
self.data[self.top1()]
}
#[must_use]
pub fn top5conf(&self) -> Vec<f32> {
self.top5().iter().map(|&i| self.data[i]).collect()
}
}
#[derive(Debug, Clone)]
pub struct Obb {
pub data: Array2<f32>,
pub orig_shape: (u32, u32),
is_track: bool,
}
impl Obb {
#[must_use]
pub fn new(data: Array2<f32>, orig_shape: (u32, u32)) -> Self {
let is_track = data.shape()[1] == 8;
Self {
data,
orig_shape,
is_track,
}
}
#[must_use]
pub fn len(&self) -> usize {
self.data.nrows()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
#[must_use]
pub fn xywhr(&self) -> ArrayView2<'_, f32> {
self.data.slice(s![.., 0..5])
}
#[must_use]
pub fn conf(&self) -> ArrayView1<'_, f32> {
self.data.slice(s![.., -2])
}
#[must_use]
pub fn cls(&self) -> ArrayView1<'_, f32> {
self.data.slice(s![.., -1])
}
#[must_use]
pub fn id(&self) -> Option<ArrayView1<'_, f32>> {
if self.is_track {
Some(self.data.slice(s![.., -3]))
} else {
None
}
}
#[must_use]
pub fn xyxyxyxy(&self) -> Array3<f32> {
let n = self.len();
let mut corners = Array3::zeros((n, 4, 2));
for i in 0..n {
let cx = self.data[[i, 0]];
let cy = self.data[[i, 1]];
let w = self.data[[i, 2]];
let h = self.data[[i, 3]];
let angle = self.data[[i, 4]];
let cos_a = angle.cos();
let sin_a = angle.sin();
let hw = w / 2.0;
let hh = h / 2.0;
let corners_rel = [
(-hw, -hh), (hw, -hh), (hw, hh), (-hw, hh), ];
for (j, (dx, dy)) in corners_rel.iter().enumerate() {
let rotated_x = dx * cos_a - dy * sin_a;
let rotated_y = dx * sin_a + dy * cos_a;
corners[[i, j, 0]] = cx + rotated_x;
corners[[i, j, 1]] = cy + rotated_y;
}
}
corners
}
#[must_use]
pub fn xyxy(&self) -> Array2<f32> {
let corners = self.xyxyxyxy();
let n = self.len();
let mut xyxy = Array2::zeros((n, 4));
for i in 0..n {
let mut min_x = f32::INFINITY;
let mut min_y = f32::INFINITY;
let mut max_x = f32::NEG_INFINITY;
let mut max_y = f32::NEG_INFINITY;
for j in 0..4 {
let x = corners[[i, j, 0]];
let y = corners[[i, j, 1]];
min_x = min_x.min(x);
min_y = min_y.min(y);
max_x = max_x.max(x);
max_y = max_y.max(y);
}
#[allow(clippy::cast_precision_loss)]
let (h, w) = (self.orig_shape.0 as f32, self.orig_shape.1 as f32);
xyxy[[i, 0]] = min_x.max(0.0).min(w);
xyxy[[i, 1]] = min_y.max(0.0).min(h);
xyxy[[i, 2]] = max_x.max(0.0).min(w);
xyxy[[i, 3]] = max_y.max(0.0).min(h);
}
xyxy
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
#[test]
fn test_boxes_xyxy() {
let data = array![[10.0, 20.0, 100.0, 200.0, 0.95, 0.0]];
let boxes = Boxes::new(data, (480, 640));
assert_eq!(boxes.len(), 1);
assert!((boxes.conf()[0] - 0.95).abs() < 1e-6);
assert!((boxes.cls()[0] - 0.0).abs() < 1e-6);
}
#[test]
fn test_boxes_xywh() {
let data = array![[0.0, 0.0, 100.0, 100.0, 0.9, 1.0]];
let boxes = Boxes::new(data, (640, 640));
let xywh = boxes.xywh();
assert!((xywh[[0, 0]] - 50.0).abs() < 1e-6); assert!((xywh[[0, 1]] - 50.0).abs() < 1e-6); assert!((xywh[[0, 2]] - 100.0).abs() < 1e-6); assert!((xywh[[0, 3]] - 100.0).abs() < 1e-6); }
#[test]
fn test_boxes_normalized() {
let data = array![[0.0, 0.0, 320.0, 240.0, 0.9, 0.0]];
let boxes = Boxes::new(data, (480, 640));
let xyxyn = boxes.xyxyn();
assert!((xyxyn[[0, 0]] - 0.0).abs() < 1e-6);
assert!((xyxyn[[0, 1]] - 0.0).abs() < 1e-6);
assert!((xyxyn[[0, 2]] - 0.5).abs() < 1e-6); assert!((xyxyn[[0, 3]] - 0.5).abs() < 1e-6); }
#[test]
fn test_probs() {
let data = array![0.1, 0.3, 0.6];
let probs = Probs::new(data);
assert_eq!(probs.top1(), 2);
assert_eq!(probs.top5(), vec![2, 1, 0]);
assert!((probs.top1conf() - 0.6).abs() < 1e-6);
}
#[test]
fn test_speed() {
let speed = Speed::new(10.0, 20.0, 5.0);
assert!((speed.total() - 35.0).abs() < 1e-6);
}
#[test]
fn test_results_verbose() {
let names = HashMap::from([(0, "person".to_string())]);
let speed = Speed::default();
let orig_img = Array3::zeros((100, 100, 3));
let results = Results::new(orig_img, "test.jpg".to_string(), names, speed, (640, 640));
assert!(results.is_empty());
assert_eq!(results.verbose(), "(no detections), ");
}
}