1use yscv_tensor::Tensor;
8
9use crate::{BoundingBox, Detection, non_max_suppression};
10
11#[derive(Debug, Clone)]
13pub struct YoloConfig {
14 pub input_size: usize,
16 pub num_classes: usize,
18 pub conf_threshold: f32,
20 pub iou_threshold: f32,
22 pub class_labels: Vec<String>,
24}
25
26#[rustfmt::skip]
28pub fn coco_labels() -> Vec<String> {
29 [
30 "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train",
31 "truck", "boat", "traffic light", "fire hydrant", "stop sign",
32 "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep",
33 "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella",
34 "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
35 "sports ball", "kite", "baseball bat", "baseball glove", "skateboard",
36 "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork",
37 "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange",
38 "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair",
39 "couch", "potted plant", "bed", "dining table", "toilet", "tv",
40 "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave",
41 "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase",
42 "scissors", "teddy bear", "hair drier", "toothbrush",
43 ]
44 .iter()
45 .map(|s| (*s).to_string())
46 .collect()
47}
48
49pub fn yolov8_coco_config() -> YoloConfig {
51 YoloConfig {
52 input_size: 640,
53 num_classes: 80,
54 conf_threshold: 0.25,
55 iou_threshold: 0.45,
56 class_labels: coco_labels(),
57 }
58}
59
60pub fn decode_yolov8_output(
72 output: &Tensor,
73 config: &YoloConfig,
74 orig_width: usize,
75 orig_height: usize,
76) -> Vec<Detection> {
77 let shape = output.shape();
78 if shape.len() != 3 || shape[0] != 1 {
80 return Vec::new();
81 }
82 let rows = shape[1]; let num_preds = shape[2];
84 if rows < 5 {
85 return Vec::new();
86 }
87 let num_classes = rows - 4;
88
89 let data = output.data();
90
91 let scale = (config.input_size as f32 / orig_width as f32)
93 .min(config.input_size as f32 / orig_height as f32);
94 let new_w = orig_width as f32 * scale;
95 let new_h = orig_height as f32 * scale;
96 let pad_x = (config.input_size as f32 - new_w) / 2.0;
97 let pad_y = (config.input_size as f32 - new_h) / 2.0;
98
99 let mut candidates = Vec::new();
100
101 let required_len = (4 + num_classes) * num_preds;
103 if data.len() < required_len {
104 return Vec::new();
105 }
106
107 for i in 0..num_preds {
108 let cx = data[i];
110 let cy = data[num_preds + i];
111 let w = data[2 * num_preds + i];
112 let h = data[3 * num_preds + i];
113
114 let mut best_score = f32::NEG_INFINITY;
116 let mut best_class = 0usize;
117 for c in 0..num_classes {
118 let s = data[(4 + c) * num_preds + i];
119 if s > best_score {
120 best_score = s;
121 best_class = c;
122 }
123 }
124
125 if best_score < config.conf_threshold {
126 continue;
127 }
128
129 let x1 = ((cx - w / 2.0) - pad_x) / scale;
131 let y1 = ((cy - h / 2.0) - pad_y) / scale;
132 let x2 = ((cx + w / 2.0) - pad_x) / scale;
133 let y2 = ((cy + h / 2.0) - pad_y) / scale;
134
135 let x1 = x1.max(0.0).min(orig_width as f32);
137 let y1 = y1.max(0.0).min(orig_height as f32);
138 let x2 = x2.max(0.0).min(orig_width as f32);
139 let y2 = y2.max(0.0).min(orig_height as f32);
140
141 candidates.push(Detection {
142 bbox: BoundingBox { x1, y1, x2, y2 },
143 score: best_score,
144 class_id: best_class,
145 });
146 }
147
148 non_max_suppression(&candidates, config.iou_threshold, candidates.len().max(1))
149}
150
151pub fn yolov11_coco_config() -> YoloConfig {
155 yolov8_coco_config()
156}
157
158pub fn decode_yolov11_output(
165 output: &Tensor,
166 config: &YoloConfig,
167 orig_width: usize,
168 orig_height: usize,
169) -> Vec<Detection> {
170 let shape = output.shape();
171 let (num_preds, cols) = if shape.len() == 3 {
173 (shape[1], shape[2])
174 } else if shape.len() == 2 {
175 (shape[0], shape[1])
176 } else {
177 return Vec::new();
178 };
179
180 if cols < 5 {
181 return Vec::new();
182 }
183 let num_classes = cols - 4;
184
185 let data = output.data();
186
187 let scale = (config.input_size as f32 / orig_width as f32)
188 .min(config.input_size as f32 / orig_height as f32);
189 let new_w = orig_width as f32 * scale;
190 let new_h = orig_height as f32 * scale;
191 let pad_x = (config.input_size as f32 - new_w) / 2.0;
192 let pad_y = (config.input_size as f32 - new_h) / 2.0;
193
194 let mut candidates = Vec::new();
195
196 let required_len = num_preds * cols;
198 if data.len() < required_len {
199 return Vec::new();
200 }
201
202 for i in 0..num_preds {
203 let row = i * cols;
204 let cx = data[row];
205 let cy = data[row + 1];
206 let w = data[row + 2];
207 let h = data[row + 3];
208
209 let mut best_score = f32::NEG_INFINITY;
210 let mut best_class = 0usize;
211 for c in 0..num_classes {
212 let s = data[row + 4 + c];
213 if s > best_score {
214 best_score = s;
215 best_class = c;
216 }
217 }
218
219 if best_score < config.conf_threshold {
220 continue;
221 }
222
223 let x1 = ((cx - w / 2.0) - pad_x) / scale;
224 let y1 = ((cy - h / 2.0) - pad_y) / scale;
225 let x2 = ((cx + w / 2.0) - pad_x) / scale;
226 let y2 = ((cy + h / 2.0) - pad_y) / scale;
227
228 let x1 = x1.max(0.0).min(orig_width as f32);
229 let y1 = y1.max(0.0).min(orig_height as f32);
230 let x2 = x2.max(0.0).min(orig_width as f32);
231 let y2 = y2.max(0.0).min(orig_height as f32);
232
233 candidates.push(Detection {
234 bbox: BoundingBox { x1, y1, x2, y2 },
235 score: best_score,
236 class_id: best_class,
237 });
238 }
239
240 non_max_suppression(&candidates, config.iou_threshold, candidates.len().max(1))
241}
242
243pub fn letterbox_preprocess(image: &Tensor, target_size: usize) -> (Tensor, f32, f32, f32) {
249 let shape = image.shape();
250 assert!(
251 shape.len() == 3 && shape[2] == 3,
252 "expected [H, W, 3] tensor"
253 );
254 let src_h = shape[0];
255 let src_w = shape[1];
256 let data = image.data();
257
258 let scale = (target_size as f32 / src_w as f32).min(target_size as f32 / src_h as f32);
259 let new_w = ((src_w as f32 * scale).round() as usize).min(target_size);
260 let new_h = ((src_h as f32 * scale).round() as usize).min(target_size);
261 let pad_x = (target_size - new_w) as f32 / 2.0;
262 let pad_y = (target_size - new_h) as f32 / 2.0;
263 let pad_left = pad_x.floor() as usize;
264 let pad_top = pad_y.floor() as usize;
265
266 let total = target_size * target_size * 3;
268 let mut out = vec![114.0f32 / 255.0; total];
269
270 let inv_scale_x = src_w as f32 / new_w as f32;
275 let inv_scale_y = src_h as f32 / new_h as f32;
276 let support_x = if inv_scale_x > 1.0 { inv_scale_x } else { 1.0 };
278 let support_y = if inv_scale_y > 1.0 { inv_scale_y } else { 1.0 };
279
280 for y in 0..new_h {
281 let center_y = (y as f32 + 0.5) * inv_scale_y - 0.5;
283 let y_min = ((center_y - support_y).ceil() as isize).max(0) as usize;
284 let y_max = ((center_y + support_y).floor() as isize).min(src_h as isize - 1) as usize;
285
286 for x in 0..new_w {
287 let center_x = (x as f32 + 0.5) * inv_scale_x - 0.5;
288 let x_min = ((center_x - support_x).ceil() as isize).max(0) as usize;
289 let x_max = ((center_x + support_x).floor() as isize).min(src_w as isize - 1) as usize;
290
291 let dst_idx = ((pad_top + y) * target_size + (pad_left + x)) * 3;
292 let mut sum = [0.0f32; 3];
293 let mut weight_sum = 0.0f32;
294
295 for sy in y_min..=y_max {
296 let wy = 1.0 - (sy as f32 - center_y).abs() / support_y;
297 if wy <= 0.0 {
298 continue;
299 }
300 for sx in x_min..=x_max {
301 let wx = 1.0 - (sx as f32 - center_x).abs() / support_x;
302 if wx <= 0.0 {
303 continue;
304 }
305 let w = wx * wy;
306 let src_idx = (sy * src_w + sx) * 3;
307 sum[0] += data[src_idx] * w;
308 sum[1] += data[src_idx + 1] * w;
309 sum[2] += data[src_idx + 2] * w;
310 weight_sum += w;
311 }
312 }
313
314 if weight_sum > 0.0 {
315 let inv_w = 1.0 / weight_sum;
316 out[dst_idx] = sum[0] * inv_w;
317 out[dst_idx + 1] = sum[1] * inv_w;
318 out[dst_idx + 2] = sum[2] * inv_w;
319 }
320 }
321 }
322
323 let tensor = Tensor::from_vec(vec![target_size, target_size, 3], out)
324 .unwrap_or_else(|_| unreachable!("letterbox: shape matches pre-allocated output"));
325 (tensor, scale, pad_x, pad_y)
326}
327
328#[cfg(any(feature = "onnx", test))]
333fn hwc_to_nchw(hwc: &Tensor) -> Vec<f32> {
334 let shape = hwc.shape();
335 let h = shape[0];
336 let w = shape[1];
337 let data = hwc.data();
338 let mut nchw = vec![0.0f32; 3 * h * w];
339 for y in 0..h {
340 for x in 0..w {
341 let src = (y * w + x) * 3;
342 for c in 0..3 {
343 nchw[c * h * w + y * w + x] = data[src + c];
344 }
345 }
346 }
347 nchw
348}
349
350#[cfg(feature = "onnx")]
356pub fn detect_yolov8_onnx(
357 model: &yscv_onnx::OnnxModel,
358 image_data: &[f32],
359 img_height: usize,
360 img_width: usize,
361 config: &YoloConfig,
362) -> Result<Vec<Detection>, crate::DetectError> {
363 use std::collections::HashMap;
364
365 let input_name = model
366 .inputs
367 .first()
368 .cloned()
369 .unwrap_or_else(|| "images".to_string());
370
371 let tensor = Tensor::from_vec(
372 vec![1, 3, config.input_size, config.input_size],
373 image_data.to_vec(),
374 )?;
375
376 let mut inputs = HashMap::new();
377 inputs.insert(input_name, tensor);
378
379 let outputs = yscv_onnx::run_onnx_model(model, inputs)?;
380
381 let output_name = model
382 .outputs
383 .first()
384 .cloned()
385 .unwrap_or_else(|| "output0".to_string());
386
387 let output_tensor =
388 outputs
389 .get(&output_name)
390 .ok_or_else(|| yscv_onnx::OnnxError::MissingInput {
391 node: "model_output".to_string(),
392 input: output_name,
393 })?;
394
395 Ok(decode_yolov8_output(
396 output_tensor,
397 config,
398 img_width,
399 img_height,
400 ))
401}
402
403#[cfg(feature = "onnx")]
409pub fn detect_yolov8_from_rgb(
410 model: &yscv_onnx::OnnxModel,
411 rgb_data: &[f32],
412 height: usize,
413 width: usize,
414 config: &YoloConfig,
415) -> Result<Vec<Detection>, crate::DetectError> {
416 let image = Tensor::from_vec(vec![height, width, 3], rgb_data.to_vec())?;
417 let (letterboxed, _scale, _pad_x, _pad_y) = letterbox_preprocess(&image, config.input_size);
418
419 let nchw = hwc_to_nchw(&letterboxed);
420
421 detect_yolov8_onnx(model, &nchw, height, width, config)
422}
423
424#[cfg(test)]
425mod tests {
426 use super::*;
427
428 #[test]
429 fn test_coco_labels_count() {
430 assert_eq!(coco_labels().len(), 80);
431 }
432
433 #[test]
434 fn test_yolov8_coco_config_defaults() {
435 let cfg = yolov8_coco_config();
436 assert_eq!(cfg.input_size, 640);
437 assert_eq!(cfg.num_classes, 80);
438 assert!((cfg.conf_threshold - 0.25).abs() < 1e-6);
439 assert!((cfg.iou_threshold - 0.45).abs() < 1e-6);
440 assert_eq!(cfg.class_labels.len(), 80);
441 }
442
443 fn make_one_detection_tensor() -> Tensor {
446 let num_classes = 80;
447 let rows = 4 + num_classes;
448 let num_preds = 8400;
449 let mut data = vec![0.0f32; rows * num_preds];
450
451 data[0] = 320.0; data[num_preds] = 320.0; data[2 * num_preds] = 100.0; data[3 * num_preds] = 100.0; data[(4 + 5) * num_preds] = 0.9;
459
460 Tensor::from_vec(vec![1, rows, num_preds], data).unwrap()
461 }
462
463 #[test]
464 fn test_decode_yolov8_output_basic() {
465 let tensor = make_one_detection_tensor();
466 let config = YoloConfig {
467 input_size: 640,
468 num_classes: 80,
469 conf_threshold: 0.25,
470 iou_threshold: 0.45,
471 class_labels: coco_labels(),
472 };
473
474 let dets = decode_yolov8_output(&tensor, &config, 640, 640);
476 assert_eq!(dets.len(), 1);
477 assert_eq!(dets[0].class_id, 5);
478 assert!((dets[0].score - 0.9).abs() < 1e-6);
479
480 let b = &dets[0].bbox;
482 assert!((b.x1 - 270.0).abs() < 1.0);
483 assert!((b.y1 - 270.0).abs() < 1.0);
484 assert!((b.x2 - 370.0).abs() < 1.0);
485 assert!((b.y2 - 370.0).abs() < 1.0);
486 }
487
488 #[test]
489 fn test_decode_yolov8_output_confidence_filter() {
490 let tensor = make_one_detection_tensor();
491 let config = YoloConfig {
492 input_size: 640,
493 num_classes: 80,
494 conf_threshold: 0.95, iou_threshold: 0.45,
496 class_labels: coco_labels(),
497 };
498 let dets = decode_yolov8_output(&tensor, &config, 640, 640);
499 assert!(dets.is_empty());
500 }
501
502 #[test]
503 fn test_decode_yolov8_output_nms() {
504 let num_classes = 80;
505 let rows = 4 + num_classes;
506 let num_preds = 8400;
507 let mut data = vec![0.0f32; rows * num_preds];
508
509 data[0] = 320.0;
512 data[num_preds] = 320.0;
513 data[2 * num_preds] = 100.0;
514 data[3 * num_preds] = 100.0;
515 data[4 * num_preds] = 0.9;
516
517 data[1] = 325.0;
519 data[num_preds + 1] = 325.0;
520 data[2 * num_preds + 1] = 100.0;
521 data[3 * num_preds + 1] = 100.0;
522 data[4 * num_preds + 1] = 0.8;
523
524 let tensor = Tensor::from_vec(vec![1, rows, num_preds], data).unwrap();
525 let config = YoloConfig {
526 input_size: 640,
527 num_classes: 80,
528 conf_threshold: 0.25,
529 iou_threshold: 0.45,
530 class_labels: coco_labels(),
531 };
532
533 let dets = decode_yolov8_output(&tensor, &config, 640, 640);
534 assert_eq!(dets.len(), 1);
536 assert!((dets[0].score - 0.9).abs() < 1e-6);
537 }
538
539 #[test]
540 fn test_letterbox_preprocess_square() {
541 let img = Tensor::from_vec(vec![100, 100, 3], vec![0.5; 100 * 100 * 3]).unwrap();
543 let (out, scale, pad_x, pad_y) = letterbox_preprocess(&img, 640);
544 assert_eq!(out.shape(), &[640, 640, 3]);
545 assert!((scale - 6.4).abs() < 0.01);
546 assert!(pad_x.abs() < 1.0);
547 assert!(pad_y.abs() < 1.0);
548 }
549
550 #[test]
551 fn test_hwc_to_nchw_basic() {
552 let data = vec![
554 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 0.0, 0.5, ];
559 let img = Tensor::from_vec(vec![2, 2, 3], data).unwrap();
560 let nchw = hwc_to_nchw(&img);
561 assert_eq!(nchw.len(), 12);
563 assert!((nchw[0] - 0.1).abs() < 1e-6); assert!((nchw[1] - 0.4).abs() < 1e-6); assert!((nchw[2] - 0.7).abs() < 1e-6); assert!((nchw[3] - 1.0).abs() < 1e-6); assert!((nchw[4] - 0.2).abs() < 1e-6);
570 assert!((nchw[5] - 0.5).abs() < 1e-6);
571 assert!((nchw[6] - 0.8).abs() < 1e-6);
572 assert!((nchw[7] - 0.0).abs() < 1e-6);
573 assert!((nchw[8] - 0.3).abs() < 1e-6);
575 assert!((nchw[9] - 0.6).abs() < 1e-6);
576 assert!((nchw[10] - 0.9).abs() < 1e-6);
577 assert!((nchw[11] - 0.5).abs() < 1e-6);
578 }
579
580 #[test]
581 fn test_letterbox_then_nchw_pipeline() {
582 let img = Tensor::from_vec(vec![100, 200, 3], vec![0.4; 100 * 200 * 3]).unwrap();
584 let (letterboxed, _scale, _pad_x, _pad_y) = letterbox_preprocess(&img, 640);
585 assert_eq!(letterboxed.shape(), &[640, 640, 3]);
586 let nchw = hwc_to_nchw(&letterboxed);
587 assert_eq!(nchw.len(), 3 * 640 * 640);
588 }
589
590 #[test]
591 fn test_letterbox_preprocess_landscape() {
592 let img = Tensor::from_vec(vec![100, 200, 3], vec![0.4; 100 * 200 * 3]).unwrap();
595 let (out, scale, pad_x, pad_y) = letterbox_preprocess(&img, 640);
596 assert_eq!(out.shape(), &[640, 640, 3]);
597 assert!((scale - 3.2).abs() < 0.01);
598 assert!(pad_x.abs() < 1.0);
599 assert!((pad_y - 160.0).abs() < 1.0);
600
601 let top_pixel = &out.data()[0..3];
603 for &v in top_pixel {
604 assert!(
605 (v - 114.0 / 255.0).abs() < 1e-6,
606 "top padding should be 114/255 grey"
607 );
608 }
609 }
610}