#[must_use]
pub fn calculate_iou(box1: &[f32; 4], box2: &[f32; 4]) -> f32 {
let x1 = box1[0].max(box2[0]);
let y1 = box1[1].max(box2[1]);
let x2 = box1[2].min(box2[2]);
let y2 = box1[3].min(box2[3]);
let intersection = ((x2 - x1).max(0.0)) * ((y2 - y1).max(0.0));
let area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]);
let area2 = (box2[2] - box2[0]) * (box2[3] - box2[1]);
let union = area1 + area2 - intersection;
if union > 0.0 {
intersection / union
} else {
0.0
}
}
fn get_covariance_params(w: f32, h: f32, angle: f32) -> (f32, f32, f32) {
let a = w.powi(2) / 12.0;
let b = h.powi(2) / 12.0;
let cos = angle.cos();
let sin = angle.sin();
let cos2 = cos.powi(2);
let sin2 = sin.powi(2);
let a_val = a * cos2 + b * sin2;
let b_val = a * sin2 + b * cos2;
let c_val = (a - b) * cos * sin;
(a_val, b_val, c_val)
}
#[must_use]
pub fn calculate_probiou(box1: &[f32; 5], box2: &[f32; 5]) -> f32 {
let eps = 1e-7;
let x1 = box1[0];
let y1 = box1[1];
let w1 = box1[2];
let h1 = box1[3];
let r1 = box1[4];
let x2 = box2[0];
let y2 = box2[1];
let w2 = box2[2];
let h2 = box2[3];
let r2 = box2[4];
let (a1, b1, c1) = get_covariance_params(w1, h1, r1);
let (a2, b2, c2) = get_covariance_params(w2, h2, r2);
let t1 = ((a1 + a2).mul_add((y1 - y2).powi(2), (b1 + b2) * (x1 - x2).powi(2))
/ (a1 + a2).mul_add(b1 + b2, -(c1 + c2).powi(2) + eps))
* 0.25;
let t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2))
/ (a1 + a2).mul_add(b1 + b2, -(c1 + c2).powi(2) + eps))
* 0.5;
let t3_num = (a1 + a2).mul_add(b1 + b2, -(c1 + c2).powi(2));
let t3_den = 4.0f32.mul_add(
((a1.mul_add(b1, -c1.powi(2))).max(0.0) * (a2.mul_add(b2, -c2.powi(2))).max(0.0)).sqrt(),
eps,
);
let t3 = (t3_num / t3_den + eps).ln() * 0.5;
let bd = (t1 + t2 + t3).clamp(eps, 100.0);
let hd = (1.0 - (-bd).exp() + eps).sqrt();
1.0 - hd
}
#[must_use]
pub fn nms_per_class(boxes: &[([f32; 4], f32, usize)], iou_threshold: f32) -> Vec<usize> {
if boxes.is_empty() {
return vec![];
}
let mut indices: Vec<usize> = (0..boxes.len()).collect();
indices.sort_by(|&a, &b| boxes[b].1.partial_cmp(&boxes[a].1).unwrap());
let mut keep = vec![];
let mut suppressed = vec![false; boxes.len()];
for &i in &indices {
if suppressed[i] {
continue;
}
keep.push(i);
let class_i = boxes[i].2;
for &j in &indices {
if !suppressed[j] && i != j {
if boxes[j].2 == class_i {
let iou = calculate_iou(&boxes[i].0, &boxes[j].0);
if iou > iou_threshold {
suppressed[j] = true;
}
}
}
}
}
keep
}
#[must_use]
pub fn nms_rotated_per_class(boxes: &[([f32; 5], f32, usize)], iou_threshold: f32) -> Vec<usize> {
if boxes.is_empty() {
return vec![];
}
let mut indices: Vec<usize> = (0..boxes.len()).collect();
indices.sort_by(|&a, &b| boxes[b].1.partial_cmp(&boxes[a].1).unwrap());
let mut keep = vec![];
let mut suppressed = vec![false; boxes.len()];
for &i in &indices {
if suppressed[i] {
continue;
}
keep.push(i);
let class_i = boxes[i].2;
for &j in &indices {
if !suppressed[j] && i != j && boxes[j].2 == class_i {
let iou = calculate_probiou(&boxes[i].0, &boxes[j].0);
if iou > iou_threshold {
suppressed[j] = true;
}
}
}
}
keep
}
#[must_use]
pub fn pluralize(word: &str) -> String {
match word {
"person" => "persons".to_string(),
"bus" => "buses".to_string(),
"knife" => "knives".to_string(),
"mouse" => "mice".to_string(),
"sheep" => "sheep".to_string(),
"skis" => "skis".to_string(),
_ => {
if word.ends_with('s')
|| word.ends_with('x')
|| word.ends_with("ch")
|| word.ends_with("sh")
{
format!("{word}es")
} else if word.ends_with('y')
&& !word.ends_with("ey")
&& !word.ends_with("ay")
&& !word.ends_with("oy")
{
format!("{}ies", &word[..word.len() - 1])
} else {
format!("{word}s")
}
}
}
}
use image::DynamicImage;
use ndarray::Array3;
use crate::error::{InferenceError, Result};
pub fn array_to_image(arr: &Array3<u8>) -> Result<DynamicImage> {
let shape = arr.shape();
let height = u32::try_from(shape[0])
.map_err(|_| InferenceError::ImageError("Image height exceeds u32::MAX".to_string()))?;
let width = u32::try_from(shape[1])
.map_err(|_| InferenceError::ImageError("Image width exceeds u32::MAX".to_string()))?;
let mut rgb_data = Vec::with_capacity((height * width * 3) as usize);
for y in 0..height as usize {
for x in 0..width as usize {
rgb_data.push(arr[[y, x, 0]]);
rgb_data.push(arr[[y, x, 1]]);
rgb_data.push(arr[[y, x, 2]]);
}
}
let img_buffer = image::RgbImage::from_raw(width, height, rgb_data).ok_or_else(|| {
InferenceError::ImageError("Failed to create image from array".to_string())
})?;
Ok(DynamicImage::ImageRgb8(img_buffer))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_calculate_iou() {
let box1 = [0.0, 0.0, 10.0, 10.0];
let box2 = [5.0, 5.0, 15.0, 15.0];
let iou = calculate_iou(&box1, &box2);
assert!((iou - 0.142_857).abs() < 0.001); }
#[test]
fn test_nms_per_class() {
let boxes = vec![
([0.0, 0.0, 10.0, 10.0], 0.9, 0), ([1.0, 1.0, 11.0, 11.0], 0.8, 1), ([100.0, 100.0, 110.0, 110.0], 0.95, 0), ];
let keep = nms_per_class(&boxes, 0.5);
assert_eq!(keep.len(), 3);
}
#[test]
fn test_nms_per_class_suppression() {
let boxes = vec![
([0.0, 0.0, 10.0, 10.0], 0.9, 0), ([1.0, 1.0, 11.0, 11.0], 0.8, 0), ];
let keep = nms_per_class(&boxes, 0.5);
assert_eq!(keep.len(), 1);
assert!(keep.contains(&0)); }
#[test]
fn test_pluralize() {
assert_eq!(pluralize("person"), "persons");
assert_eq!(pluralize("bus"), "buses");
assert_eq!(pluralize("match"), "matches");
assert_eq!(pluralize("box"), "boxes");
assert_eq!(pluralize("car"), "cars");
assert_eq!(pluralize("baby"), "babies");
assert_eq!(pluralize("toy"), "toys");
}
#[test]
fn test_array_to_image() {
let data = vec![255, 0, 0, 0, 255, 0]; let arr = Array3::from_shape_vec((1, 2, 3), data).unwrap();
let img = array_to_image(&arr).unwrap();
assert_eq!(img.width(), 2);
assert_eq!(img.height(), 1);
let rgb = img.to_rgb8();
let p1 = rgb.get_pixel(0, 0);
assert_eq!(p1[0], 255);
assert_eq!(p1[1], 0);
assert_eq!(p1[2], 0);
let p2 = rgb.get_pixel(1, 0);
assert_eq!(p2[0], 0);
assert_eq!(p2[1], 255);
assert_eq!(p2[2], 0);
}
}