use crate::{Result, TranscodeConfig, TranscodeError, TranscodeOutput};
use rayon::prelude::*;
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone)]
pub struct ParallelConfig {
pub max_parallel: usize,
pub cores_per_encode: Option<usize>,
pub use_thread_pool: bool,
pub priority: ParallelPriority,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ParallelPriority {
Low,
Normal,
High,
}
impl Default for ParallelConfig {
fn default() -> Self {
Self {
max_parallel: num_cpus(),
cores_per_encode: None,
use_thread_pool: true,
priority: ParallelPriority::Normal,
}
}
}
impl ParallelConfig {
#[must_use]
pub fn auto() -> Self {
Self::default()
}
#[must_use]
pub fn with_max_parallel(max: usize) -> Self {
Self {
max_parallel: max,
..Self::default()
}
}
#[must_use]
pub fn cores_per_encode(mut self, cores: usize) -> Self {
self.cores_per_encode = Some(cores);
self
}
#[must_use]
pub fn priority(mut self, priority: ParallelPriority) -> Self {
self.priority = priority;
self
}
pub fn validate(&self) -> Result<()> {
if self.max_parallel == 0 {
return Err(TranscodeError::ValidationError(
crate::ValidationError::Unsupported(
"max_parallel must be greater than 0".to_string(),
),
));
}
if let Some(cores) = self.cores_per_encode {
if cores == 0 {
return Err(TranscodeError::ValidationError(
crate::ValidationError::Unsupported(
"cores_per_encode must be greater than 0".to_string(),
),
));
}
}
Ok(())
}
}
fn num_cpus() -> usize {
std::thread::available_parallelism()
.map(std::num::NonZero::get)
.unwrap_or(4) }
pub struct ParallelEncoder {
config: ParallelConfig,
jobs: Vec<TranscodeConfig>,
results: Arc<Mutex<Vec<Result<TranscodeOutput>>>>,
}
impl ParallelEncoder {
#[must_use]
pub fn new(config: ParallelConfig) -> Self {
Self {
config,
jobs: Vec::new(),
results: Arc::new(Mutex::new(Vec::new())),
}
}
pub fn add_job(&mut self, job: TranscodeConfig) {
self.jobs.push(job);
}
pub fn add_jobs(&mut self, jobs: Vec<TranscodeConfig>) {
self.jobs.extend(jobs);
}
#[must_use]
pub fn job_count(&self) -> usize {
self.jobs.len()
}
pub async fn execute_all(&mut self) -> Result<Vec<Result<TranscodeOutput>>> {
self.config.validate()?;
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(self.config.max_parallel)
.build()
.map_err(|e| {
TranscodeError::PipelineError(format!("Failed to create thread pool: {e}"))
})?;
let jobs = std::mem::take(&mut self.jobs);
let job_results: Vec<Result<TranscodeOutput>> = pool.install(|| {
jobs.into_par_iter()
.map(Self::execute_job)
.collect::<Vec<_>>()
});
match self.results.lock() {
Ok(mut guard) => {
guard.extend(job_results.iter().cloned());
}
Err(poisoned) => {
poisoned.into_inner().extend(job_results.iter().cloned());
}
}
Ok(job_results)
}
pub async fn execute_sequential(&mut self) -> Result<Vec<TranscodeOutput>> {
let mut outputs = Vec::new();
for job in &self.jobs {
let output = Self::execute_job(job.clone())?;
outputs.push(output);
}
Ok(outputs)
}
#[cfg(not(target_arch = "wasm32"))]
fn execute_job(job: TranscodeConfig) -> Result<TranscodeOutput> {
let input = job
.input
.as_deref()
.ok_or_else(|| TranscodeError::InvalidInput("No input file specified".to_string()))?;
let output = job
.output
.as_deref()
.ok_or_else(|| TranscodeError::InvalidOutput("No output file specified".to_string()))?;
let mut pipeline_builder = crate::pipeline::TranscodePipelineBuilder::new()
.input(input)
.output(output);
if let Some(ref vc) = job.video_codec {
pipeline_builder = pipeline_builder.video_codec(vc);
}
if let Some(ref ac) = job.audio_codec {
pipeline_builder = pipeline_builder.audio_codec(ac);
}
if let Some(mode) = job.multi_pass {
pipeline_builder = pipeline_builder.multipass(mode);
}
let mut pipeline = pipeline_builder.build()?;
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|e| {
TranscodeError::PipelineError(format!("Failed to create async runtime: {e}"))
})?;
rt.block_on(pipeline.execute())
}
#[cfg(target_arch = "wasm32")]
fn execute_job(_job: TranscodeConfig) -> Result<TranscodeOutput> {
Err(TranscodeError::Unsupported(
"Parallel job execution is not supported on wasm32".to_string(),
))
}
#[must_use]
pub fn get_results(&self) -> Vec<Result<TranscodeOutput>> {
match self.results.lock() {
Ok(guard) => guard.clone(),
Err(poisoned) => poisoned.into_inner().clone(),
}
}
pub fn clear(&mut self) {
self.jobs.clear();
match self.results.lock() {
Ok(mut guard) => guard.clear(),
Err(poisoned) => poisoned.into_inner().clear(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Av1TileConfig {
pub tile_cols_log2: u8,
pub tile_rows_log2: u8,
pub threads: usize,
pub row_mt: bool,
}
impl Default for Av1TileConfig {
fn default() -> Self {
Self {
tile_cols_log2: 1, tile_rows_log2: 1, threads: 0,
row_mt: true,
}
}
}
impl Av1TileConfig {
pub fn new(tile_cols_log2: u8, tile_rows_log2: u8, threads: usize) -> Result<Self> {
if tile_cols_log2 > 6 {
return Err(TranscodeError::ValidationError(
crate::ValidationError::Unsupported(format!(
"tile_cols_log2 must be 0–6, got {tile_cols_log2}"
)),
));
}
if tile_rows_log2 > 6 {
return Err(TranscodeError::ValidationError(
crate::ValidationError::Unsupported(format!(
"tile_rows_log2 must be 0–6, got {tile_rows_log2}"
)),
));
}
Ok(Self {
tile_cols_log2,
tile_rows_log2,
threads,
row_mt: true,
})
}
#[must_use]
pub fn auto(_width: u32, height: u32, threads: usize) -> Self {
let (cols_log2, rows_log2) = if height <= 720 {
(1, 0)
} else if height <= 1080 {
(1, 1)
} else if height <= 2160 {
(2, 2)
} else {
(3, 2)
};
Self {
tile_cols_log2: cols_log2,
tile_rows_log2: rows_log2,
threads,
row_mt: true,
}
}
#[must_use]
pub fn tile_cols(&self) -> u32 {
1u32 << self.tile_cols_log2
}
#[must_use]
pub fn tile_rows(&self) -> u32 {
1u32 << self.tile_rows_log2
}
#[must_use]
pub fn total_tiles(&self) -> u32 {
self.tile_cols() * self.tile_rows()
}
pub fn validate_for_frame(&self, width: u32, height: u32) -> Result<()> {
const MIN_TILE_DIM: u32 = 64;
let tile_w = width / self.tile_cols();
let tile_h = height / self.tile_rows();
if tile_w < MIN_TILE_DIM || tile_h < MIN_TILE_DIM {
return Err(TranscodeError::ValidationError(
crate::ValidationError::Unsupported(format!(
"Tile grid {}×{} produces tiles {}×{} which is smaller than \
the AV1 minimum {}×{} pixels",
self.tile_cols(),
self.tile_rows(),
tile_w,
tile_h,
MIN_TILE_DIM,
MIN_TILE_DIM
)),
));
}
Ok(())
}
}
#[derive(Debug, Clone, Default)]
pub struct Av1TileStats {
pub tiles_encoded: u32,
pub compressed_bytes: u64,
pub wall_time_secs: f64,
}
impl Av1TileStats {
#[must_use]
pub fn tiles_per_second(&self) -> f64 {
if self.wall_time_secs > 0.0 {
f64::from(self.tiles_encoded) / self.wall_time_secs
} else {
0.0
}
}
#[must_use]
pub fn avg_bytes_per_tile(&self) -> u64 {
if self.tiles_encoded == 0 {
return 0;
}
self.compressed_bytes / u64::from(self.tiles_encoded)
}
}
pub struct Av1TileParallelEncoder {
tile_config: Av1TileConfig,
frame_width: u32,
frame_height: u32,
stats: Av1TileStats,
}
impl Av1TileParallelEncoder {
pub fn new(tile_config: Av1TileConfig, frame_width: u32, frame_height: u32) -> Result<Self> {
tile_config.validate_for_frame(frame_width, frame_height)?;
Ok(Self {
tile_config,
frame_width,
frame_height,
stats: Av1TileStats::default(),
})
}
#[must_use]
pub fn tile_config(&self) -> &Av1TileConfig {
&self.tile_config
}
#[must_use]
pub fn stats(&self) -> &Av1TileStats {
&self.stats
}
pub fn encode_frame_rgba(&mut self, rgba: &[u8]) -> Result<Vec<u8>> {
let expected = (self.frame_width * self.frame_height * 4) as usize;
if rgba.len() < expected {
return Err(TranscodeError::CodecError(format!(
"RGBA buffer too small: got {} bytes, need {}",
rgba.len(),
expected
)));
}
let start = std::time::Instant::now();
let tile_cols = self.tile_config.tile_cols();
let tile_rows = self.tile_config.tile_rows();
let tile_w = self.frame_width / tile_cols;
let tile_h = self.frame_height / tile_rows;
let coords: Vec<(u32, u32)> = (0..tile_rows)
.flat_map(|row| (0..tile_cols).map(move |col| (col, row)))
.collect();
let tile_bitstreams: Vec<(usize, Vec<u8>)> = {
use rayon::prelude::*;
let frame_width = self.frame_width;
coords
.par_iter()
.enumerate()
.map(|(idx, &(col, row))| {
let x_start = col * tile_w;
let y_start = row * tile_h;
let mut tile_buf = Vec::with_capacity((tile_w * tile_h * 4) as usize);
for ty in 0..tile_h {
let src_row = y_start + ty;
let src_start = ((src_row * frame_width + x_start) * 4) as usize;
let src_end = src_start + (tile_w * 4) as usize;
if src_end <= rgba.len() {
tile_buf.extend_from_slice(&rgba[src_start..src_end]);
}
}
let compressed = compress_tile_placeholder(&tile_buf);
(idx, compressed)
})
.collect()
};
let compressed_total: u64 = tile_bitstreams.iter().map(|(_, b)| b.len() as u64).sum();
self.stats.tiles_encoded += tile_bitstreams.len() as u32;
self.stats.compressed_bytes += compressed_total;
self.stats.wall_time_secs += start.elapsed().as_secs_f64();
Ok(assemble_av1_tile_bitstream(tile_bitstreams))
}
pub fn reset_stats(&mut self) {
self.stats = Av1TileStats::default();
}
}
#[must_use]
pub fn assemble_av1_tile_bitstream(tiles: Vec<(usize, Vec<u8>)>) -> Vec<u8> {
let mut out = Vec::new();
out.extend_from_slice(&(tiles.len() as u32).to_le_bytes());
let mut sorted = tiles;
sorted.sort_by_key(|(idx, _)| *idx);
for (idx, data) in sorted {
out.extend_from_slice(&(idx as u32).to_le_bytes());
out.extend_from_slice(&(data.len() as u32).to_le_bytes());
out.extend_from_slice(&data);
}
out
}
fn compress_tile_placeholder(rgba: &[u8]) -> Vec<u8> {
if rgba.is_empty() {
return Vec::new();
}
let luma: Vec<u8> = rgba.iter().step_by(4).copied().collect();
let mut out = Vec::with_capacity(luma.len());
let mut i = 0;
while i < luma.len() {
let val = luma[i];
let mut run: u8 = 1;
while i + usize::from(run) < luma.len() && luma[i + usize::from(run)] == val && run < 255 {
run += 1;
}
out.push(val);
out.push(run);
i += usize::from(run);
}
out
}
pub struct ParallelEncodeBuilder {
config: ParallelConfig,
jobs: Vec<TranscodeConfig>,
}
impl ParallelEncodeBuilder {
#[must_use]
pub fn new() -> Self {
Self {
config: ParallelConfig::default(),
jobs: Vec::new(),
}
}
#[must_use]
pub fn max_parallel(mut self, max: usize) -> Self {
self.config.max_parallel = max;
self
}
#[must_use]
pub fn cores_per_encode(mut self, cores: usize) -> Self {
self.config.cores_per_encode = Some(cores);
self
}
#[must_use]
pub fn priority(mut self, priority: ParallelPriority) -> Self {
self.config.priority = priority;
self
}
#[must_use]
pub fn add_job(mut self, job: TranscodeConfig) -> Self {
self.jobs.push(job);
self
}
#[must_use]
pub fn add_jobs(mut self, jobs: Vec<TranscodeConfig>) -> Self {
self.jobs.extend(jobs);
self
}
#[must_use]
pub fn build(self) -> ParallelEncoder {
let mut encoder = ParallelEncoder::new(self.config);
encoder.add_jobs(self.jobs);
encoder
}
}
impl Default for ParallelEncodeBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parallel_config_default() {
let config = ParallelConfig::default();
assert!(config.max_parallel > 0);
assert_eq!(config.priority, ParallelPriority::Normal);
assert!(config.use_thread_pool);
}
#[test]
fn test_parallel_config_validation() {
let valid = ParallelConfig::with_max_parallel(4);
assert!(valid.validate().is_ok());
let invalid = ParallelConfig {
max_parallel: 0,
..Default::default()
};
assert!(invalid.validate().is_err());
}
#[test]
fn test_parallel_config_cores_validation() {
let valid = ParallelConfig::default().cores_per_encode(2);
assert!(valid.validate().is_ok());
let invalid = ParallelConfig::default().cores_per_encode(0);
assert!(invalid.validate().is_err());
}
#[test]
fn test_parallel_encoder_job_count() {
let mut encoder = ParallelEncoder::new(ParallelConfig::default());
assert_eq!(encoder.job_count(), 0);
let job = TranscodeConfig {
input: Some("/tmp/input.mp4".to_string()),
output: Some("/tmp/output.mp4".to_string()),
..Default::default()
};
encoder.add_job(job);
assert_eq!(encoder.job_count(), 1);
}
#[test]
fn test_parallel_encoder_add_jobs() {
let mut encoder = ParallelEncoder::new(ParallelConfig::default());
let jobs = vec![
TranscodeConfig {
input: Some("/tmp/input1.mp4".to_string()),
output: Some("/tmp/output1.mp4".to_string()),
..Default::default()
},
TranscodeConfig {
input: Some("/tmp/input2.mp4".to_string()),
output: Some("/tmp/output2.mp4".to_string()),
..Default::default()
},
];
encoder.add_jobs(jobs);
assert_eq!(encoder.job_count(), 2);
}
#[test]
fn test_parallel_encoder_clear() {
let mut encoder = ParallelEncoder::new(ParallelConfig::default());
let job = TranscodeConfig {
input: Some("/tmp/input.mp4".to_string()),
output: Some("/tmp/output.mp4".to_string()),
..Default::default()
};
encoder.add_job(job);
assert_eq!(encoder.job_count(), 1);
encoder.clear();
assert_eq!(encoder.job_count(), 0);
}
#[test]
fn test_parallel_builder() {
let job = TranscodeConfig {
input: Some("/tmp/input.mp4".to_string()),
output: Some("/tmp/output.mp4".to_string()),
..Default::default()
};
let encoder = ParallelEncodeBuilder::new()
.max_parallel(4)
.cores_per_encode(2)
.priority(ParallelPriority::High)
.add_job(job)
.build();
assert_eq!(encoder.config.max_parallel, 4);
assert_eq!(encoder.config.cores_per_encode, Some(2));
assert_eq!(encoder.config.priority, ParallelPriority::High);
assert_eq!(encoder.job_count(), 1);
}
#[test]
fn test_num_cpus() {
let cpus = num_cpus();
assert!(cpus > 0);
assert!(cpus <= 1024); }
#[test]
fn test_av1_tile_config_default() {
let cfg = Av1TileConfig::default();
assert_eq!(cfg.tile_cols(), 2);
assert_eq!(cfg.tile_rows(), 2);
assert_eq!(cfg.total_tiles(), 4);
assert!(cfg.row_mt);
}
#[test]
fn test_av1_tile_config_new_valid() {
let cfg = Av1TileConfig::new(2, 1, 4).expect("valid config");
assert_eq!(cfg.tile_cols(), 4);
assert_eq!(cfg.tile_rows(), 2);
assert_eq!(cfg.total_tiles(), 8);
}
#[test]
fn test_av1_tile_config_new_invalid_cols() {
let result = Av1TileConfig::new(7, 1, 0);
assert!(result.is_err(), "log2 > 6 should fail");
}
#[test]
fn test_av1_tile_config_new_invalid_rows() {
let result = Av1TileConfig::new(1, 7, 0);
assert!(result.is_err(), "log2 > 6 should fail");
}
#[test]
fn test_av1_tile_config_auto_720p() {
let cfg = Av1TileConfig::auto(1280, 720, 4);
assert_eq!(cfg.tile_cols_log2, 1);
assert_eq!(cfg.tile_rows_log2, 0);
}
#[test]
fn test_av1_tile_config_auto_1080p() {
let cfg = Av1TileConfig::auto(1920, 1080, 4);
assert_eq!(cfg.tile_cols_log2, 1);
assert_eq!(cfg.tile_rows_log2, 1);
}
#[test]
fn test_av1_tile_config_auto_4k() {
let cfg = Av1TileConfig::auto(3840, 2160, 8);
assert_eq!(cfg.tile_cols_log2, 2);
assert_eq!(cfg.tile_rows_log2, 2);
}
#[test]
fn test_av1_tile_config_validate_ok() {
let cfg = Av1TileConfig::new(1, 1, 0).expect("valid");
assert!(cfg.validate_for_frame(1920, 1080).is_ok());
}
#[test]
fn test_av1_tile_config_validate_too_small() {
let cfg = Av1TileConfig::new(3, 3, 0).expect("valid config");
assert!(cfg.validate_for_frame(256, 256).is_err());
}
#[test]
fn test_av1_tile_parallel_encoder_encode_frame() {
let cfg = Av1TileConfig::new(1, 1, 2).expect("valid");
let mut encoder = Av1TileParallelEncoder::new(cfg, 512, 512).expect("encoder ok");
let frame_data = vec![128u8; 512 * 512 * 4]; let bitstream = encoder.encode_frame_rgba(&frame_data).expect("encode ok");
assert!(bitstream.len() >= 4, "bitstream should have header");
let tile_count =
u32::from_le_bytes([bitstream[0], bitstream[1], bitstream[2], bitstream[3]]);
assert_eq!(tile_count, 4, "should encode 4 tiles");
assert_eq!(encoder.stats().tiles_encoded, 4);
assert!(encoder.stats().compressed_bytes > 0);
}
#[test]
fn test_av1_tile_parallel_encoder_undersized_frame() {
let cfg = Av1TileConfig::default();
let mut encoder = Av1TileParallelEncoder::new(cfg, 256, 256).expect("encoder ok");
let result = encoder.encode_frame_rgba(&[0u8]);
assert!(result.is_err(), "undersized frame should fail");
}
#[test]
fn test_av1_tile_parallel_encoder_stats_reset() {
let cfg = Av1TileConfig::new(1, 1, 2).expect("valid");
let mut encoder = Av1TileParallelEncoder::new(cfg, 256, 256).expect("encoder ok");
let frame_data = vec![0u8; 256 * 256 * 4];
encoder.encode_frame_rgba(&frame_data).expect("encode ok");
assert!(encoder.stats().tiles_encoded > 0);
encoder.reset_stats();
assert_eq!(encoder.stats().tiles_encoded, 0);
assert_eq!(encoder.stats().compressed_bytes, 0);
}
#[test]
fn test_av1_tile_stats_tiles_per_second() {
let stats = Av1TileStats {
tiles_encoded: 100,
compressed_bytes: 50_000,
wall_time_secs: 2.0,
};
assert!((stats.tiles_per_second() - 50.0).abs() < 1e-9);
assert_eq!(stats.avg_bytes_per_tile(), 500);
}
#[test]
fn test_av1_tile_stats_zero_time() {
let stats = Av1TileStats::default();
assert!((stats.tiles_per_second()).abs() < 1e-9);
assert_eq!(stats.avg_bytes_per_tile(), 0);
}
#[test]
fn test_assemble_av1_tile_bitstream_order() {
let tiles = vec![(1, vec![1u8, 2, 3]), (0, vec![4u8, 5, 6])];
let bs = assemble_av1_tile_bitstream(tiles);
let count = u32::from_le_bytes([bs[0], bs[1], bs[2], bs[3]]);
assert_eq!(count, 2);
let idx0 = u32::from_le_bytes([bs[4], bs[5], bs[6], bs[7]]);
assert_eq!(idx0, 0);
}
#[test]
fn test_compress_tile_placeholder_empty() {
let result = compress_tile_placeholder(&[]);
assert!(result.is_empty());
}
#[test]
fn test_compress_tile_placeholder_rle() {
let rgba = vec![
200u8, 0, 0, 255, 200, 0, 0, 255, 200, 0, 0, 255, 200, 0, 0, 255,
];
let compressed = compress_tile_placeholder(&rgba);
assert_eq!(compressed, vec![200, 4]);
}
}