use core::{num::NonZeroUsize, time::Duration};
use derive_more::IsVariant;
use thiserror::Error;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::frame::{LumaFrame, Timebase, Timestamp};
use std::{vec, vec::Vec};
#[derive(Debug, Clone, Copy, PartialEq, IsVariant, Error)]
#[non_exhaustive]
pub enum Error {
#[error("histogram bin count ({bins}) is too large")]
BinCountTooLarge {
bins: usize,
},
#[error("threshold ({threshold}) must be in [0.0, 1.0]")]
ThresholdOutOfRange {
threshold: f64,
},
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct Options {
threshold: f64,
bins: NonZeroUsize,
#[cfg_attr(feature = "serde", serde(with = "humantime_serde"))]
min_duration: Duration,
initial_cut: bool,
}
impl Default for Options {
#[cfg_attr(not(tarpaulin), inline(always))]
fn default() -> Self {
Self::new()
}
}
impl Options {
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn new() -> Self {
Self {
threshold: 0.5,
bins: NonZeroUsize::new(256).unwrap(),
min_duration: Duration::from_secs(1),
initial_cut: true,
}
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn threshold(&self) -> f64 {
self.threshold
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn with_threshold(mut self, val: f64) -> Self {
self.set_threshold(val);
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn set_threshold(&mut self, val: f64) -> &mut Self {
self.threshold = val;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn bins(&self) -> usize {
self.bins.get()
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn with_bins(mut self, val: NonZeroUsize) -> Self {
self.set_bins(val);
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn set_bins(&mut self, val: NonZeroUsize) -> &mut Self {
self.bins = val;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn min_duration(&self) -> Duration {
self.min_duration
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn with_min_duration(mut self, val: Duration) -> Self {
self.set_min_duration(val);
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn set_min_duration(&mut self, val: Duration) -> &mut Self {
self.min_duration = val;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn with_min_frames(mut self, frames: u32, fps: Timebase) -> Self {
self.set_min_frames(frames, fps);
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn set_min_frames(&mut self, frames: u32, fps: Timebase) -> &mut Self {
self.min_duration = fps.frames_to_duration(frames);
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn initial_cut(&self) -> bool {
self.initial_cut
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn with_initial_cut(mut self, val: bool) -> Self {
self.initial_cut = val;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn set_initial_cut(&mut self, val: bool) -> &mut Self {
self.initial_cut = val;
self
}
}
const N_ACCUM: usize = 4;
#[derive(Debug, Clone)]
pub struct Detector {
options: Options,
corr_threshold: f64,
bin_of: [u32; 256],
scratch: Vec<u32>,
current: Vec<u32>,
previous: Vec<u32>,
has_previous: bool,
last_cut_ts: Option<Timestamp>,
last_hist_diff: Option<f64>,
}
impl Detector {
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn new(options: Options) -> Self {
Self::try_new(options).expect("invalid histogram::Options")
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn try_new(options: Options) -> Result<Self, Error> {
let threshold = options.threshold;
if !(0.0..=1.0).contains(&threshold) {
return Err(Error::ThresholdOutOfRange { threshold });
}
let bins = options.bins.get();
if bins > u32::MAX as usize {
return Err(Error::BinCountTooLarge { bins });
}
let scratch_len = N_ACCUM
.checked_mul(bins)
.ok_or(Error::BinCountTooLarge { bins })?;
let corr_threshold = (1.0 - threshold).clamp(0.0, 1.0);
let bin_of = build_bin_lookup(bins);
Ok(Self {
options,
corr_threshold,
bin_of,
scratch: vec![0u32; scratch_len],
current: vec![0u32; bins],
previous: vec![0u32; bins],
has_previous: false,
last_cut_ts: None,
last_hist_diff: None,
})
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn options(&self) -> &Options {
&self.options
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn last_hist_diff(&self) -> Option<f64> {
self.last_hist_diff
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn clear(&mut self) {
self.has_previous = false;
self.last_cut_ts = None;
self.last_hist_diff = None;
}
pub fn process(&mut self, frame: LumaFrame<'_>) -> Option<Timestamp> {
let ts = frame.timestamp();
if self.last_cut_ts.is_none() {
self.last_cut_ts = Some(if self.options.initial_cut {
ts.saturating_sub_duration(self.options.min_duration)
} else {
ts
});
}
self.compute_histogram(&frame);
let mut cut: Option<Timestamp> = None;
if self.has_previous {
let diff = correlation(&self.previous, &self.current);
self.last_hist_diff = Some(diff);
let min_elapsed = self
.last_cut_ts
.as_ref()
.and_then(|last| ts.duration_since(last))
.is_some_and(|d| d >= self.options.min_duration);
if diff <= self.corr_threshold && min_elapsed {
cut = Some(ts);
self.last_cut_ts = Some(ts);
}
}
core::mem::swap(&mut self.current, &mut self.previous);
self.has_previous = true;
cut
}
fn compute_histogram(&mut self, frame: &LumaFrame<'_>) {
let bins = self.options.bins.get();
let data = frame.data();
let w = frame.width() as usize;
let h = frame.height() as usize;
let s = frame.stride() as usize;
let scratch = &mut self.scratch;
let current = &mut self.current;
let bin_of = &self.bin_of;
debug_assert_eq!(scratch.len(), N_ACCUM * bins);
debug_assert_eq!(current.len(), bins);
scratch.fill(0);
let (acc0, rest) = scratch.split_at_mut(bins);
let (acc1, rest) = rest.split_at_mut(bins);
let (acc2, acc3) = rest.split_at_mut(bins);
for y in 0..h {
let row_start = y * s;
let row = &data[row_start..row_start + w];
let chunks = row.chunks_exact(N_ACCUM);
let remainder = chunks.remainder();
for chunk in chunks {
acc0[bin_of[chunk[0] as usize] as usize] += 1;
acc1[bin_of[chunk[1] as usize] as usize] += 1;
acc2[bin_of[chunk[2] as usize] as usize] += 1;
acc3[bin_of[chunk[3] as usize] as usize] += 1;
}
for (i, &v) in remainder.iter().enumerate() {
let idx = bin_of[v as usize] as usize;
match i {
0 => acc0[idx] += 1,
1 => acc1[idx] += 1,
2 => acc2[idx] += 1,
_ => acc3[idx] += 1,
}
}
}
for j in 0..bins {
current[j] = acc0[j] + acc1[j] + acc2[j] + acc3[j];
}
}
}
fn build_bin_lookup(bins: usize) -> [u32; 256] {
let mut t = [0u32; 256];
let b = bins as u64;
let mut v = 0usize;
while v < 256 {
t[v] = ((v as u64 * b) / 256) as u32;
v += 1;
}
t
}
fn correlation(a: &[u32], b: &[u32]) -> f64 {
debug_assert_eq!(a.len(), b.len());
let n = a.len() as f64;
let sum_a: u64 = a.iter().map(|&x| x as u64).sum();
let sum_b: u64 = b.iter().map(|&x| x as u64).sum();
let mean_a = sum_a as f64 / n;
let mean_b = sum_b as f64 / n;
let mut num = 0.0;
let mut var_a = 0.0;
let mut var_b = 0.0;
for (&x, &y) in a.iter().zip(b.iter()) {
let da = x as f64 - mean_a;
let db = y as f64 - mean_b;
num += da * db;
var_a += da * da;
var_b += db * db;
}
if var_a == 0.0 && var_b == 0.0 {
return if a == b { 1.0 } else { 0.0 };
}
if var_a == 0.0 || var_b == 0.0 {
return 0.0;
}
num / super::sqrt_64(var_a * var_b)
}
#[cfg(all(test, feature = "std"))]
mod tests {
use super::*;
use crate::frame::Timebase;
use core::num::NonZeroU32;
const fn nz32(n: u32) -> NonZeroU32 {
match NonZeroU32::new(n) {
Some(v) => v,
None => panic!("zero"),
}
}
fn make_frame<'a>(data: &'a [u8], w: u32, h: u32, pts: i64) -> LumaFrame<'a> {
let tb = Timebase::new(1, nz32(1000)); LumaFrame::new(data, w, h, w, Timestamp::new(pts, tb))
}
#[test]
fn identical_frames_produce_no_cut() {
let mut det = Detector::new(Options::default());
let buf = [128u8; 64 * 48];
assert!(det.process(make_frame(&buf, 64, 48, 0)).is_none());
assert!(det.process(make_frame(&buf, 64, 48, 2000)).is_none());
assert!(det.process(make_frame(&buf, 64, 48, 4000)).is_none());
assert_eq!(det.last_hist_diff(), Some(1.0));
}
#[test]
fn very_different_frames_produce_cut() {
let opts = Options::default().with_min_duration(Duration::from_millis(0));
let mut det = Detector::new(opts);
let black = [0u8; 64 * 48];
let white = [255u8; 64 * 48];
assert!(det.process(make_frame(&black, 64, 48, 0)).is_none());
let cut = det.process(make_frame(&white, 64, 48, 33));
assert!(
cut.is_some(),
"expected a cut at the black→white transition"
);
assert_eq!(cut.unwrap().pts(), 33);
}
#[test]
fn min_duration_suppresses_rapid_cuts() {
let opts = Options::default()
.with_min_duration(Duration::from_secs(1))
.with_initial_cut(false);
let mut det = Detector::new(opts);
let black = [0u8; 64 * 48];
let white = [255u8; 64 * 48];
let mut cuts = 0u32;
for i in 0..30i64 {
let frame_data = if i % 2 == 0 { &black } else { &white };
let ts = i * 33; if det.process(make_frame(frame_data, 64, 48, ts)).is_some() {
cuts += 1;
}
}
assert_eq!(cuts, 0, "min_duration should suppress all cuts within 1s");
}
#[test]
fn cut_reported_after_min_duration_elapsed() {
let opts = Options::default()
.with_min_duration(Duration::from_millis(500))
.with_initial_cut(false);
let mut det = Detector::new(opts);
let black = [0u8; 64 * 48];
let white = [255u8; 64 * 48];
assert!(det.process(make_frame(&black, 64, 48, 0)).is_none());
assert!(det.process(make_frame(&white, 64, 48, 100)).is_none());
let cut = det.process(make_frame(&black, 64, 48, 600));
assert!(cut.is_some(), "expected cut after min_duration elapsed");
}
#[test]
fn clear_resets_stream_state() {
let opts = Options::default().with_min_duration(Duration::from_millis(0));
let mut det = Detector::new(opts);
let black = [0u8; 64 * 48];
let white = [255u8; 64 * 48];
assert!(det.process(make_frame(&black, 64, 48, 0)).is_none());
let cut = det.process(make_frame(&white, 64, 48, 33));
assert!(cut.is_some());
assert!(det.last_hist_diff().is_some());
det.clear();
assert!(det.process(make_frame(&black, 64, 48, 1_000_000)).is_none());
assert!(
det.last_hist_diff().is_none(),
"last_hist_diff should be cleared"
);
let cut2 = det.process(make_frame(&white, 64, 48, 1_000_033));
assert!(cut2.is_some(), "cut should still be detected on video 2");
}
#[test]
fn compute_histogram_respects_stride() {
let mut buf = [0xFFu8; 8 * 2];
buf[0..4].copy_from_slice(&[10, 20, 30, 40]);
buf[8..12].copy_from_slice(&[50, 60, 70, 80]);
let mut det = Detector::new(Options::default());
let tb = Timebase::new(1, nz32(1000));
let frame = LumaFrame::new(&buf, 4, 2, 8, Timestamp::new(0, tb));
det.compute_histogram(&frame);
for v in [10, 20, 30, 40, 50, 60, 70, 80] {
assert_eq!(det.current[v as usize], 1);
}
assert_eq!(det.current[0xFF], 0, "padding must not be counted");
assert_eq!(det.current.iter().sum::<u32>(), 8);
}
#[test]
fn compute_histogram_remainder_path() {
let mut buf = [0u8; 7 * 3];
for (i, b) in buf.iter_mut().enumerate() {
*b = i as u8; }
let mut det = Detector::new(Options::default());
let tb = Timebase::new(1, nz32(1000));
let frame = LumaFrame::new(&buf, 7, 3, 7, Timestamp::new(0, tb));
det.compute_histogram(&frame);
for v in 0u8..21 {
assert_eq!(
det.current[v as usize], 1,
"pixel value {v} should have count 1"
);
}
assert_eq!(det.current.iter().sum::<u32>(), 21);
}
#[test]
fn build_bin_lookup_matches_formula() {
let t = build_bin_lookup(256);
for v in 0..=255u32 {
assert_eq!(t[v as usize], v);
}
let t = build_bin_lookup(128);
for v in 0..=255u32 {
assert_eq!(t[v as usize], v / 2);
}
let t = build_bin_lookup(1);
for v in 0..=255u32 {
assert_eq!(t[v as usize], 0);
}
}
#[test]
fn correlation_of_identical_is_one() {
let a: Vec<u32> = vec![1, 2, 3, 4, 5];
assert!((correlation(&a, &a) - 1.0).abs() < 1e-12);
}
#[test]
fn with_min_frames_matches_python_default() {
let fps = Timebase::new(30, nz32(1));
let opts = Options::default().with_min_frames(15, fps);
assert_eq!(opts.min_duration(), Duration::from_millis(500));
}
#[test]
fn with_min_frames_ntsc() {
let fps = Timebase::new(30_000, nz32(1001));
let opts = Options::default().with_min_frames(15, fps);
assert_eq!(opts.min_duration(), Duration::from_nanos(500_500_000));
}
#[test]
fn correlation_of_flat_frames() {
let a = vec![4u32; 256];
let b = vec![4u32; 256];
assert_eq!(correlation(&a, &b), 1.0);
let c = vec![7u32; 256];
assert_eq!(correlation(&a, &c), 0.0); }
#[test]
fn try_new_rejects_overflowing_bin_count() {
let opts = Options::default().with_bins(NonZeroUsize::new(usize::MAX).unwrap());
let err = Detector::try_new(opts).expect_err("should fail");
assert_eq!(err, Error::BinCountTooLarge { bins: usize::MAX });
}
#[test]
fn options_accessors_builders_setters_roundtrip() {
let fps30 = Timebase::new(30, nz32(1));
let opts = Options::default()
.with_threshold(0.42)
.with_bins(core::num::NonZeroUsize::new(128).unwrap())
.with_min_duration(core::time::Duration::from_millis(500))
.with_initial_cut(false);
assert_eq!(opts.threshold(), 0.42);
assert_eq!(opts.bins(), 128);
assert_eq!(opts.min_duration(), core::time::Duration::from_millis(500));
assert!(!opts.initial_cut());
let opts_frames = Options::default().with_min_frames(15, fps30);
assert_eq!(
opts_frames.min_duration(),
core::time::Duration::from_millis(500)
);
let mut opts = Options::default();
opts
.set_threshold(0.1)
.set_bins(core::num::NonZeroUsize::new(64).unwrap())
.set_min_duration(core::time::Duration::from_secs(1))
.set_initial_cut(true);
assert_eq!(opts.threshold(), 0.1);
assert_eq!(opts.bins(), 64);
assert!(opts.initial_cut());
opts.set_min_frames(30, fps30);
assert_eq!(opts.min_duration(), core::time::Duration::from_secs(1));
}
#[test]
fn detector_options_and_last_hist_diff_accessors() {
let opts = Options::default().with_min_duration(core::time::Duration::from_millis(0));
let mut det = Detector::new(opts.clone());
assert_eq!(det.options().threshold(), opts.threshold());
assert!(det.last_hist_diff().is_none());
let buf = vec![64u8; 32 * 32];
det.process(make_frame(&buf, 32, 32, 0));
det.process(make_frame(&buf, 32, 32, 33));
assert!(det.last_hist_diff().is_some());
}
#[test]
fn histogram_tail_three_exercises_three_remainder_pixels() {
let buf = vec![100u8; 35];
let mut det =
Detector::new(Options::default().with_min_duration(core::time::Duration::from_millis(0)));
det.process(make_frame(&buf, 7, 5, 0));
det.process(make_frame(&buf, 7, 5, 33));
assert_eq!(det.last_hist_diff(), Some(1.0));
}
}