use crate::classify::content::{ContentClassifier, ContentType};
use crate::classify::quality::{QualityClassifier, QualityMetrics};
use crate::classify::scene::{SceneClassifier, SceneType};
use crate::error::{SceneError, SceneResult};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone)]
pub struct BatchConfig {
pub classify_scene: bool,
pub classify_quality: bool,
pub classify_content: bool,
pub parallel_threshold: usize,
}
impl Default for BatchConfig {
fn default() -> Self {
Self {
classify_scene: true,
classify_quality: true,
classify_content: true,
parallel_threshold: 4,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchFrameResult {
pub frame_index: usize,
pub scene_type: Option<SceneType>,
pub quality: Option<QualityMetrics>,
pub content_type: Option<ContentType>,
}
pub struct FrameRef<'a> {
pub rgb_data: &'a [u8],
pub width: usize,
pub height: usize,
}
impl<'a> FrameRef<'a> {
#[must_use]
pub fn new(rgb_data: &'a [u8], width: usize, height: usize) -> Self {
Self {
rgb_data,
width,
height,
}
}
}
pub struct BatchClassifier {
config: BatchConfig,
scene_classifier: SceneClassifier,
quality_classifier: QualityClassifier,
content_classifier: ContentClassifier,
}
impl BatchClassifier {
#[must_use]
pub fn new() -> Self {
Self {
config: BatchConfig::default(),
scene_classifier: SceneClassifier::new(),
quality_classifier: QualityClassifier::new(),
content_classifier: ContentClassifier::new(),
}
}
#[must_use]
pub fn with_config(config: BatchConfig) -> Self {
Self {
config,
scene_classifier: SceneClassifier::new(),
quality_classifier: QualityClassifier::new(),
content_classifier: ContentClassifier::new(),
}
}
fn classify_frame(
&self,
frame: &FrameRef<'_>,
frame_index: usize,
) -> SceneResult<BatchFrameResult> {
let expected = frame.width * frame.height * 3;
if frame.rgb_data.len() != expected {
return Err(SceneError::InvalidDimensions(format!(
"Frame {}: expected {} bytes for {}x{}, got {}",
frame_index,
expected,
frame.width,
frame.height,
frame.rgb_data.len()
)));
}
let scene_type = if self.config.classify_scene {
Some(
self.scene_classifier
.classify(frame.rgb_data, frame.width, frame.height)?
.scene_type,
)
} else {
None
};
let quality = if self.config.classify_quality {
Some(
self.quality_classifier
.analyze(frame.rgb_data, frame.width, frame.height)?,
)
} else {
None
};
let content_type = if self.config.classify_content {
let frames_slice: Vec<&[u8]> = vec![frame.rgb_data; 3];
Some(
self.content_classifier
.classify(&frames_slice, frame.width, frame.height)?
.content_type,
)
} else {
None
};
Ok(BatchFrameResult {
frame_index,
scene_type,
quality,
content_type,
})
}
pub fn classify_batch(&self, frames: &[FrameRef<'_>]) -> SceneResult<Vec<BatchFrameResult>> {
if frames.is_empty() {
return Ok(Vec::new());
}
if frames.len() >= self.config.parallel_threshold {
self.classify_batch_parallel(frames)
} else {
self.classify_batch_sequential(frames)
}
}
fn classify_batch_sequential(
&self,
frames: &[FrameRef<'_>],
) -> SceneResult<Vec<BatchFrameResult>> {
let mut results = Vec::with_capacity(frames.len());
for (i, frame) in frames.iter().enumerate() {
results.push(self.classify_frame(frame, i)?);
}
Ok(results)
}
fn classify_batch_parallel(
&self,
frames: &[FrameRef<'_>],
) -> SceneResult<Vec<BatchFrameResult>> {
let results: Vec<SceneResult<BatchFrameResult>> = frames
.par_iter()
.enumerate()
.map(|(i, frame)| self.classify_frame(frame, i))
.collect();
let mut collected = Vec::with_capacity(results.len());
for r in results {
collected.push(r?);
}
Ok(collected)
}
#[must_use]
pub fn config(&self) -> &BatchConfig {
&self.config
}
}
impl Default for BatchClassifier {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn uniform_frame(w: usize, h: usize, v: u8) -> Vec<u8> {
vec![v; w * h * 3]
}
fn gradient_frame(w: usize, h: usize) -> Vec<u8> {
let mut data = vec![0u8; w * h * 3];
for y in 0..h {
for x in 0..w {
let idx = (y * w + x) * 3;
data[idx] = (x * 255 / w.max(1)) as u8;
data[idx + 1] = (y * 255 / h.max(1)) as u8;
data[idx + 2] = 128;
}
}
data
}
#[test]
fn test_batch_empty() {
let classifier = BatchClassifier::new();
let results = classifier.classify_batch(&[]);
assert!(results.is_ok());
let r = results.expect("should succeed");
assert!(r.is_empty());
}
#[test]
fn test_batch_single_frame() {
let classifier = BatchClassifier::new();
let data = uniform_frame(100, 100, 128);
let frames = vec![FrameRef::new(&data, 100, 100)];
let results = classifier.classify_batch(&frames);
assert!(results.is_ok());
let r = results.expect("should succeed");
assert_eq!(r.len(), 1);
assert_eq!(r[0].frame_index, 0);
assert!(r[0].scene_type.is_some());
assert!(r[0].quality.is_some());
assert!(r[0].content_type.is_some());
}
#[test]
fn test_batch_sequential() {
let config = BatchConfig {
parallel_threshold: 100, ..Default::default()
};
let classifier = BatchClassifier::with_config(config);
let d1 = uniform_frame(50, 50, 100);
let d2 = uniform_frame(50, 50, 200);
let d3 = gradient_frame(50, 50);
let frames = vec![
FrameRef::new(&d1, 50, 50),
FrameRef::new(&d2, 50, 50),
FrameRef::new(&d3, 50, 50),
];
let results = classifier.classify_batch(&frames);
assert!(results.is_ok());
let r = results.expect("should succeed");
assert_eq!(r.len(), 3);
for (i, res) in r.iter().enumerate() {
assert_eq!(res.frame_index, i);
}
}
#[test]
fn test_batch_parallel() {
let config = BatchConfig {
parallel_threshold: 2, ..Default::default()
};
let classifier = BatchClassifier::with_config(config);
let frames_data: Vec<Vec<u8>> = (0..5)
.map(|v| uniform_frame(40, 40, (v * 50) as u8))
.collect();
let frames: Vec<FrameRef<'_>> = frames_data
.iter()
.map(|d| FrameRef::new(d, 40, 40))
.collect();
let results = classifier.classify_batch(&frames);
assert!(results.is_ok());
let r = results.expect("should succeed");
assert_eq!(r.len(), 5);
}
#[test]
fn test_batch_invalid_dimensions() {
let classifier = BatchClassifier::new();
let data = vec![0u8; 10]; let frames = vec![FrameRef::new(&data, 100, 100)];
let results = classifier.classify_batch(&frames);
assert!(results.is_err());
}
#[test]
fn test_batch_scene_only() {
let config = BatchConfig {
classify_scene: true,
classify_quality: false,
classify_content: false,
parallel_threshold: 10,
};
let classifier = BatchClassifier::with_config(config);
let data = uniform_frame(60, 60, 128);
let frames = vec![FrameRef::new(&data, 60, 60)];
let results = classifier.classify_batch(&frames);
assert!(results.is_ok());
let r = results.expect("should succeed");
assert!(r[0].scene_type.is_some());
assert!(r[0].quality.is_none());
assert!(r[0].content_type.is_none());
}
#[test]
fn test_batch_quality_only() {
let config = BatchConfig {
classify_scene: false,
classify_quality: true,
classify_content: false,
parallel_threshold: 10,
};
let classifier = BatchClassifier::with_config(config);
let data = gradient_frame(80, 80);
let frames = vec![FrameRef::new(&data, 80, 80)];
let results = classifier.classify_batch(&frames);
assert!(results.is_ok());
let r = results.expect("should succeed");
assert!(r[0].scene_type.is_none());
assert!(r[0].quality.is_some());
assert!(r[0].content_type.is_none());
}
#[test]
fn test_batch_config_defaults() {
let config = BatchConfig::default();
assert!(config.classify_scene);
assert!(config.classify_quality);
assert!(config.classify_content);
assert_eq!(config.parallel_threshold, 4);
}
#[test]
fn test_batch_parallel_frame_indices() {
let config = BatchConfig {
parallel_threshold: 1, ..Default::default()
};
let classifier = BatchClassifier::with_config(config);
let frames_data: Vec<Vec<u8>> = (0..8)
.map(|v| uniform_frame(30, 30, (v * 30) as u8))
.collect();
let frames: Vec<FrameRef<'_>> = frames_data
.iter()
.map(|d| FrameRef::new(d, 30, 30))
.collect();
let results = classifier.classify_batch(&frames).expect("should succeed");
let mut indices: Vec<usize> = results.iter().map(|r| r.frame_index).collect();
indices.sort_unstable();
assert_eq!(indices, (0..8).collect::<Vec<_>>());
}
#[test]
fn test_batch_config_accessor() {
let config = BatchConfig {
classify_scene: false,
classify_quality: false,
classify_content: true,
parallel_threshold: 42,
};
let classifier = BatchClassifier::with_config(config);
assert!(!classifier.config().classify_scene);
assert!(!classifier.config().classify_quality);
assert!(classifier.config().classify_content);
assert_eq!(classifier.config().parallel_threshold, 42);
}
#[test]
fn test_classification_temporal_consistency() {
use std::collections::HashMap;
const W: usize = 32;
const H: usize = 32;
const N: usize = 10;
const MIN_DOMINANT_FRACTION: usize = 9;
let base_r: u8 = 200;
let base_g: u8 = 140;
let base_b: u8 = 80;
let frames_data: Vec<Vec<u8>> = (0..N)
.map(|frame_idx| {
let mut data = vec![0u8; W * H * 3];
for (i, chunk) in data.chunks_mut(3).enumerate() {
let noise = ((i + frame_idx) % 3) as u8;
chunk[0] = base_r.saturating_add(noise);
chunk[1] = base_g.saturating_add(noise);
chunk[2] = base_b;
}
data
})
.collect();
let config = BatchConfig {
parallel_threshold: N + 1, ..BatchConfig::default()
};
let classifier = BatchClassifier::with_config(config);
let frames: Vec<FrameRef<'_>> =
frames_data.iter().map(|d| FrameRef::new(d, W, H)).collect();
let results = classifier
.classify_batch(&frames)
.expect("batch classification should succeed");
assert_eq!(results.len(), N, "all frames should produce a result");
let mut label_counts: HashMap<String, usize> = HashMap::new();
for r in &results {
if let Some(scene) = r.scene_type {
*label_counts.entry(format!("{scene:?}")).or_insert(0) += 1;
}
}
let dominant_count = label_counts.values().copied().max().unwrap_or(0);
assert!(
dominant_count >= MIN_DOMINANT_FRACTION,
"dominant label should appear in >= {MIN_DOMINANT_FRACTION}/{N} frames \
(no flicker), got {dominant_count} — label counts: {label_counts:?}"
);
}
}