use crate::common::Rect;
use crate::error::{SceneError, SceneResult};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone)]
pub struct PyramidConfig {
pub num_levels: usize,
pub scale_factor: f32,
pub min_dimension: usize,
pub edge_threshold: f32,
pub block_size: usize,
pub nms_iou_threshold: f32,
}
impl Default for PyramidConfig {
fn default() -> Self {
Self {
num_levels: 4,
scale_factor: 0.5,
min_dimension: 32,
edge_threshold: 0.08,
block_size: 16,
nms_iou_threshold: 0.4,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PyramidLevel {
pub level: usize,
pub width: usize,
pub height: usize,
pub scale: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PyramidDetection {
pub bbox: Rect,
pub confidence: f32,
pub source_level: usize,
}
pub struct PyramidDetector {
config: PyramidConfig,
}
impl PyramidDetector {
#[must_use]
pub fn new() -> Self {
Self {
config: PyramidConfig::default(),
}
}
#[must_use]
pub fn with_config(config: PyramidConfig) -> Self {
Self { config }
}
#[must_use]
pub fn build_levels(&self, width: usize, height: usize) -> Vec<PyramidLevel> {
let mut levels = Vec::with_capacity(self.config.num_levels);
let mut w = width;
let mut h = height;
let mut scale = 1.0_f32;
for i in 0..self.config.num_levels {
if w < self.config.min_dimension || h < self.config.min_dimension {
break;
}
levels.push(PyramidLevel {
level: i,
width: w,
height: h,
scale,
});
w = ((w as f32 * self.config.scale_factor) as usize).max(1);
h = ((h as f32 * self.config.scale_factor) as usize).max(1);
scale *= self.config.scale_factor;
}
levels
}
fn downsample(rgb: &[u8], src_w: usize, src_h: usize, factor: f32) -> (Vec<u8>, usize, usize) {
let dst_w = ((src_w as f32 * factor) as usize).max(1);
let dst_h = ((src_h as f32 * factor) as usize).max(1);
let mut out = vec![0u8; dst_w * dst_h * 3];
for dy in 0..dst_h {
for dx in 0..dst_w {
let sx = ((dx as f32 / factor) as usize).min(src_w - 1);
let sy = ((dy as f32 / factor) as usize).min(src_h - 1);
let src_idx = (sy * src_w + sx) * 3;
let dst_idx = (dy * dst_w + dx) * 3;
out[dst_idx] = rgb[src_idx];
out[dst_idx + 1] = rgb[src_idx + 1];
out[dst_idx + 2] = rgb[src_idx + 2];
}
}
(out, dst_w, dst_h)
}
fn block_edge_density(
rgb: &[u8],
width: usize,
height: usize,
bx: usize,
by: usize,
bw: usize,
bh: usize,
) -> f32 {
let mut edge_sum = 0.0_f64;
let mut count = 0_u64;
let x_end = (bx + bw).min(width.saturating_sub(1));
let y_end = (by + bh).min(height.saturating_sub(1));
for y in by..y_end {
for x in bx..x_end {
let idx = (y * width + x) * 3;
let idx_right = (y * width + x + 1) * 3;
let idx_below = ((y + 1) * width + x) * 3;
if idx_right + 2 < rgb.len() && idx_below + 2 < rgb.len() {
let mut diff = 0.0_f32;
for c in 0..3 {
diff += (rgb[idx + c] as f32 - rgb[idx_right + c] as f32).abs();
diff += (rgb[idx + c] as f32 - rgb[idx_below + c] as f32).abs();
}
edge_sum += (diff / 6.0 / 255.0) as f64;
count += 1;
}
}
}
if count > 0 {
(edge_sum / count as f64) as f32
} else {
0.0
}
}
pub fn detect(
&self,
rgb_data: &[u8],
width: usize,
height: usize,
) -> SceneResult<Vec<PyramidDetection>> {
if rgb_data.len() != width * height * 3 {
return Err(SceneError::InvalidDimensions(
"RGB data size mismatch".to_string(),
));
}
let levels = self.build_levels(width, height);
let mut all_detections: Vec<PyramidDetection> = Vec::new();
let mut current_rgb = rgb_data.to_vec();
let mut cur_w = width;
let mut cur_h = height;
for level_info in &levels {
let bs = self.config.block_size;
let blocks_x = (cur_w / bs).max(1);
let blocks_y = (cur_h / bs).max(1);
for by_idx in 0..blocks_y {
for bx_idx in 0..blocks_x {
let bx = bx_idx * bs;
let by = by_idx * bs;
let bw = bs.min(cur_w - bx);
let bh = bs.min(cur_h - by);
let density =
Self::block_edge_density(¤t_rgb, cur_w, cur_h, bx, by, bw, bh);
if density >= self.config.edge_threshold {
let inv_scale = 1.0 / level_info.scale;
let orig_x = bx as f32 * inv_scale;
let orig_y = by as f32 * inv_scale;
let orig_w = bw as f32 * inv_scale;
let orig_h = bh as f32 * inv_scale;
all_detections.push(PyramidDetection {
bbox: Rect::new(orig_x, orig_y, orig_w, orig_h),
confidence: density.clamp(0.0, 1.0),
source_level: level_info.level,
});
}
}
}
if level_info.level + 1 < levels.len() {
let (down, dw, dh) =
Self::downsample(¤t_rgb, cur_w, cur_h, self.config.scale_factor);
current_rgb = down;
cur_w = dw;
cur_h = dh;
}
}
self.apply_nms(&mut all_detections);
Ok(all_detections)
}
fn apply_nms(&self, detections: &mut Vec<PyramidDetection>) {
detections.sort_by(|a, b| {
b.confidence
.partial_cmp(&a.confidence)
.unwrap_or(std::cmp::Ordering::Equal)
});
let n = detections.len();
let mut suppressed = vec![false; n];
for i in 0..n {
if suppressed[i] {
continue;
}
for j in (i + 1)..n {
if suppressed[j] {
continue;
}
if detections[i].bbox.iou(&detections[j].bbox) > self.config.nms_iou_threshold {
suppressed[j] = true;
}
}
}
let mut out = Vec::with_capacity(n);
for (i, det) in detections.drain(..).enumerate() {
if !suppressed[i] {
out.push(det);
}
}
*detections = out;
}
#[must_use]
pub fn config(&self) -> &PyramidConfig {
&self.config
}
}
impl Default for PyramidDetector {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn uniform_image(w: usize, h: usize, v: u8) -> Vec<u8> {
vec![v; w * h * 3]
}
fn edgy_image(w: usize, h: usize) -> Vec<u8> {
let mut data = vec![0u8; w * h * 3];
for y in 0..h {
for x in 0..w {
let idx = (y * w + x) * 3;
let v = if x % 4 < 2 { 200u8 } else { 20u8 };
data[idx] = v;
data[idx + 1] = v;
data[idx + 2] = v;
}
}
data
}
#[test]
fn test_pyramid_config_defaults() {
let cfg = PyramidConfig::default();
assert_eq!(cfg.num_levels, 4);
assert!((cfg.scale_factor - 0.5).abs() < f32::EPSILON);
assert_eq!(cfg.min_dimension, 32);
assert_eq!(cfg.block_size, 16);
}
#[test]
fn test_build_levels() {
let det = PyramidDetector::new();
let levels = det.build_levels(640, 480);
assert!(!levels.is_empty());
assert_eq!(levels[0].width, 640);
assert_eq!(levels[0].height, 480);
assert!((levels[0].scale - 1.0).abs() < f32::EPSILON);
for i in 1..levels.len() {
assert!(levels[i].width < levels[i - 1].width);
assert!(levels[i].height < levels[i - 1].height);
}
}
#[test]
fn test_build_levels_min_dimension() {
let cfg = PyramidConfig {
num_levels: 10,
scale_factor: 0.5,
min_dimension: 100,
..Default::default()
};
let det = PyramidDetector::with_config(cfg);
let levels = det.build_levels(200, 200);
assert_eq!(levels.len(), 2);
}
#[test]
fn test_detect_uniform_no_detections() {
let det = PyramidDetector::new();
let w = 128;
let h = 128;
let data = uniform_image(w, h, 128);
let result = det.detect(&data, w, h);
assert!(result.is_ok());
let dets = result.expect("should succeed");
assert!(
dets.is_empty(),
"uniform image should produce no detections"
);
}
#[test]
fn test_detect_edgy_image() {
let det = PyramidDetector::new();
let w = 128;
let h = 128;
let data = edgy_image(w, h);
let result = det.detect(&data, w, h);
assert!(result.is_ok());
let dets = result.expect("should succeed");
assert!(!dets.is_empty(), "edgy image should produce detections");
}
#[test]
fn test_detect_invalid_dimensions() {
let det = PyramidDetector::new();
let result = det.detect(&[0u8; 10], 100, 100);
assert!(result.is_err());
}
#[test]
fn test_detections_valid_bbox() {
let det = PyramidDetector::new();
let w = 128;
let h = 128;
let data = edgy_image(w, h);
let dets = det.detect(&data, w, h).expect("should succeed");
for d in &dets {
assert!(d.bbox.x >= 0.0);
assert!(d.bbox.y >= 0.0);
assert!(d.bbox.width > 0.0);
assert!(d.bbox.height > 0.0);
assert!(d.confidence > 0.0 && d.confidence <= 1.0);
}
}
#[test]
fn test_detections_multi_level() {
let cfg = PyramidConfig {
num_levels: 3,
scale_factor: 0.5,
min_dimension: 16,
edge_threshold: 0.01, block_size: 8,
nms_iou_threshold: 0.9, };
let det = PyramidDetector::with_config(cfg);
let w = 128;
let h = 128;
let data = edgy_image(w, h);
let dets = det.detect(&data, w, h).expect("should succeed");
let levels_seen: std::collections::HashSet<usize> =
dets.iter().map(|d| d.source_level).collect();
assert!(
levels_seen.len() >= 2,
"expected detections from multiple levels, got {:?}",
levels_seen
);
}
#[test]
fn test_nms_reduces_count() {
let cfg = PyramidConfig {
nms_iou_threshold: 0.3, edge_threshold: 0.01,
block_size: 8,
..Default::default()
};
let det = PyramidDetector::with_config(cfg);
let w = 64;
let h = 64;
let data = edgy_image(w, h);
let dets_nms = det.detect(&data, w, h).expect("should succeed");
let cfg_no_nms = PyramidConfig {
nms_iou_threshold: 1.0, edge_threshold: 0.01,
block_size: 8,
..Default::default()
};
let det_no_nms = PyramidDetector::with_config(cfg_no_nms);
let dets_all = det_no_nms.detect(&data, w, h).expect("should succeed");
assert!(
dets_nms.len() <= dets_all.len(),
"NMS should not increase count: {} vs {}",
dets_nms.len(),
dets_all.len()
);
}
#[test]
fn test_config_accessor() {
let cfg = PyramidConfig {
num_levels: 5,
scale_factor: 0.75,
min_dimension: 64,
edge_threshold: 0.1,
block_size: 32,
nms_iou_threshold: 0.5,
};
let det = PyramidDetector::with_config(cfg);
assert_eq!(det.config().num_levels, 5);
assert!((det.config().scale_factor - 0.75).abs() < f32::EPSILON);
assert_eq!(det.config().min_dimension, 64);
}
#[test]
fn test_downsample_dimensions() {
let w = 100;
let h = 80;
let data = uniform_image(w, h, 128);
let (down, dw, dh) = PyramidDetector::downsample(&data, w, h, 0.5);
assert_eq!(dw, 50);
assert_eq!(dh, 40);
assert_eq!(down.len(), dw * dh * 3);
}
#[test]
fn test_single_level_pyramid() {
let cfg = PyramidConfig {
num_levels: 1,
..Default::default()
};
let det = PyramidDetector::with_config(cfg);
let levels = det.build_levels(256, 256);
assert_eq!(levels.len(), 1);
let data = edgy_image(256, 256);
let result = det.detect(&data, 256, 256);
assert!(result.is_ok());
}
}