use std::path::{Path, PathBuf};
use std::sync::Arc;
use crate::device::DeviceType;
use crate::error::{MlError, MlResult};
use crate::model::OnnxModel;
use crate::pipeline::{PipelineInfo, PipelineTask, TypedPipeline};
#[cfg(feature = "onnx")]
use crate::postprocess::sigmoid_slice;
use crate::preprocess::{ImagePreprocessor, InputRange, TensorLayout};
#[derive(Clone, Debug)]
pub struct ShotBoundaryConfig {
pub frame_size: (u32, u32),
pub window: usize,
pub threshold: f32,
pub min_gap: usize,
pub input_name: Option<String>,
pub output_name: Option<String>,
}
impl Default for ShotBoundaryConfig {
fn default() -> Self {
Self {
frame_size: (48, 27),
window: 100,
threshold: 0.5,
min_gap: 3,
input_name: None,
output_name: None,
}
}
}
#[derive(Clone, Copy, Debug)]
pub struct ShotBoundary {
pub frame_index: usize,
pub confidence: f32,
}
#[derive(Clone, Debug)]
pub struct ShotFrame {
pub pixels: Vec<u8>,
pub width: u32,
pub height: u32,
}
impl ShotFrame {
pub fn new(pixels: Vec<u8>, width: u32, height: u32) -> MlResult<Self> {
let expected = (width as usize) * (height as usize) * 3;
if pixels.len() != expected {
return Err(MlError::invalid_input(format!(
"shot frame: expected {expected} bytes for {width}x{height} RGB, got {}",
pixels.len()
)));
}
Ok(Self {
pixels,
width,
height,
})
}
}
pub struct ShotBoundaryDetector {
model: Option<Arc<OnnxModel>>,
config: ShotBoundaryConfig,
preprocessor: ImagePreprocessor,
model_path: Option<PathBuf>,
}
impl ShotBoundaryDetector {
#[must_use]
pub fn heuristic(config: ShotBoundaryConfig) -> Self {
let (w, h) = config.frame_size;
let preprocessor = ImagePreprocessor::new(w, h)
.with_tensor_layout(TensorLayout::Nchw)
.with_input_range(InputRange::U8);
Self {
model: None,
config,
preprocessor,
model_path: None,
}
}
pub fn load(path: impl AsRef<Path>, device: DeviceType) -> MlResult<Self> {
Self::load_with_config(path, device, ShotBoundaryConfig::default())
}
pub fn load_with_config(
path: impl AsRef<Path>,
device: DeviceType,
config: ShotBoundaryConfig,
) -> MlResult<Self> {
let model_path = path.as_ref().to_path_buf();
let model = Arc::new(OnnxModel::load(&model_path, device)?);
Ok(Self::build(Some(model), config, Some(model_path)))
}
#[must_use]
pub fn from_shared(
model: Arc<OnnxModel>,
config: ShotBoundaryConfig,
model_path: PathBuf,
) -> Self {
Self::build(Some(model), config, Some(model_path))
}
fn build(
model: Option<Arc<OnnxModel>>,
config: ShotBoundaryConfig,
model_path: Option<PathBuf>,
) -> Self {
let (w, h) = config.frame_size;
let preprocessor = ImagePreprocessor::new(w, h)
.with_tensor_layout(TensorLayout::Nchw)
.with_input_range(InputRange::U8);
Self {
model,
config,
preprocessor,
model_path,
}
}
#[must_use]
pub fn model_path(&self) -> Option<&Path> {
self.model_path.as_deref()
}
#[must_use]
pub fn has_model(&self) -> bool {
self.model.is_some()
}
#[must_use]
pub fn threshold(&self) -> f32 {
self.config.threshold
}
fn filter_peaks(&self, probs: &[f32]) -> Vec<ShotBoundary> {
let mut out = Vec::new();
let mut last_emitted: Option<usize> = None;
for (idx, &p) in probs.iter().enumerate() {
if p < self.config.threshold {
continue;
}
if let Some(prev) = last_emitted {
if idx.saturating_sub(prev) < self.config.min_gap {
continue;
}
}
out.push(ShotBoundary {
frame_index: idx,
confidence: p,
});
last_emitted = Some(idx);
}
out
}
fn heuristic_probs(&self, frames: &[ShotFrame]) -> MlResult<Vec<f32>> {
if frames.is_empty() {
return Ok(Vec::new());
}
let mut probs = vec![0.0_f32; frames.len()];
let mut previous: Option<Vec<f32>> = None;
for (idx, frame) in frames.iter().enumerate() {
let buf = self
.preprocessor
.process_u8_rgb(&frame.pixels, frame.width, frame.height)?;
if let Some(prev) = &previous {
let mut sum = 0.0_f32;
let mut max_possible = 0.0_f32;
for (a, b) in prev.iter().zip(buf.iter()) {
sum += (a - b).abs();
max_possible += 1.0;
}
if max_possible > 0.0 {
let ratio = (sum / max_possible).clamp(0.0, 1.0);
probs[idx] = ratio;
}
}
previous = Some(buf);
}
Ok(probs)
}
}
impl TypedPipeline for ShotBoundaryDetector {
type Input = Vec<ShotFrame>;
type Output = Vec<ShotBoundary>;
fn run(&self, input: Self::Input) -> MlResult<Self::Output> {
if input.is_empty() {
return Ok(Vec::new());
}
#[cfg(feature = "onnx")]
{
if let Some(model) = self.model.as_ref() {
use oxionnx::Tensor;
use std::collections::HashMap;
let (w, h) = self.config.frame_size;
let mut batched: Vec<f32> =
Vec::with_capacity((w as usize) * (h as usize) * 3 * input.len());
for frame in &input {
let buf = self.preprocessor.process_u8_rgb(
&frame.pixels,
frame.width,
frame.height,
)?;
batched.extend(buf);
}
let shape = vec![1, input.len(), 3, h as usize, w as usize];
let tensor = Tensor {
data: batched,
shape,
};
let input_name = self
.config
.input_name
.clone()
.or_else(|| model.info().inputs.first().map(|s| s.name.clone()))
.ok_or_else(|| MlError::invalid_input("model has no declared inputs"))?;
let output_name = self
.config
.output_name
.clone()
.or_else(|| model.info().outputs.first().map(|s| s.name.clone()))
.ok_or_else(|| MlError::invalid_input("model has no declared outputs"))?;
let mut inputs: HashMap<&str, Tensor> = HashMap::with_capacity(1);
inputs.insert(input_name.as_str(), tensor);
let outputs = model.run(&inputs)?;
let out = outputs.get(&output_name).ok_or_else(|| {
MlError::postprocess(format!("output '{output_name}' missing from model run"))
})?;
let probs = sigmoid_slice(&out.data);
return Ok(self.filter_peaks(&probs));
}
}
let probs = self.heuristic_probs(&input)?;
Ok(self.filter_peaks(&probs))
}
fn info(&self) -> PipelineInfo {
PipelineInfo {
id: "shot-boundary/transnet-v2",
name: "Shot Boundary Detector",
task: PipelineTask::ShotBoundary,
input_size: Some(self.config.frame_size),
}
}
}
impl std::fmt::Debug for ShotBoundaryDetector {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ShotBoundaryDetector")
.field("model_path", &self.model_path)
.field("threshold", &self.config.threshold)
.field("window", &self.config.window)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn solid_frame(w: u32, h: u32, rgb: [u8; 3]) -> ShotFrame {
let mut buf = Vec::with_capacity((w as usize) * (h as usize) * 3);
for _ in 0..(w as usize * h as usize) {
buf.extend_from_slice(&rgb);
}
ShotFrame::new(buf, w, h).expect("valid frame")
}
#[test]
fn empty_input_returns_empty() {
let det = ShotBoundaryDetector::heuristic(ShotBoundaryConfig::default());
let out = det.run(Vec::new()).expect("ok");
assert!(out.is_empty());
}
#[test]
fn heuristic_detects_color_change() {
let det = ShotBoundaryDetector::heuristic(ShotBoundaryConfig {
threshold: 0.1,
min_gap: 0,
..Default::default()
});
let frames = vec![
solid_frame(48, 27, [0, 0, 0]),
solid_frame(48, 27, [0, 0, 0]),
solid_frame(48, 27, [255, 255, 255]),
solid_frame(48, 27, [255, 255, 255]),
];
let out = det.run(frames).expect("ok");
assert!(out.iter().any(|b| b.frame_index == 2));
}
#[test]
fn default_config_is_transnet_shaped() {
let cfg = ShotBoundaryConfig::default();
assert_eq!(cfg.frame_size, (48, 27));
assert_eq!(cfg.window, 100);
}
#[test]
fn shot_frame_rejects_wrong_buffer() {
let err = ShotFrame::new(vec![0u8; 10], 48, 27).expect_err("must fail");
assert!(matches!(err, MlError::InvalidInput(_)));
}
}