use super::{Hsv, Rgb};
use crate::error::{CvError, CvResult};
use oximedia_codec::VideoFrame;
use oximedia_core::PixelFormat;
use std::collections::HashMap;
pub struct AutoKeyDetector {
min_saturation: f32,
min_value: f32,
hue_bucket_size: f32,
}
impl AutoKeyDetector {
#[must_use]
pub fn new() -> Self {
Self {
min_saturation: 0.3,
min_value: 0.2,
hue_bucket_size: 10.0,
}
}
pub fn set_min_saturation(&mut self, saturation: f32) {
self.min_saturation = saturation.clamp(0.0, 1.0);
}
pub fn set_min_value(&mut self, value: f32) {
self.min_value = value.clamp(0.0, 1.0);
}
pub fn set_hue_bucket_size(&mut self, size: f32) {
self.hue_bucket_size = size.clamp(1.0, 90.0);
}
pub fn detect_from_region(
&self,
frame: &VideoFrame,
x: u32,
y: u32,
width: u32,
height: u32,
) -> CvResult<Rgb> {
if x + width > frame.width || y + height > frame.height {
return Err(CvError::invalid_roi(x, y, width, height));
}
let rgb_data = self.extract_region_rgb(frame, x, y, width, height)?;
self.detect_from_rgb_data(&rgb_data, width as usize, height as usize)
}
pub fn detect_from_frame(&self, frame: &VideoFrame) -> CvResult<Rgb> {
self.detect_from_region(frame, 0, 0, frame.width, frame.height)
}
#[allow(clippy::vec_init_then_push)]
pub fn detect_from_edges(&self, frame: &VideoFrame) -> CvResult<Rgb> {
let width = frame.width;
let height = frame.height;
let border_size = (width.min(height) / 10).max(20);
let mut samples = Vec::with_capacity(4);
samples.push(self.detect_from_region(frame, 0, 0, width, border_size)?);
samples.push(self.detect_from_region(
frame,
0,
height - border_size,
width,
border_size,
)?);
samples.push(self.detect_from_region(frame, 0, 0, border_size, height)?);
samples.push(self.detect_from_region(
frame,
width - border_size,
0,
border_size,
height,
)?);
let avg_r = samples.iter().map(|c| c.r).sum::<f32>() / samples.len() as f32;
let avg_g = samples.iter().map(|c| c.g).sum::<f32>() / samples.len() as f32;
let avg_b = samples.iter().map(|c| c.b).sum::<f32>() / samples.len() as f32;
Ok(Rgb::new(avg_r, avg_g, avg_b))
}
#[allow(clippy::vec_init_then_push)]
pub fn detect_from_corners(&self, frame: &VideoFrame) -> CvResult<Rgb> {
let width = frame.width;
let height = frame.height;
let sample_size = (width.min(height) / 8).max(50);
let mut samples = Vec::with_capacity(4);
samples.push(self.detect_from_region(frame, 0, 0, sample_size, sample_size)?);
samples.push(self.detect_from_region(
frame,
width - sample_size,
0,
sample_size,
sample_size,
)?);
samples.push(self.detect_from_region(
frame,
0,
height - sample_size,
sample_size,
sample_size,
)?);
samples.push(self.detect_from_region(
frame,
width - sample_size,
height - sample_size,
sample_size,
sample_size,
)?);
self.find_mode_color(&samples)
}
fn extract_region_rgb(
&self,
frame: &VideoFrame,
x: u32,
y: u32,
width: u32,
height: u32,
) -> CvResult<Vec<f32>> {
let region_size = (width * height) as usize;
let mut rgb_data = vec![0.0f32; region_size * 3];
match frame.format {
PixelFormat::Rgb24 => {
if frame.planes.is_empty() {
return Err(CvError::invalid_parameter("planes", "empty"));
}
let data = &frame.planes[0].data;
let stride = frame.planes[0].stride;
for row in 0..height as usize {
let src_y = y as usize + row;
let src_offset = src_y * stride + x as usize * 3;
let dst_offset = row * width as usize * 3;
for col in 0..width as usize {
let src_idx = src_offset + col * 3;
let dst_idx = dst_offset + col * 3;
rgb_data[dst_idx] = f32::from(data[src_idx]) / 255.0;
rgb_data[dst_idx + 1] = f32::from(data[src_idx + 1]) / 255.0;
rgb_data[dst_idx + 2] = f32::from(data[src_idx + 2]) / 255.0;
}
}
}
PixelFormat::Rgba32 => {
if frame.planes.is_empty() {
return Err(CvError::invalid_parameter("planes", "empty"));
}
let data = &frame.planes[0].data;
let stride = frame.planes[0].stride;
for row in 0..height as usize {
let src_y = y as usize + row;
let src_offset = src_y * stride + x as usize * 4;
let dst_offset = row * width as usize * 3;
for col in 0..width as usize {
let src_idx = src_offset + col * 4;
let dst_idx = dst_offset + col * 3;
rgb_data[dst_idx] = f32::from(data[src_idx]) / 255.0;
rgb_data[dst_idx + 1] = f32::from(data[src_idx + 1]) / 255.0;
rgb_data[dst_idx + 2] = f32::from(data[src_idx + 2]) / 255.0;
}
}
}
_ => {
return Err(CvError::unsupported_format(format!("{}", frame.format)));
}
}
Ok(rgb_data)
}
fn detect_from_rgb_data(&self, rgb_data: &[f32], width: usize, height: usize) -> CvResult<Rgb> {
let pixel_count = width * height;
let mut hue_histogram: HashMap<i32, ColorAccumulator> = HashMap::new();
for i in 0..pixel_count {
let r = rgb_data[i * 3];
let g = rgb_data[i * 3 + 1];
let b = rgb_data[i * 3 + 2];
let pixel = Rgb::new(r, g, b);
let hsv = pixel.to_hsv();
if hsv.s >= self.min_saturation && hsv.v >= self.min_value {
let hue_bucket = (hsv.h / self.hue_bucket_size) as i32;
let accumulator = hue_histogram.entry(hue_bucket).or_insert(ColorAccumulator {
count: 0,
sum_h: 0.0,
sum_s: 0.0,
sum_v: 0.0,
});
accumulator.count += 1;
accumulator.sum_h += hsv.h;
accumulator.sum_s += hsv.s;
accumulator.sum_v += hsv.v;
}
}
let dominant_bucket = hue_histogram
.iter()
.max_by_key(|(_, acc)| acc.count)
.ok_or_else(|| CvError::detection_failed("No suitable key color found"))?;
let accumulator = dominant_bucket.1;
let avg_hue = accumulator.sum_h / accumulator.count as f32;
let avg_sat = accumulator.sum_s / accumulator.count as f32;
let avg_val = accumulator.sum_v / accumulator.count as f32;
let key_hsv = Hsv::new(avg_hue, avg_sat, avg_val);
Ok(key_hsv.to_rgb())
}
fn find_mode_color(&self, samples: &[Rgb]) -> CvResult<Rgb> {
if samples.is_empty() {
return Err(CvError::detection_failed("No color samples provided"));
}
let hsv_samples: Vec<Hsv> = samples.iter().map(super::Rgb::to_hsv).collect();
let mut hue_groups: HashMap<i32, Vec<Hsv>> = HashMap::new();
for hsv in &hsv_samples {
if hsv.s >= self.min_saturation && hsv.v >= self.min_value {
let bucket = (hsv.h / self.hue_bucket_size) as i32;
hue_groups.entry(bucket).or_default().push(*hsv);
}
}
let largest_group = hue_groups
.values()
.max_by_key(|group| group.len())
.ok_or_else(|| CvError::detection_failed("No suitable key color found"))?;
let avg_h = largest_group.iter().map(|hsv| hsv.h).sum::<f32>() / largest_group.len() as f32;
let avg_s = largest_group.iter().map(|hsv| hsv.s).sum::<f32>() / largest_group.len() as f32;
let avg_v = largest_group.iter().map(|hsv| hsv.v).sum::<f32>() / largest_group.len() as f32;
let mode_hsv = Hsv::new(avg_h, avg_s, avg_v);
Ok(mode_hsv.to_rgb())
}
}
impl Default for AutoKeyDetector {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy)]
struct ColorAccumulator {
count: usize,
sum_h: f32,
sum_s: f32,
sum_v: f32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ScreenType {
GreenScreen,
BlueScreen,
Unknown,
}
pub struct ScreenTypeDetector {
green_hue_center: f32,
green_hue_tolerance: f32,
blue_hue_center: f32,
blue_hue_tolerance: f32,
}
impl ScreenTypeDetector {
#[must_use]
pub fn new() -> Self {
Self {
green_hue_center: 120.0, green_hue_tolerance: 30.0, blue_hue_center: 240.0, blue_hue_tolerance: 30.0, }
}
#[must_use]
pub fn detect(&self, color: &Rgb) -> ScreenType {
let hsv = color.to_hsv();
let green_diff = (hsv.h - self.green_hue_center).abs();
if green_diff <= self.green_hue_tolerance {
return ScreenType::GreenScreen;
}
let blue_diff = (hsv.h - self.blue_hue_center).abs();
let blue_diff_wrapped = (hsv.h - (self.blue_hue_center + 360.0)).abs();
if blue_diff <= self.blue_hue_tolerance || blue_diff_wrapped <= self.blue_hue_tolerance {
return ScreenType::BlueScreen;
}
ScreenType::Unknown
}
#[must_use]
pub fn recommend_config(&self, screen_type: ScreenType) -> (f32, f32) {
match screen_type {
ScreenType::GreenScreen => {
(0.35, 0.15)
}
ScreenType::BlueScreen => {
(0.30, 0.12)
}
ScreenType::Unknown => {
(0.30, 0.10)
}
}
}
}
impl Default for ScreenTypeDetector {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct BackgroundColorAnalysis {
pub primary_color: Rgb,
pub secondary_color: Option<Rgb>,
pub confidence: f32,
pub screen_type: ScreenType,
pub coverage: f32,
}
pub struct KmeansBackgroundDetector {
num_clusters: usize,
max_iterations: usize,
convergence_threshold: f32,
min_saturation: f32,
min_value: f32,
}
impl KmeansBackgroundDetector {
#[must_use]
pub fn new() -> Self {
Self {
num_clusters: 4,
max_iterations: 20,
convergence_threshold: 0.5,
min_saturation: 0.2,
min_value: 0.1,
}
}
pub fn set_num_clusters(&mut self, k: usize) {
self.num_clusters = k.clamp(2, 16);
}
pub fn set_max_iterations(&mut self, iters: usize) {
self.max_iterations = iters.clamp(5, 100);
}
pub fn analyse(&self, frame: &VideoFrame) -> CvResult<BackgroundColorAnalysis> {
let pixels = self.sample_border_pixels(frame)?;
if pixels.len() < self.num_clusters {
return Err(CvError::detection_failed(
"insufficient pixels for background detection",
));
}
let filtered: Vec<Hsv> = pixels
.iter()
.map(|rgb| rgb.to_hsv())
.filter(|hsv| hsv.s >= self.min_saturation && hsv.v >= self.min_value)
.collect();
if filtered.is_empty() {
return Err(CvError::detection_failed("no saturated pixels found"));
}
let (centres, assignments) = self.kmeans_hsv(&filtered)?;
let mut cluster_counts = vec![0usize; centres.len()];
for &a in &assignments {
if a < cluster_counts.len() {
cluster_counts[a] += 1;
}
}
let primary_idx = cluster_counts
.iter()
.enumerate()
.max_by_key(|(_, &c)| c)
.map(|(i, _)| i)
.ok_or_else(|| CvError::detection_failed("k-means produced no clusters"))?;
let primary_hsv = centres[primary_idx];
let primary_rgb = primary_hsv.to_rgb();
let total_pixels = filtered.len() as f32;
let coverage = cluster_counts[primary_idx] as f32 / total_pixels;
let secondary_color = if coverage < 0.70 && centres.len() >= 2 {
let secondary_idx = cluster_counts
.iter()
.enumerate()
.filter(|(i, _)| *i != primary_idx)
.max_by_key(|(_, &c)| c)
.map(|(i, _)| i);
secondary_idx.map(|idx| {
let sec_hsv = centres[idx];
sec_hsv.to_rgb()
})
} else {
None
};
let confidence = (coverage * primary_hsv.s).clamp(0.0, 1.0);
let screen_type_detector = ScreenTypeDetector::new();
let screen_type = screen_type_detector.detect(&primary_rgb);
Ok(BackgroundColorAnalysis {
primary_color: primary_rgb,
secondary_color,
confidence,
screen_type,
coverage,
})
}
fn sample_border_pixels(&self, frame: &VideoFrame) -> CvResult<Vec<Rgb>> {
let w = frame.width as usize;
let h = frame.height as usize;
let border = (w.min(h) / 10).max(10);
match frame.format {
PixelFormat::Rgb24 => {
if frame.planes.is_empty() {
return Err(CvError::invalid_parameter("planes", "empty"));
}
let data = &frame.planes[0].data;
let stride = frame.planes[0].stride;
let mut pixels = Vec::new();
for y in (0..h).filter(|&y| y < border || y >= h - border) {
for x in 0..w {
let idx = y * stride + x * 3;
if idx + 2 < data.len() {
pixels.push(Rgb::new(
data[idx] as f32 / 255.0,
data[idx + 1] as f32 / 255.0,
data[idx + 2] as f32 / 255.0,
));
}
}
}
for y in border..h - border {
for x in (0..w).filter(|&x| x < border || x >= w - border) {
let idx = y * stride + x * 3;
if idx + 2 < data.len() {
pixels.push(Rgb::new(
data[idx] as f32 / 255.0,
data[idx + 1] as f32 / 255.0,
data[idx + 2] as f32 / 255.0,
));
}
}
}
Ok(pixels)
}
PixelFormat::Rgba32 => {
if frame.planes.is_empty() {
return Err(CvError::invalid_parameter("planes", "empty"));
}
let data = &frame.planes[0].data;
let stride = frame.planes[0].stride;
let mut pixels = Vec::new();
for y in (0..h).filter(|&y| y < border || y >= h - border) {
for x in 0..w {
let idx = y * stride + x * 4;
if idx + 2 < data.len() {
pixels.push(Rgb::new(
data[idx] as f32 / 255.0,
data[idx + 1] as f32 / 255.0,
data[idx + 2] as f32 / 255.0,
));
}
}
}
for y in border..h - border {
for x in (0..w).filter(|&x| x < border || x >= w - border) {
let idx = y * stride + x * 4;
if idx + 2 < data.len() {
pixels.push(Rgb::new(
data[idx] as f32 / 255.0,
data[idx + 1] as f32 / 255.0,
data[idx + 2] as f32 / 255.0,
));
}
}
}
Ok(pixels)
}
_ => Err(CvError::unsupported_format(format!("{}", frame.format))),
}
}
fn kmeans_hsv(&self, pixels: &[Hsv]) -> CvResult<(Vec<Hsv>, Vec<usize>)> {
let k = self.num_clusters.min(pixels.len());
if k == 0 {
return Err(CvError::detection_failed("no pixels to cluster"));
}
let mut centres: Vec<Hsv> = (0..k).map(|i| pixels[i * pixels.len() / k]).collect();
let mut assignments = vec![0usize; pixels.len()];
for _iter in 0..self.max_iterations {
let mut changed = false;
for (i, pixel) in pixels.iter().enumerate() {
let nearest = centres
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| {
hsv_distance(pixel, a)
.partial_cmp(&hsv_distance(pixel, b))
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(idx, _)| idx)
.unwrap_or(0);
if assignments[i] != nearest {
assignments[i] = nearest;
changed = true;
}
}
if !changed {
break;
}
let old_centres = centres.clone();
for c in 0..k {
let cluster_pixels: Vec<&Hsv> = pixels
.iter()
.enumerate()
.filter(|(i, _)| assignments[*i] == c)
.map(|(_, p)| p)
.collect();
if cluster_pixels.is_empty() {
continue;
}
let n = cluster_pixels.len() as f32;
let sin_sum: f32 = cluster_pixels
.iter()
.map(|p| (p.h * std::f32::consts::PI / 180.0).sin())
.sum();
let cos_sum: f32 = cluster_pixels
.iter()
.map(|p| (p.h * std::f32::consts::PI / 180.0).cos())
.sum();
let mean_h = sin_sum.atan2(cos_sum).to_degrees();
let mean_h = if mean_h < 0.0 { mean_h + 360.0 } else { mean_h };
let mean_s = cluster_pixels.iter().map(|p| p.s).sum::<f32>() / n;
let mean_v = cluster_pixels.iter().map(|p| p.v).sum::<f32>() / n;
centres[c] = Hsv::new(mean_h, mean_s, mean_v);
}
let max_move = centres
.iter()
.zip(old_centres.iter())
.map(|(a, b)| hsv_distance(a, b))
.fold(0.0_f32, f32::max);
if max_move < self.convergence_threshold {
break;
}
}
Ok((centres, assignments))
}
}
impl Default for KmeansBackgroundDetector {
fn default() -> Self {
Self::new()
}
}
fn hsv_distance(a: &Hsv, b: &Hsv) -> f32 {
let dh = {
let raw = (a.h - b.h).abs();
if raw > 180.0 {
360.0 - raw
} else {
raw
}
} / 360.0;
let ds = a.s - b.s;
let dv = a.v - b.v;
(2.0 * dh * dh + ds * ds + dv * dv).sqrt()
}
pub struct MultiFrameDetector {
auto_detector: AutoKeyDetector,
samples: Vec<Rgb>,
max_samples: usize,
}
impl MultiFrameDetector {
#[must_use]
pub fn new(max_samples: usize) -> Self {
Self {
auto_detector: AutoKeyDetector::new(),
samples: Vec::new(),
max_samples: max_samples.max(1),
}
}
pub fn add_frame(&mut self, frame: &VideoFrame) -> CvResult<()> {
let color = self.auto_detector.detect_from_edges(frame)?;
self.samples.push(color);
if self.samples.len() > self.max_samples {
self.samples.remove(0);
}
Ok(())
}
pub fn get_key_color(&self) -> CvResult<Rgb> {
if self.samples.is_empty() {
return Err(CvError::detection_failed("No frames have been sampled"));
}
self.auto_detector.find_mode_color(&self.samples)
}
#[must_use]
pub fn sample_count(&self) -> usize {
self.samples.len()
}
#[must_use]
pub fn is_ready(&self) -> bool {
self.samples.len() >= self.max_samples / 2
}
pub fn reset(&mut self) {
self.samples.clear();
}
}