use super::obu::{encode_leb128, ObuHeader, ObuType};
use super::tile_encoder::{ParallelTileEncoder, TileEncoderConfig, TileInfoBuilder};
use crate::error::{CodecError, CodecResult};
use crate::frame::VideoFrame;
use crate::traits::{EncodedPacket, EncoderConfig, VideoEncoder};
use oximedia_core::CodecId;
#[derive(Debug)]
pub struct Av1Encoder {
config: EncoderConfig,
frame_count: u64,
output_queue: Vec<EncodedPacket>,
flushing: bool,
tile_encoder: Option<ParallelTileEncoder>,
quality: u8,
}
impl Av1Encoder {
pub fn new(config: EncoderConfig) -> CodecResult<Self> {
if config.width == 0 || config.height == 0 {
return Err(CodecError::InvalidParameter(
"Invalid frame dimensions".to_string(),
));
}
if config.codec != CodecId::Av1 {
return Err(CodecError::InvalidParameter(
"Expected AV1 codec".to_string(),
));
}
let tile_encoder = if config.threads > 1 {
let tile_config = TileEncoderConfig::auto(config.width, config.height, config.threads);
Some(ParallelTileEncoder::new(
tile_config,
config.width,
config.height,
)?)
} else {
None
};
let quality = Self::compute_quality(&config);
Ok(Self {
config,
frame_count: 0,
output_queue: Vec::new(),
flushing: false,
tile_encoder,
quality,
})
}
fn compute_quality(config: &EncoderConfig) -> u8 {
use crate::traits::BitrateMode;
match config.bitrate {
BitrateMode::Crf(crf) => {
let normalized = (crf / 51.0).clamp(0.0, 1.0);
(255.0 * (1.0 - normalized)) as u8
}
BitrateMode::Cbr(_) | BitrateMode::Vbr { .. } => {
128
}
BitrateMode::Lossless => {
255
}
}
}
fn encode_frame(&mut self, frame: &VideoFrame) {
let is_keyframe = self.frame_count % u64::from(self.config.keyint) == 0;
let mut data = Vec::new();
if is_keyframe {
self.write_sequence_header(&mut data);
}
if let Some(ref tile_encoder) = self.tile_encoder {
self.encode_frame_with_tiles(frame, &mut data, is_keyframe);
} else {
Self::write_frame_obu(&mut data, is_keyframe);
}
#[allow(clippy::cast_possible_wrap)]
let pts = self.frame_count as i64;
let dts = pts;
self.output_queue.push(EncodedPacket {
data,
pts,
dts,
keyframe: is_keyframe,
duration: Some(1),
});
self.frame_count += 1;
}
fn encode_frame_with_tiles(&self, frame: &VideoFrame, data: &mut Vec<u8>, is_keyframe: bool) {
if let Some(ref tile_encoder) = self.tile_encoder {
match tile_encoder.encode_frame(frame, self.quality, is_keyframe) {
Ok(tiles) => {
match tile_encoder.merge_tiles(&tiles) {
Ok(merged) => {
self.write_frame_header_with_tiles(data, is_keyframe, tile_encoder);
data.extend_from_slice(&merged);
}
Err(_) => {
Self::write_frame_obu(data, is_keyframe);
}
}
}
Err(_) => {
Self::write_frame_obu(data, is_keyframe);
}
}
}
}
fn write_frame_header_with_tiles(
&self,
data: &mut Vec<u8>,
is_keyframe: bool,
tile_encoder: &ParallelTileEncoder,
) {
let header = ObuHeader {
obu_type: ObuType::Frame,
has_extension: false,
has_size: true,
temporal_id: 0,
spatial_id: 0,
};
data.extend(header.to_bytes());
let mut frame_header = Vec::new();
let frame_type = u8::from(!is_keyframe);
frame_header.push((frame_type << 5) | 0x10);
let tile_info = TileInfoBuilder::from_config(
tile_encoder.config(),
self.config.width,
self.config.height,
);
if tile_info.tile_count() > 1 {
frame_header.push(tile_info.tile_cols_log2);
frame_header.push(tile_info.tile_rows_log2);
}
let size_bytes = encode_leb128(frame_header.len() as u64);
data.extend(size_bytes);
data.extend(frame_header);
}
fn write_sequence_header(&self, data: &mut Vec<u8>) {
let header = ObuHeader {
obu_type: ObuType::SequenceHeader,
has_extension: false,
has_size: true,
temporal_id: 0,
spatial_id: 0,
};
data.extend(header.to_bytes());
let payload = self.build_sequence_header_payload();
let size_bytes = encode_leb128(payload.len() as u64);
data.extend(size_bytes);
data.extend(payload);
}
#[allow(clippy::cast_possible_truncation)]
fn build_sequence_header_payload(&self) -> Vec<u8> {
let mut payload = Vec::new();
payload.push(0x00);
payload.push(0x00);
payload.push(0x00);
let width_bits = 32 - self.config.width.leading_zeros();
let height_bits = 32 - self.config.height.leading_zeros();
payload.push(
((width_bits.saturating_sub(1) as u8) << 4) | (height_bits.saturating_sub(1) as u8),
);
let width_minus_1 = self.config.width.saturating_sub(1);
let height_minus_1 = self.config.height.saturating_sub(1);
payload.extend(&width_minus_1.to_be_bytes()[2..]);
payload.extend(&height_minus_1.to_be_bytes()[2..]);
payload
}
fn write_frame_obu(data: &mut Vec<u8>, is_keyframe: bool) {
let header = ObuHeader {
obu_type: ObuType::Frame,
has_extension: false,
has_size: true,
temporal_id: 0,
spatial_id: 0,
};
data.extend(header.to_bytes());
let frame_data = Self::build_frame_payload(is_keyframe);
let size_bytes = encode_leb128(frame_data.len() as u64);
data.extend(size_bytes);
data.extend(frame_data);
}
fn build_frame_payload(is_keyframe: bool) -> Vec<u8> {
let mut payload = Vec::new();
let frame_type = u8::from(!is_keyframe);
payload.push((frame_type << 5) | 0x10);
payload.extend(&[0x00; 16]);
payload
}
}
impl VideoEncoder for Av1Encoder {
fn codec(&self) -> CodecId {
CodecId::Av1
}
fn send_frame(&mut self, frame: &VideoFrame) -> CodecResult<()> {
if self.flushing {
return Err(CodecError::InvalidParameter(
"Cannot send frame while flushing".to_string(),
));
}
if frame.width != self.config.width || frame.height != self.config.height {
return Err(CodecError::InvalidParameter(format!(
"Frame dimensions {}x{} don't match encoder config {}x{}",
frame.width, frame.height, self.config.width, self.config.height
)));
}
self.encode_frame(frame);
Ok(())
}
fn receive_packet(&mut self) -> CodecResult<Option<EncodedPacket>> {
if self.output_queue.is_empty() {
return Ok(None);
}
Ok(Some(self.output_queue.remove(0)))
}
fn flush(&mut self) -> CodecResult<()> {
self.flushing = true;
Ok(())
}
fn config(&self) -> &EncoderConfig {
&self.config
}
}
impl Av1Encoder {
#[must_use]
pub fn tile_config(&self) -> Option<&TileEncoderConfig> {
self.tile_encoder.as_ref().map(|e| e.config())
}
#[must_use]
pub fn has_tile_encoding(&self) -> bool {
self.tile_encoder.is_some()
}
#[must_use]
pub fn tile_count(&self) -> usize {
self.tile_encoder.as_ref().map_or(1, |e| e.tile_count())
}
pub fn set_tile_config(&mut self, tile_config: TileEncoderConfig) -> CodecResult<()> {
tile_config.validate()?;
self.tile_encoder = Some(ParallelTileEncoder::new(
tile_config,
self.config.width,
self.config.height,
)?);
Ok(())
}
pub fn disable_tile_encoding(&mut self) {
self.tile_encoder = None;
}
}
#[cfg(test)]
mod tests {
use super::*;
use oximedia_core::PixelFormat;
#[test]
fn test_encoder_creation() {
let config = EncoderConfig::av1(1920, 1080);
let encoder = Av1Encoder::new(config);
assert!(encoder.is_ok());
}
#[test]
fn test_encoder_invalid_dimensions() {
let config = EncoderConfig::av1(0, 0);
let encoder = Av1Encoder::new(config);
assert!(encoder.is_err());
}
#[test]
fn test_encoder_codec_id() {
let config = EncoderConfig::av1(1920, 1080);
let encoder = Av1Encoder::new(config).expect("should succeed");
assert_eq!(encoder.codec(), CodecId::Av1);
}
#[test]
fn test_encode_frame() {
let config = EncoderConfig::av1(320, 240);
let mut encoder = Av1Encoder::new(config).expect("should succeed");
let mut frame = VideoFrame::new(PixelFormat::Yuv420p, 320, 240);
frame.allocate();
assert!(encoder.send_frame(&frame).is_ok());
let packet = encoder.receive_packet().expect("should succeed");
assert!(packet.is_some());
let packet = packet.expect("should succeed");
assert!(packet.keyframe);
assert!(!packet.data.is_empty());
}
#[test]
fn test_frame_dimension_mismatch() {
let config = EncoderConfig::av1(320, 240);
let mut encoder = Av1Encoder::new(config).expect("should succeed");
let frame = VideoFrame::new(PixelFormat::Yuv420p, 640, 480);
let result = encoder.send_frame(&frame);
assert!(result.is_err());
}
}