1use yscv_tensor::Tensor;
8
9use crate::{BoundingBox, Detection, non_max_suppression};
10
11#[derive(Debug, Clone)]
13pub struct YoloConfig {
14 pub input_size: usize,
16 pub num_classes: usize,
18 pub conf_threshold: f32,
20 pub iou_threshold: f32,
22 pub class_labels: Vec<String>,
24}
25
26#[rustfmt::skip]
28pub fn coco_labels() -> Vec<String> {
29 [
30 "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train",
31 "truck", "boat", "traffic light", "fire hydrant", "stop sign",
32 "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep",
33 "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella",
34 "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
35 "sports ball", "kite", "baseball bat", "baseball glove", "skateboard",
36 "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork",
37 "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange",
38 "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair",
39 "couch", "potted plant", "bed", "dining table", "toilet", "tv",
40 "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave",
41 "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase",
42 "scissors", "teddy bear", "hair drier", "toothbrush",
43 ]
44 .iter()
45 .map(|s| (*s).to_string())
46 .collect()
47}
48
49pub fn yolov8_coco_config() -> YoloConfig {
51 YoloConfig {
52 input_size: 640,
53 num_classes: 80,
54 conf_threshold: 0.25,
55 iou_threshold: 0.45,
56 class_labels: coco_labels(),
57 }
58}
59
60pub fn decode_yolov8_output(
72 output: &Tensor,
73 config: &YoloConfig,
74 orig_width: usize,
75 orig_height: usize,
76) -> Vec<Detection> {
77 let shape = output.shape();
78 if shape.len() != 3 || shape[0] != 1 {
80 return Vec::new();
81 }
82 let rows = shape[1]; let num_preds = shape[2];
84 if rows < 5 {
85 return Vec::new();
86 }
87 let num_classes = rows - 4;
88
89 let data = output.data();
90
91 let scale = (config.input_size as f32 / orig_width as f32)
93 .min(config.input_size as f32 / orig_height as f32);
94 let new_w = orig_width as f32 * scale;
95 let new_h = orig_height as f32 * scale;
96 let pad_x = (config.input_size as f32 - new_w) / 2.0;
97 let pad_y = (config.input_size as f32 - new_h) / 2.0;
98
99 let mut candidates = Vec::new();
100
101 for i in 0..num_preds {
102 let cx = data[i];
104 let cy = data[num_preds + i];
105 let w = data[2 * num_preds + i];
106 let h = data[3 * num_preds + i];
107
108 let mut best_score = f32::NEG_INFINITY;
110 let mut best_class = 0usize;
111 for c in 0..num_classes {
112 let s = data[(4 + c) * num_preds + i];
113 if s > best_score {
114 best_score = s;
115 best_class = c;
116 }
117 }
118
119 if best_score < config.conf_threshold {
120 continue;
121 }
122
123 let x1 = ((cx - w / 2.0) - pad_x) / scale;
125 let y1 = ((cy - h / 2.0) - pad_y) / scale;
126 let x2 = ((cx + w / 2.0) - pad_x) / scale;
127 let y2 = ((cy + h / 2.0) - pad_y) / scale;
128
129 let x1 = x1.max(0.0).min(orig_width as f32);
131 let y1 = y1.max(0.0).min(orig_height as f32);
132 let x2 = x2.max(0.0).min(orig_width as f32);
133 let y2 = y2.max(0.0).min(orig_height as f32);
134
135 candidates.push(Detection {
136 bbox: BoundingBox { x1, y1, x2, y2 },
137 score: best_score,
138 class_id: best_class,
139 });
140 }
141
142 non_max_suppression(&candidates, config.iou_threshold, candidates.len().max(1))
143}
144
145pub fn yolov11_coco_config() -> YoloConfig {
149 yolov8_coco_config()
150}
151
152pub fn decode_yolov11_output(
159 output: &Tensor,
160 config: &YoloConfig,
161 orig_width: usize,
162 orig_height: usize,
163) -> Vec<Detection> {
164 let shape = output.shape();
165 let (num_preds, cols) = if shape.len() == 3 {
167 (shape[1], shape[2])
168 } else if shape.len() == 2 {
169 (shape[0], shape[1])
170 } else {
171 return Vec::new();
172 };
173
174 if cols < 5 {
175 return Vec::new();
176 }
177 let num_classes = cols - 4;
178
179 let data = output.data();
180
181 let scale = (config.input_size as f32 / orig_width as f32)
182 .min(config.input_size as f32 / orig_height as f32);
183 let new_w = orig_width as f32 * scale;
184 let new_h = orig_height as f32 * scale;
185 let pad_x = (config.input_size as f32 - new_w) / 2.0;
186 let pad_y = (config.input_size as f32 - new_h) / 2.0;
187
188 let mut candidates = Vec::new();
189
190 let base = if shape.len() == 3 { 0 } else { 0 };
192
193 for i in 0..num_preds {
194 let row = base + i * cols;
195 let cx = data[row];
196 let cy = data[row + 1];
197 let w = data[row + 2];
198 let h = data[row + 3];
199
200 let mut best_score = f32::NEG_INFINITY;
201 let mut best_class = 0usize;
202 for c in 0..num_classes {
203 let s = data[row + 4 + c];
204 if s > best_score {
205 best_score = s;
206 best_class = c;
207 }
208 }
209
210 if best_score < config.conf_threshold {
211 continue;
212 }
213
214 let x1 = ((cx - w / 2.0) - pad_x) / scale;
215 let y1 = ((cy - h / 2.0) - pad_y) / scale;
216 let x2 = ((cx + w / 2.0) - pad_x) / scale;
217 let y2 = ((cy + h / 2.0) - pad_y) / scale;
218
219 let x1 = x1.max(0.0).min(orig_width as f32);
220 let y1 = y1.max(0.0).min(orig_height as f32);
221 let x2 = x2.max(0.0).min(orig_width as f32);
222 let y2 = y2.max(0.0).min(orig_height as f32);
223
224 candidates.push(Detection {
225 bbox: BoundingBox { x1, y1, x2, y2 },
226 score: best_score,
227 class_id: best_class,
228 });
229 }
230
231 non_max_suppression(&candidates, config.iou_threshold, candidates.len().max(1))
232}
233
234pub fn letterbox_preprocess(image: &Tensor, target_size: usize) -> (Tensor, f32, f32, f32) {
240 let shape = image.shape();
241 assert!(
242 shape.len() == 3 && shape[2] == 3,
243 "expected [H, W, 3] tensor"
244 );
245 let src_h = shape[0];
246 let src_w = shape[1];
247 let data = image.data();
248
249 let scale = (target_size as f32 / src_w as f32).min(target_size as f32 / src_h as f32);
250 let new_w = (src_w as f32 * scale).round() as usize;
251 let new_h = (src_h as f32 * scale).round() as usize;
252 let pad_x = (target_size - new_w) as f32 / 2.0;
253 let pad_y = (target_size - new_h) as f32 / 2.0;
254 let pad_left = pad_x.floor() as usize;
255 let pad_top = pad_y.floor() as usize;
256
257 let total = target_size * target_size * 3;
259 let mut out = vec![0.5f32; total];
260
261 let scale_x = src_w as f32 / new_w as f32;
263 let scale_y = src_h as f32 / new_h as f32;
264
265 for y in 0..new_h {
266 let src_y = ((y as f32 * scale_y) as usize).min(src_h - 1);
267 for x in 0..new_w {
268 let src_x = ((x as f32 * scale_x) as usize).min(src_w - 1);
269 let dst_idx = ((pad_top + y) * target_size + (pad_left + x)) * 3;
270 let src_idx = (src_y * src_w + src_x) * 3;
271 out[dst_idx] = data[src_idx];
272 out[dst_idx + 1] = data[src_idx + 1];
273 out[dst_idx + 2] = data[src_idx + 2];
274 }
275 }
276
277 let tensor = Tensor::from_vec(vec![target_size, target_size, 3], out)
278 .expect("letterbox output tensor creation");
279 (tensor, scale, pad_x, pad_y)
280}
281
282#[allow(dead_code)]
287fn hwc_to_nchw(hwc: &Tensor) -> Vec<f32> {
288 let shape = hwc.shape();
289 let h = shape[0];
290 let w = shape[1];
291 let data = hwc.data();
292 let mut nchw = vec![0.0f32; 3 * h * w];
293 for y in 0..h {
294 for x in 0..w {
295 let src = (y * w + x) * 3;
296 for c in 0..3 {
297 nchw[c * h * w + y * w + x] = data[src + c];
298 }
299 }
300 }
301 nchw
302}
303
304#[cfg(feature = "onnx")]
310pub fn detect_yolov8_onnx(
311 model: &yscv_onnx::OnnxModel,
312 image_data: &[f32],
313 img_height: usize,
314 img_width: usize,
315 config: &YoloConfig,
316) -> Result<Vec<Detection>, crate::DetectError> {
317 use std::collections::HashMap;
318
319 let input_name = model
320 .inputs
321 .first()
322 .cloned()
323 .unwrap_or_else(|| "images".to_string());
324
325 let tensor = Tensor::from_vec(
326 vec![1, 3, config.input_size, config.input_size],
327 image_data.to_vec(),
328 )?;
329
330 let mut inputs = HashMap::new();
331 inputs.insert(input_name, tensor);
332
333 let outputs = yscv_onnx::run_onnx_model(model, inputs)?;
334
335 let output_name = model
336 .outputs
337 .first()
338 .cloned()
339 .unwrap_or_else(|| "output0".to_string());
340
341 let output_tensor =
342 outputs
343 .get(&output_name)
344 .ok_or_else(|| yscv_onnx::OnnxError::MissingInput {
345 node: "model_output".to_string(),
346 input: output_name,
347 })?;
348
349 Ok(decode_yolov8_output(
350 output_tensor,
351 config,
352 img_width,
353 img_height,
354 ))
355}
356
357#[cfg(feature = "onnx")]
363pub fn detect_yolov8_from_rgb(
364 model: &yscv_onnx::OnnxModel,
365 rgb_data: &[f32],
366 height: usize,
367 width: usize,
368 config: &YoloConfig,
369) -> Result<Vec<Detection>, crate::DetectError> {
370 let image = Tensor::from_vec(vec![height, width, 3], rgb_data.to_vec())?;
371 let (letterboxed, _scale, _pad_x, _pad_y) = letterbox_preprocess(&image, config.input_size);
372
373 let nchw = hwc_to_nchw(&letterboxed);
374
375 detect_yolov8_onnx(model, &nchw, height, width, config)
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381
382 #[test]
383 fn test_coco_labels_count() {
384 assert_eq!(coco_labels().len(), 80);
385 }
386
387 #[test]
388 fn test_yolov8_coco_config_defaults() {
389 let cfg = yolov8_coco_config();
390 assert_eq!(cfg.input_size, 640);
391 assert_eq!(cfg.num_classes, 80);
392 assert!((cfg.conf_threshold - 0.25).abs() < 1e-6);
393 assert!((cfg.iou_threshold - 0.45).abs() < 1e-6);
394 assert_eq!(cfg.class_labels.len(), 80);
395 }
396
397 fn make_one_detection_tensor() -> Tensor {
400 let num_classes = 80;
401 let rows = 4 + num_classes;
402 let num_preds = 8400;
403 let mut data = vec![0.0f32; rows * num_preds];
404
405 data[0] = 320.0; data[num_preds] = 320.0; data[2 * num_preds] = 100.0; data[3 * num_preds] = 100.0; data[(4 + 5) * num_preds] = 0.9;
413
414 Tensor::from_vec(vec![1, rows, num_preds], data).unwrap()
415 }
416
417 #[test]
418 fn test_decode_yolov8_output_basic() {
419 let tensor = make_one_detection_tensor();
420 let config = YoloConfig {
421 input_size: 640,
422 num_classes: 80,
423 conf_threshold: 0.25,
424 iou_threshold: 0.45,
425 class_labels: coco_labels(),
426 };
427
428 let dets = decode_yolov8_output(&tensor, &config, 640, 640);
430 assert_eq!(dets.len(), 1);
431 assert_eq!(dets[0].class_id, 5);
432 assert!((dets[0].score - 0.9).abs() < 1e-6);
433
434 let b = &dets[0].bbox;
436 assert!((b.x1 - 270.0).abs() < 1.0);
437 assert!((b.y1 - 270.0).abs() < 1.0);
438 assert!((b.x2 - 370.0).abs() < 1.0);
439 assert!((b.y2 - 370.0).abs() < 1.0);
440 }
441
442 #[test]
443 fn test_decode_yolov8_output_confidence_filter() {
444 let tensor = make_one_detection_tensor();
445 let config = YoloConfig {
446 input_size: 640,
447 num_classes: 80,
448 conf_threshold: 0.95, iou_threshold: 0.45,
450 class_labels: coco_labels(),
451 };
452 let dets = decode_yolov8_output(&tensor, &config, 640, 640);
453 assert!(dets.is_empty());
454 }
455
456 #[test]
457 fn test_decode_yolov8_output_nms() {
458 let num_classes = 80;
459 let rows = 4 + num_classes;
460 let num_preds = 8400;
461 let mut data = vec![0.0f32; rows * num_preds];
462
463 data[0] = 320.0;
466 data[num_preds] = 320.0;
467 data[2 * num_preds] = 100.0;
468 data[3 * num_preds] = 100.0;
469 data[4 * num_preds] = 0.9;
470
471 data[1] = 325.0;
473 data[num_preds + 1] = 325.0;
474 data[2 * num_preds + 1] = 100.0;
475 data[3 * num_preds + 1] = 100.0;
476 data[4 * num_preds + 1] = 0.8;
477
478 let tensor = Tensor::from_vec(vec![1, rows, num_preds], data).unwrap();
479 let config = YoloConfig {
480 input_size: 640,
481 num_classes: 80,
482 conf_threshold: 0.25,
483 iou_threshold: 0.45,
484 class_labels: coco_labels(),
485 };
486
487 let dets = decode_yolov8_output(&tensor, &config, 640, 640);
488 assert_eq!(dets.len(), 1);
490 assert!((dets[0].score - 0.9).abs() < 1e-6);
491 }
492
493 #[test]
494 fn test_letterbox_preprocess_square() {
495 let img = Tensor::from_vec(vec![100, 100, 3], vec![0.5; 100 * 100 * 3]).unwrap();
497 let (out, scale, pad_x, pad_y) = letterbox_preprocess(&img, 640);
498 assert_eq!(out.shape(), &[640, 640, 3]);
499 assert!((scale - 6.4).abs() < 0.01);
500 assert!(pad_x.abs() < 1.0);
501 assert!(pad_y.abs() < 1.0);
502 }
503
504 #[test]
505 fn test_hwc_to_nchw_basic() {
506 let data = vec![
508 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 0.0, 0.5, ];
513 let img = Tensor::from_vec(vec![2, 2, 3], data).unwrap();
514 let nchw = hwc_to_nchw(&img);
515 assert_eq!(nchw.len(), 12);
517 assert!((nchw[0] - 0.1).abs() < 1e-6); assert!((nchw[1] - 0.4).abs() < 1e-6); assert!((nchw[2] - 0.7).abs() < 1e-6); assert!((nchw[3] - 1.0).abs() < 1e-6); assert!((nchw[4] - 0.2).abs() < 1e-6);
524 assert!((nchw[5] - 0.5).abs() < 1e-6);
525 assert!((nchw[6] - 0.8).abs() < 1e-6);
526 assert!((nchw[7] - 0.0).abs() < 1e-6);
527 assert!((nchw[8] - 0.3).abs() < 1e-6);
529 assert!((nchw[9] - 0.6).abs() < 1e-6);
530 assert!((nchw[10] - 0.9).abs() < 1e-6);
531 assert!((nchw[11] - 0.5).abs() < 1e-6);
532 }
533
534 #[test]
535 fn test_letterbox_then_nchw_pipeline() {
536 let img = Tensor::from_vec(vec![100, 200, 3], vec![0.4; 100 * 200 * 3]).unwrap();
538 let (letterboxed, _scale, _pad_x, _pad_y) = letterbox_preprocess(&img, 640);
539 assert_eq!(letterboxed.shape(), &[640, 640, 3]);
540 let nchw = hwc_to_nchw(&letterboxed);
541 assert_eq!(nchw.len(), 3 * 640 * 640);
542 }
543
544 #[test]
545 fn test_letterbox_preprocess_landscape() {
546 let img = Tensor::from_vec(vec![100, 200, 3], vec![0.4; 100 * 200 * 3]).unwrap();
549 let (out, scale, pad_x, pad_y) = letterbox_preprocess(&img, 640);
550 assert_eq!(out.shape(), &[640, 640, 3]);
551 assert!((scale - 3.2).abs() < 0.01);
552 assert!(pad_x.abs() < 1.0);
553 assert!((pad_y - 160.0).abs() < 1.0);
554
555 let top_pixel = &out.data()[0..3];
557 for &v in top_pixel {
558 assert!((v - 0.5).abs() < 1e-6, "top padding should be 0.5 grey");
559 }
560 }
561}