1use crate::error::{MlError, MlResult};
42
43#[derive(Clone, Copy, Debug, PartialEq)]
60pub struct BoundingBox {
61 pub x0: f32,
63 pub y0: f32,
65 pub x1: f32,
67 pub y1: f32,
69}
70
71impl BoundingBox {
72 #[must_use]
74 pub const fn new(x0: f32, y0: f32, x1: f32, y1: f32) -> Self {
75 Self { x0, y0, x1, y1 }
76 }
77
78 #[must_use]
80 pub fn width(&self) -> f32 {
81 (self.x1 - self.x0).max(0.0)
82 }
83
84 #[must_use]
86 pub fn height(&self) -> f32 {
87 (self.y1 - self.y0).max(0.0)
88 }
89
90 #[must_use]
92 pub fn area(&self) -> f32 {
93 self.width() * self.height()
94 }
95
96 #[must_use]
98 pub fn from_xywh_center(cx: f32, cy: f32, w: f32, h: f32) -> Self {
99 let half_w = w * 0.5;
100 let half_h = h * 0.5;
101 Self {
102 x0: cx - half_w,
103 y0: cy - half_h,
104 x1: cx + half_w,
105 y1: cy + half_h,
106 }
107 }
108}
109
110#[must_use]
126pub fn softmax(logits: &[f32]) -> Vec<f32> {
127 if logits.is_empty() {
128 return Vec::new();
129 }
130 let mut max = f32::NEG_INFINITY;
131 for &v in logits {
132 if v > max {
133 max = v;
134 }
135 }
136 let mut exps: Vec<f32> = logits.iter().map(|&v| (v - max).exp()).collect();
137 let sum: f32 = exps.iter().sum();
138 if sum == 0.0 {
139 let n = exps.len() as f32;
141 for e in &mut exps {
142 *e = 1.0 / n;
143 }
144 } else {
145 for e in &mut exps {
146 *e /= sum;
147 }
148 }
149 exps
150}
151
152pub fn argmax(scores: &[f32]) -> MlResult<usize> {
169 if scores.is_empty() {
170 return Err(MlError::postprocess("argmax on empty slice"));
171 }
172 let mut best = 0usize;
173 let mut best_v = scores[0];
174 for (i, &v) in scores.iter().enumerate().skip(1) {
175 if v > best_v {
176 best = i;
177 best_v = v;
178 }
179 }
180 Ok(best)
181}
182
183pub fn top_k(scores: &[f32], k: usize) -> MlResult<Vec<(usize, f32)>> {
205 if scores.is_empty() {
206 return Err(MlError::postprocess("top_k on empty slice"));
207 }
208 if k == 0 {
209 return Ok(Vec::new());
210 }
211 let mut indexed: Vec<(usize, f32)> = scores.iter().copied().enumerate().collect();
212 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
213 indexed.truncate(k);
214 Ok(indexed)
215}
216
217#[must_use]
219pub fn sigmoid(v: f32) -> f32 {
220 1.0 / (1.0 + (-v).exp())
221}
222
223#[must_use]
225pub fn sigmoid_slice(values: &[f32]) -> Vec<f32> {
226 values.iter().copied().map(sigmoid).collect()
227}
228
229#[must_use]
234pub fn iou(a: &BoundingBox, b: &BoundingBox) -> f32 {
235 let ix0 = a.x0.max(b.x0);
236 let iy0 = a.y0.max(b.y0);
237 let ix1 = a.x1.min(b.x1);
238 let iy1 = a.y1.min(b.y1);
239 let iw = (ix1 - ix0).max(0.0);
240 let ih = (iy1 - iy0).max(0.0);
241 let inter = iw * ih;
242 if inter <= 0.0 {
243 return 0.0;
244 }
245 let area_a = a.area();
246 let area_b = b.area();
247 let union = area_a + area_b - inter;
248 if union <= 0.0 {
249 return 0.0;
250 }
251 (inter / union).clamp(0.0, 1.0)
252}
253
254#[must_use]
279pub fn nms(boxes: &[BoundingBox], scores: &[f32], iou_threshold: f32) -> Vec<usize> {
280 if boxes.len() != scores.len() || boxes.is_empty() {
281 return Vec::new();
282 }
283 let threshold = iou_threshold.clamp(0.0, 1.0);
284
285 let mut order: Vec<usize> = (0..boxes.len()).collect();
287 order.sort_by(|&a, &b| {
288 scores[b]
289 .partial_cmp(&scores[a])
290 .unwrap_or(std::cmp::Ordering::Equal)
291 });
292
293 let mut kept: Vec<usize> = Vec::with_capacity(order.len());
294 for &idx in &order {
295 let cand = &boxes[idx];
296 if cand.area() <= 0.0 {
297 continue;
298 }
299 let mut suppress = false;
300 for &keep_idx in &kept {
301 if iou(cand, &boxes[keep_idx]) > threshold {
302 suppress = true;
303 break;
304 }
305 }
306 if !suppress {
307 kept.push(idx);
308 }
309 }
310 kept
311}
312
313pub fn l2_normalize(v: &mut [f32]) {
318 let norm_sq: f32 = v.iter().map(|x| x * x).sum();
319 if !norm_sq.is_finite() || norm_sq <= 0.0 {
320 return;
321 }
322 let inv = norm_sq.sqrt().recip();
323 for x in v.iter_mut() {
324 *x *= inv;
325 }
326}
327
328#[must_use]
333pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
334 if a.len() != b.len() || a.is_empty() {
335 return 0.0;
336 }
337 let mut dot = 0.0_f32;
338 let mut norm_a = 0.0_f32;
339 let mut norm_b = 0.0_f32;
340 for (&x, &y) in a.iter().zip(b.iter()) {
341 dot += x * y;
342 norm_a += x * x;
343 norm_b += y * y;
344 }
345 if norm_a <= 0.0 || norm_b <= 0.0 {
346 return 0.0;
347 }
348 let denom = norm_a.sqrt() * norm_b.sqrt();
349 if denom <= 0.0 {
350 0.0
351 } else {
352 (dot / denom).clamp(-1.0, 1.0)
353 }
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359
360 #[test]
361 fn softmax_sums_to_one() {
362 let probs = softmax(&[1.0, 2.0, 3.0]);
363 let sum: f32 = probs.iter().sum();
364 assert!((sum - 1.0).abs() < 1e-5);
365 }
366
367 #[test]
368 fn softmax_empty_is_empty() {
369 assert!(softmax(&[]).is_empty());
370 }
371
372 #[test]
373 fn softmax_largest_input_is_largest_output() {
374 let probs = softmax(&[0.1, 5.0, 0.3, 0.2]);
375 assert!(probs[1] > probs[0]);
376 assert!(probs[1] > probs[2]);
377 assert!(probs[1] > probs[3]);
378 }
379
380 #[test]
381 fn argmax_picks_max() {
382 let idx = argmax(&[0.1, 0.4, 0.2]).expect("ok");
383 assert_eq!(idx, 1);
384 }
385
386 #[test]
387 fn argmax_empty_errors() {
388 let err = argmax(&[]).expect_err("must fail");
389 assert!(matches!(err, MlError::Postprocess(_)));
390 }
391
392 #[test]
393 fn top_k_sorted_descending() {
394 let r = top_k(&[0.1, 0.5, 0.3, 0.7, 0.2], 3).expect("ok");
395 assert_eq!(r.len(), 3);
396 assert_eq!(r[0].0, 3);
397 assert_eq!(r[1].0, 1);
398 assert_eq!(r[2].0, 2);
399 }
400
401 #[test]
402 fn top_k_zero_returns_empty() {
403 let r = top_k(&[1.0, 2.0], 0).expect("ok");
404 assert!(r.is_empty());
405 }
406
407 #[test]
408 fn sigmoid_zero_is_half() {
409 assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
410 }
411
412 #[test]
413 fn sigmoid_slice_matches() {
414 let v = sigmoid_slice(&[-10.0, 0.0, 10.0]);
415 assert!(v[0] < 0.001);
416 assert!((v[1] - 0.5).abs() < 1e-6);
417 assert!(v[2] > 0.999);
418 }
419
420 #[test]
421 fn bbox_xywh_center_round_trip() {
422 let b = BoundingBox::from_xywh_center(10.0, 20.0, 4.0, 8.0);
423 assert!((b.x0 - 8.0).abs() < 1e-5);
424 assert!((b.y0 - 16.0).abs() < 1e-5);
425 assert!((b.x1 - 12.0).abs() < 1e-5);
426 assert!((b.y1 - 24.0).abs() < 1e-5);
427 assert!((b.area() - 32.0).abs() < 1e-5);
428 }
429
430 #[test]
431 fn bbox_negative_extent_has_zero_area() {
432 let b = BoundingBox::new(5.0, 5.0, 2.0, 2.0);
433 assert_eq!(b.width(), 0.0);
434 assert_eq!(b.height(), 0.0);
435 assert_eq!(b.area(), 0.0);
436 }
437
438 #[test]
439 fn iou_identical_boxes_is_one() {
440 let b = BoundingBox::new(0.0, 0.0, 10.0, 10.0);
441 assert!((iou(&b, &b) - 1.0).abs() < 1e-6);
442 }
443
444 #[test]
445 fn iou_zero_area_returns_zero() {
446 let a = BoundingBox::new(0.0, 0.0, 0.0, 0.0);
447 let b = BoundingBox::new(0.0, 0.0, 10.0, 10.0);
448 assert_eq!(iou(&a, &b), 0.0);
449 }
450
451 #[test]
452 fn nms_handles_length_mismatch() {
453 let boxes = vec![BoundingBox::new(0.0, 0.0, 1.0, 1.0)];
454 let scores = vec![0.9_f32, 0.8];
455 assert!(nms(&boxes, &scores, 0.5).is_empty());
456 }
457
458 #[test]
459 fn l2_normalize_unit_vector_idempotent() {
460 let mut v = vec![3.0_f32, 4.0];
461 l2_normalize(&mut v);
462 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
463 assert!((norm - 1.0).abs() < 1e-5);
464 l2_normalize(&mut v);
466 let norm2: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
467 assert!((norm2 - 1.0).abs() < 1e-5);
468 }
469
470 #[test]
471 fn cosine_similarity_orthogonal_zero() {
472 let a = [1.0_f32, 0.0];
473 let b = [0.0_f32, 1.0];
474 assert!(cosine_similarity(&a, &b).abs() < 1e-6);
475 }
476
477 #[test]
478 fn cosine_similarity_length_mismatch_zero() {
479 let a = [1.0_f32, 2.0];
480 let b = [1.0_f32, 2.0, 3.0];
481 assert_eq!(cosine_similarity(&a, &b), 0.0);
482 }
483}