use crate::error::{CodecError, CodecResult};
const STATS_MAGIC: &[u8; 4] = b"THRS";
const STATS_ENTRY_SIZE: usize = 33;
const STATS_HEADER_SIZE: usize = 8;
#[derive(Debug, Clone)]
pub struct TwoPassConfig {
pub target_bitrate: u64,
pub framerate: f64,
pub keyframe_interval: u32,
pub quality_min: u8,
pub quality_max: u8,
}
impl TwoPassConfig {
#[must_use]
pub fn new(target_bitrate: u64, framerate: f64) -> Self {
Self {
target_bitrate,
framerate,
keyframe_interval: 64,
quality_min: 16,
quality_max: 56,
}
}
}
impl Default for TwoPassConfig {
fn default() -> Self {
Self::new(2_000_000, 30.0)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct TwoPassStats {
pub frame_index: u64,
pub dct_energy: f64,
pub motion_magnitude: f64,
pub is_scene_cut: bool,
pub frame_complexity: f64,
}
impl TwoPassStats {
fn to_bytes(&self) -> [u8; STATS_ENTRY_SIZE] {
let mut buf = [0u8; STATS_ENTRY_SIZE];
buf[0..8].copy_from_slice(&self.frame_index.to_le_bytes());
buf[8..16].copy_from_slice(&self.dct_energy.to_le_bytes());
buf[16..24].copy_from_slice(&self.motion_magnitude.to_le_bytes());
buf[24] = if self.is_scene_cut { 1 } else { 0 };
buf[25..33].copy_from_slice(&self.frame_complexity.to_le_bytes());
buf
}
fn from_bytes(src: &[u8]) -> CodecResult<Self> {
if src.len() < STATS_ENTRY_SIZE {
return Err(CodecError::InvalidBitstream(format!(
"TwoPassStats entry too short: {} < {}",
src.len(),
STATS_ENTRY_SIZE
)));
}
let frame_index =
u64::from_le_bytes(src[0..8].try_into().map_err(|_| {
CodecError::InvalidBitstream("frame_index slice error".to_string())
})?);
let dct_energy = f64::from_le_bytes(
src[8..16]
.try_into()
.map_err(|_| CodecError::InvalidBitstream("dct_energy slice error".to_string()))?,
);
let motion_magnitude = f64::from_le_bytes(src[16..24].try_into().map_err(|_| {
CodecError::InvalidBitstream("motion_magnitude slice error".to_string())
})?);
let is_scene_cut = src[24] != 0;
let frame_complexity = f64::from_le_bytes(src[25..33].try_into().map_err(|_| {
CodecError::InvalidBitstream("frame_complexity slice error".to_string())
})?);
Ok(Self {
frame_index,
dct_energy,
motion_magnitude,
is_scene_cut,
frame_complexity,
})
}
}
pub struct TheoraFirstPassAnalyzer {
config: TwoPassConfig,
stats: Vec<TwoPassStats>,
prev_luma: Option<Vec<u8>>,
prev_variance: f64,
}
impl TheoraFirstPassAnalyzer {
#[must_use]
pub fn new(config: TwoPassConfig) -> Self {
Self {
config,
stats: Vec::new(),
prev_luma: None,
prev_variance: 0.0,
}
}
pub fn analyze_frame(
&mut self,
y_plane: &[u8],
width: u32,
height: u32,
frame_idx: u64,
) -> TwoPassStats {
let dct_energy = compute_dct_energy(y_plane, width, height);
let curr_variance = compute_frame_variance(y_plane, width, height);
let (motion_magnitude, is_scene_cut) = match &self.prev_luma {
Some(prev) => {
let mad = compute_mad(prev, y_plane);
let scene_cut = is_scene_cut(mad, self.prev_variance, curr_variance);
(mad, scene_cut)
}
None => (0.0, false),
};
let frame_complexity = dct_energy * 0.6 + motion_magnitude * 40.0;
let entry = TwoPassStats {
frame_index: frame_idx,
dct_energy,
motion_magnitude,
is_scene_cut,
frame_complexity,
};
let luma_len = (width as usize).saturating_mul(height as usize);
let copy_len = luma_len.min(y_plane.len());
let mut luma_copy = vec![128u8; luma_len];
luma_copy[..copy_len].copy_from_slice(&y_plane[..copy_len]);
self.prev_luma = Some(luma_copy);
self.prev_variance = curr_variance;
self.stats.push(entry.clone());
entry
}
#[must_use]
pub fn collect_stats(&self) -> &[TwoPassStats] {
&self.stats
}
#[must_use]
pub fn serialize_stats(&self) -> Vec<u8> {
serialize_stats_slice(&self.stats)
}
pub fn deserialize_stats(data: &[u8]) -> CodecResult<Vec<TwoPassStats>> {
deserialize_stats_impl(data)
}
}
pub struct TheoraSecondPassEncoder {
config: TwoPassConfig,
stats: Vec<TwoPassStats>,
frame_bits: Vec<u32>,
base_bits_per_frame: f64,
}
impl TheoraSecondPassEncoder {
pub fn new(config: TwoPassConfig, stats: Vec<TwoPassStats>) -> CodecResult<Self> {
if stats.is_empty() {
return Err(CodecError::InvalidParameter(
"TwoPassStats must not be empty for second-pass encoding".to_string(),
));
}
if config.framerate <= 0.0 {
return Err(CodecError::InvalidParameter(
"framerate must be positive".to_string(),
));
}
if config.quality_min > config.quality_max {
return Err(CodecError::InvalidParameter(
"quality_min must not exceed quality_max".to_string(),
));
}
let base_bits_per_frame = config.target_bitrate as f64 / config.framerate;
let frame_bits = allocate_bits(&config, &stats, base_bits_per_frame);
Ok(Self {
config,
stats,
frame_bits,
base_bits_per_frame,
})
}
#[must_use]
pub fn get_frame_quality(&self, frame_idx: u64, _is_keyframe: bool) -> u8 {
let idx = frame_idx as usize;
if idx >= self.frame_bits.len() {
return self.config.quality_max;
}
bits_to_quality(
self.frame_bits[idx],
self.base_bits_per_frame,
self.config.quality_min,
self.config.quality_max,
)
}
#[must_use]
pub fn allocate_bits(&self, frame_idx: u64, _is_keyframe: bool) -> u32 {
let idx = frame_idx as usize;
if idx < self.frame_bits.len() {
self.frame_bits[idx]
} else {
self.base_bits_per_frame as u32
}
}
#[must_use]
pub fn total_frames(&self) -> usize {
self.stats.len()
}
#[must_use]
pub fn stats(&self) -> &[TwoPassStats] {
&self.stats
}
}
fn compute_dct_energy(y_plane: &[u8], width: u32, height: u32) -> f64 {
if width == 0 || height == 0 {
return 0.0;
}
let w = width as usize;
let h = height as usize;
let stride = w;
let mut total_energy = 0.0f64;
let block = 8usize;
let bx_count = w.div_ceil(block);
let by_count = h.div_ceil(block);
for by in 0..by_count {
for bx in 0..bx_count {
let x0 = bx * block;
let y0 = by * block;
let x1 = (x0 + block).min(w);
let y1 = (y0 + block).min(h);
let mut sum = 0u64;
let mut sum_sq = 0u64;
let mut count = 0u64;
for row in y0..y1 {
for col in x0..x1 {
let off = row * stride + col;
if off < y_plane.len() {
let v = u64::from(y_plane[off]);
sum += v;
sum_sq += v * v;
count += 1;
}
}
}
if count > 0 {
let mean = sum as f64 / count as f64;
let mean_sq = mean * mean;
let ex2 = sum_sq as f64 / count as f64;
let variance = (ex2 - mean_sq).max(0.0);
total_energy += variance;
}
}
}
total_energy
}
fn compute_frame_variance(y_plane: &[u8], width: u32, height: u32) -> f64 {
if width == 0 || height == 0 || y_plane.is_empty() {
return 0.0;
}
let n = y_plane.len() as u64;
let sum: u64 = y_plane.iter().map(|&b| u64::from(b)).sum();
let sum_sq: u64 = y_plane.iter().map(|&b| u64::from(b) * u64::from(b)).sum();
let mean_sq = (sum * sum) / n;
sum_sq.saturating_sub(mean_sq / n) as f64
}
fn compute_mad(prev: &[u8], curr: &[u8]) -> f64 {
let len = prev.len().min(curr.len());
if len == 0 {
return 0.0;
}
let total: u64 = prev[..len]
.iter()
.zip(curr[..len].iter())
.map(|(&a, &b)| {
let diff = (i32::from(a) - i32::from(b)).unsigned_abs();
u64::from(diff)
})
.sum();
total as f64 / len as f64
}
fn is_scene_cut(mad: f64, prev_variance: f64, curr_variance: f64) -> bool {
if mad > 30.0 {
return true;
}
if prev_variance > 0.0 {
let ratio = curr_variance / prev_variance;
if ratio > 2.5 {
return true;
}
}
false
}
fn allocate_bits(config: &TwoPassConfig, stats: &[TwoPassStats], base_bpf: f64) -> Vec<u32> {
let n = stats.len();
if n == 0 {
return Vec::new();
}
let total_complexity: f64 = stats.iter().map(|s| s.frame_complexity).sum();
let avg_complexity = if total_complexity > 0.0 {
total_complexity / n as f64
} else {
1.0
};
let total_bits = base_bpf * n as f64;
let weights: Vec<f64> = stats
.iter()
.enumerate()
.map(|(i, s)| {
let base_weight = if avg_complexity > 0.0 {
s.frame_complexity / avg_complexity
} else {
1.0
};
let scene_factor = if s.is_scene_cut { 2.0f64 } else { 1.0 };
let key_factor = if i % config.keyframe_interval as usize == 0 {
1.5f64
} else {
1.0
};
(base_weight * scene_factor * key_factor).max(0.1)
})
.collect();
let weight_sum: f64 = weights.iter().sum();
weights
.iter()
.map(|&w| {
let bits = if weight_sum > 0.0 {
total_bits * w / weight_sum
} else {
base_bpf
};
bits.max(1.0) as u32
})
.collect()
}
fn bits_to_quality(bits: u32, base_bpf: f64, quality_min: u8, quality_max: u8) -> u8 {
let q_range = quality_max as f64 - quality_min as f64;
let q_mid = quality_min as f64 + q_range * 0.5;
if base_bpf <= 0.0 {
return quality_min;
}
let ratio = bits as f64 / base_bpf;
let scaled = (ratio - 0.5) / 1.5; let quality = quality_min as f64 + scaled.clamp(0.0, 1.0) * q_range;
quality.clamp(quality_min as f64, quality_max as f64) as u8
}
fn serialize_stats_slice(stats: &[TwoPassStats]) -> Vec<u8> {
let count = stats.len() as u32;
let mut buf = Vec::with_capacity(STATS_HEADER_SIZE + stats.len() * STATS_ENTRY_SIZE);
buf.extend_from_slice(STATS_MAGIC);
buf.extend_from_slice(&count.to_le_bytes());
for s in stats {
buf.extend_from_slice(&s.to_bytes());
}
buf
}
fn deserialize_stats_impl(data: &[u8]) -> CodecResult<Vec<TwoPassStats>> {
if data.len() < STATS_HEADER_SIZE {
return Err(CodecError::InvalidBitstream(format!(
"stats buffer too short for header: {} bytes",
data.len()
)));
}
if &data[0..4] != STATS_MAGIC {
return Err(CodecError::InvalidBitstream(
"invalid TwoPassStats magic bytes".to_string(),
));
}
let count = u32::from_le_bytes(
data[4..8]
.try_into()
.map_err(|_| CodecError::InvalidBitstream("count slice error".to_string()))?,
) as usize;
let expected_len = STATS_HEADER_SIZE + count * STATS_ENTRY_SIZE;
if data.len() < expected_len {
return Err(CodecError::InvalidBitstream(format!(
"stats buffer too short: have {}, need {} for {} entries",
data.len(),
expected_len,
count
)));
}
let mut result = Vec::with_capacity(count);
for i in 0..count {
let start = STATS_HEADER_SIZE + i * STATS_ENTRY_SIZE;
let end = start + STATS_ENTRY_SIZE;
let entry = TwoPassStats::from_bytes(&data[start..end])?;
result.push(entry);
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
fn uniform_plane(width: u32, height: u32, value: u8) -> Vec<u8> {
vec![value; (width * height) as usize]
}
fn checker_plane(width: u32, height: u32) -> Vec<u8> {
let mut v = vec![0u8; (width * height) as usize];
for y in 0..height as usize {
for x in 0..width as usize {
v[y * width as usize + x] = if (x + y) % 2 == 0 { 200 } else { 50 };
}
}
v
}
#[test]
fn test_two_pass_config_defaults() {
let cfg = TwoPassConfig::default();
assert_eq!(cfg.target_bitrate, 2_000_000);
assert!((cfg.framerate - 30.0).abs() < 1e-9);
assert_eq!(cfg.keyframe_interval, 64);
assert!(cfg.quality_min < cfg.quality_max);
}
#[test]
fn test_two_pass_config_new() {
let cfg = TwoPassConfig::new(4_000_000, 60.0);
assert_eq!(cfg.target_bitrate, 4_000_000);
assert!((cfg.framerate - 60.0).abs() < 1e-9);
}
#[test]
fn test_first_pass_analyzer_new() {
let analyzer = TheoraFirstPassAnalyzer::new(TwoPassConfig::default());
assert!(analyzer.collect_stats().is_empty());
}
#[test]
fn test_analyze_frame_uniform_plane_zero_energy() {
let mut analyzer = TheoraFirstPassAnalyzer::new(TwoPassConfig::default());
let plane = uniform_plane(64, 64, 128);
let stats = analyzer.analyze_frame(&plane, 64, 64, 0);
assert!(
stats.dct_energy < 1.0,
"uniform plane should have near-zero DCT energy, got {}",
stats.dct_energy
);
}
#[test]
fn test_analyze_frame_varied_plane_nonzero_energy() {
let mut analyzer = TheoraFirstPassAnalyzer::new(TwoPassConfig::default());
let plane = checker_plane(64, 64);
let stats = analyzer.analyze_frame(&plane, 64, 64, 0);
assert!(
stats.dct_energy > 0.0,
"checkered plane must have non-zero DCT energy"
);
}
#[test]
fn test_analyze_frame_first_frame_zero_motion() {
let mut analyzer = TheoraFirstPassAnalyzer::new(TwoPassConfig::default());
let plane = checker_plane(32, 32);
let stats = analyzer.analyze_frame(&plane, 32, 32, 0);
assert_eq!(
stats.motion_magnitude, 0.0,
"first frame has no prior frame"
);
assert!(!stats.is_scene_cut, "first frame cannot be a scene cut");
}
#[test]
fn test_analyze_frame_motion_accumulates() {
let mut analyzer = TheoraFirstPassAnalyzer::new(TwoPassConfig::default());
let plane_a = uniform_plane(32, 32, 10);
let plane_b = uniform_plane(32, 32, 200);
analyzer.analyze_frame(&plane_a, 32, 32, 0);
let stats = analyzer.analyze_frame(&plane_b, 32, 32, 1);
assert!(
stats.motion_magnitude > 0.0,
"motion must be positive when frames differ"
);
assert!(
stats.is_scene_cut,
"large pixel shift should trigger scene cut"
);
}
#[test]
fn test_scene_cut_detection_high_mad() {
let result = is_scene_cut(35.0, 100.0, 100.0);
assert!(result);
}
#[test]
fn test_scene_cut_detection_high_variance_ratio() {
let result = is_scene_cut(5.0, 100.0, 300.0);
assert!(result);
}
#[test]
fn test_scene_cut_detection_no_cut() {
let result = is_scene_cut(10.0, 100.0, 120.0);
assert!(!result);
}
#[test]
fn test_collect_stats_length() {
let mut analyzer = TheoraFirstPassAnalyzer::new(TwoPassConfig::default());
for i in 0..10u64 {
let plane = uniform_plane(16, 16, (i * 25 % 255) as u8);
analyzer.analyze_frame(&plane, 16, 16, i);
}
assert_eq!(analyzer.collect_stats().len(), 10);
}
#[test]
fn test_serialize_empty_stats() {
let analyzer = TheoraFirstPassAnalyzer::new(TwoPassConfig::default());
let bytes = analyzer.serialize_stats();
assert_eq!(bytes.len(), STATS_HEADER_SIZE);
assert_eq!(&bytes[0..4], STATS_MAGIC);
let count = u32::from_le_bytes(bytes[4..8].try_into().expect("slice"));
assert_eq!(count, 0);
}
#[test]
fn test_serialize_deserialize_roundtrip() {
let mut analyzer = TheoraFirstPassAnalyzer::new(TwoPassConfig::default());
let plane_a = uniform_plane(16, 16, 80);
let plane_b = checker_plane(16, 16);
let plane_c = uniform_plane(16, 16, 200);
analyzer.analyze_frame(&plane_a, 16, 16, 0);
analyzer.analyze_frame(&plane_b, 16, 16, 1);
analyzer.analyze_frame(&plane_c, 16, 16, 2);
let bytes = analyzer.serialize_stats();
let restored =
TheoraFirstPassAnalyzer::deserialize_stats(&bytes).expect("deserialize should succeed");
let original = analyzer.collect_stats();
assert_eq!(restored.len(), original.len());
for (r, o) in restored.iter().zip(original.iter()) {
assert_eq!(r.frame_index, o.frame_index);
assert!((r.dct_energy - o.dct_energy).abs() < 1e-6);
assert!((r.motion_magnitude - o.motion_magnitude).abs() < 1e-6);
assert_eq!(r.is_scene_cut, o.is_scene_cut);
assert!((r.frame_complexity - o.frame_complexity).abs() < 1e-6);
}
}
#[test]
fn test_deserialize_invalid_magic() {
let bad: Vec<u8> = b"BADM\x01\x00\x00\x00".to_vec();
let result = TheoraFirstPassAnalyzer::deserialize_stats(&bad);
assert!(result.is_err(), "bad magic should return an error");
}
#[test]
fn test_deserialize_too_short() {
let short: Vec<u8> = vec![0u8; 3];
let result = TheoraFirstPassAnalyzer::deserialize_stats(&short);
assert!(result.is_err());
}
#[test]
fn test_deserialize_truncated_entries() {
let mut buf = Vec::new();
buf.extend_from_slice(STATS_MAGIC);
buf.extend_from_slice(&2u32.to_le_bytes());
buf.extend_from_slice(&[0u8; STATS_ENTRY_SIZE]); let result = TheoraFirstPassAnalyzer::deserialize_stats(&buf);
assert!(result.is_err());
}
fn make_stats(n: usize) -> Vec<TwoPassStats> {
(0..n)
.map(|i| TwoPassStats {
frame_index: i as u64,
dct_energy: 100.0 + i as f64 * 10.0,
motion_magnitude: 5.0,
is_scene_cut: i == 15,
frame_complexity: 100.0 + i as f64 * 10.0,
})
.collect()
}
#[test]
fn test_second_pass_encoder_new() {
let cfg = TwoPassConfig::default();
let stats = make_stats(30);
let enc = TheoraSecondPassEncoder::new(cfg, stats);
assert!(enc.is_ok());
}
#[test]
fn test_second_pass_encoder_empty_stats_errors() {
let cfg = TwoPassConfig::default();
let result = TheoraSecondPassEncoder::new(cfg, vec![]);
assert!(result.is_err());
}
#[test]
fn test_second_pass_encoder_invalid_framerate_errors() {
let mut cfg = TwoPassConfig::default();
cfg.framerate = 0.0;
let result = TheoraSecondPassEncoder::new(cfg, make_stats(5));
assert!(result.is_err());
}
#[test]
fn test_second_pass_get_frame_quality_range() {
let cfg = TwoPassConfig::default();
let stats = make_stats(30);
let enc = TheoraSecondPassEncoder::new(cfg.clone(), stats).expect("ok");
for i in 0..30u64 {
let q = enc.get_frame_quality(i, i % cfg.keyframe_interval as u64 == 0);
assert!(
q >= cfg.quality_min && q <= cfg.quality_max,
"quality {} out of range [{}, {}] at frame {}",
q,
cfg.quality_min,
cfg.quality_max,
i
);
}
}
#[test]
fn test_second_pass_allocate_bits_keyframe_higher() {
let cfg = TwoPassConfig {
target_bitrate: 2_000_000,
framerate: 30.0,
keyframe_interval: 10,
quality_min: 16,
quality_max: 56,
};
let stats: Vec<TwoPassStats> = (0..20)
.map(|i| TwoPassStats {
frame_index: i as u64,
dct_energy: 100.0,
motion_magnitude: 5.0,
is_scene_cut: false,
frame_complexity: 100.0,
})
.collect();
let enc = TheoraSecondPassEncoder::new(cfg.clone(), stats).expect("ok");
let key_bits = enc.allocate_bits(0, true);
let inter_bits = enc.allocate_bits(1, false);
assert!(
key_bits > inter_bits,
"keyframe should receive more bits: {} vs {}",
key_bits,
inter_bits
);
}
#[test]
fn test_second_pass_total_frames() {
let stats = make_stats(42);
let enc = TheoraSecondPassEncoder::new(TwoPassConfig::default(), stats).expect("ok");
assert_eq!(enc.total_frames(), 42);
}
#[test]
fn test_second_pass_out_of_range_frame() {
let cfg = TwoPassConfig::default();
let enc = TheoraSecondPassEncoder::new(cfg.clone(), make_stats(5)).expect("ok");
let q = enc.get_frame_quality(999, false);
assert_eq!(q, cfg.quality_max);
}
}