use anyhow::Result;
use cudarc::driver::CudaContext;
use nvidia_video_codec_sdk::{
sys::nvEncodeAPI::{
NV_ENC_BUFFER_FORMAT, NV_ENC_CODEC_AV1_GUID, NV_ENC_PIC_FLAGS, NV_ENC_PIC_TYPE,
NV_ENC_PRESET_P4_GUID, NV_ENC_PRESET_P7_GUID, NV_ENC_TUNING_INFO, _NV_ENC_PARAMS_RC_MODE,
},
Bitstream, Buffer, EncodePictureParams, Encoder, EncoderInitParams, Session,
};
pub struct NvencAv1Encoder {
input_buffer: std::mem::ManuallyDrop<Buffer<'static>>,
output_bitstream: std::mem::ManuallyDrop<Bitstream<'static>>,
session: *mut Session,
pub width: u32,
pub height: u32,
nv12_buffer: Vec<u8>,
pub frame_count: u64,
fps: u32,
ten_bit: bool,
}
unsafe impl Send for NvencAv1Encoder {}
impl Drop for NvencAv1Encoder {
fn drop(&mut self) {
unsafe {
std::mem::ManuallyDrop::drop(&mut self.input_buffer);
std::mem::ManuallyDrop::drop(&mut self.output_bitstream);
}
if !self.session.is_null() {
let session = unsafe { Box::from_raw(self.session) };
let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| drop(session)));
}
}
}
enum RateControl {
Bitrate(u32),
ConstQP(u32),
}
impl NvencAv1Encoder {
pub fn new(width: u32, height: u32, fps: u32, bitrate: u32, ten_bit: bool) -> Result<Self> {
Self::new_inner(width, height, fps, ten_bit, RateControl::Bitrate(bitrate))
}
pub fn new_cqp(width: u32, height: u32, fps: u32, qp: u32, ten_bit: bool) -> Result<Self> {
Self::new_inner(width, height, fps, ten_bit, RateControl::ConstQP(qp))
}
fn new_inner(
width: u32,
height: u32,
fps: u32,
ten_bit: bool,
rc: RateControl,
) -> Result<Self> {
let (preset, tuning) = match rc {
RateControl::Bitrate(_) => (
NV_ENC_PRESET_P4_GUID,
NV_ENC_TUNING_INFO::NV_ENC_TUNING_INFO_ULTRA_LOW_LATENCY,
),
RateControl::ConstQP(_) => (
NV_ENC_PRESET_P7_GUID,
NV_ENC_TUNING_INFO::NV_ENC_TUNING_INFO_HIGH_QUALITY,
),
};
let cuda_ctx = CudaContext::new(0)
.map_err(|e| anyhow::anyhow!("Failed to create CUDA context: {}", e))?;
let encoder = Encoder::initialize_with_cuda(cuda_ctx)
.map_err(|e| anyhow::anyhow!("Failed to initialize NVENC: {:?}", e))?;
let mut preset_config = encoder
.get_preset_config(NV_ENC_CODEC_AV1_GUID, preset, tuning)
.map_err(|e| anyhow::anyhow!("Failed to get AV1 preset config: {:?}", e))?;
let config = &mut preset_config.presetCfg;
config.gopLength = fps;
config.frameIntervalP = 1;
match rc {
RateControl::Bitrate(bitrate) => {
config.rcParams.averageBitRate = bitrate;
config.rcParams.maxBitRate = bitrate;
config.rcParams.vbvBufferSize = bitrate / fps;
}
RateControl::ConstQP(qp) => {
config.rcParams.rateControlMode = _NV_ENC_PARAMS_RC_MODE::NV_ENC_PARAMS_RC_CONSTQP;
config.rcParams.constQP.qpInterP = qp;
config.rcParams.constQP.qpInterB = qp;
config.rcParams.constQP.qpIntra = qp;
}
}
unsafe {
let av1_config = &mut config.encodeCodecConfig.av1Config;
av1_config.set_enableIntraRefresh(0);
av1_config.idrPeriod = fps;
av1_config.set_repeatSeqHdr(1);
av1_config.colorRange = 1; if ten_bit {
av1_config.set_inputPixelBitDepthMinus8(2);
av1_config.set_pixelBitDepthMinus8(2);
} else {
av1_config.set_inputPixelBitDepthMinus8(0);
av1_config.set_pixelBitDepthMinus8(0);
}
}
let mut init_params = EncoderInitParams::new(NV_ENC_CODEC_AV1_GUID, width, height);
init_params
.preset_guid(preset)
.tuning_info(tuning)
.framerate(fps, 1)
.encode_config(config);
let buffer_format = if ten_bit {
NV_ENC_BUFFER_FORMAT::NV_ENC_BUFFER_FORMAT_YUV420_10BIT
} else {
NV_ENC_BUFFER_FORMAT::NV_ENC_BUFFER_FORMAT_NV12
};
let session = encoder
.start_session(buffer_format, init_params)
.map_err(|e| anyhow::anyhow!("Failed to start AV1 session: {:?}", e))?;
let session_ptr = Box::into_raw(Box::new(session));
let session_ref: &'static Session = unsafe { &*session_ptr };
let input_buffer = session_ref
.create_input_buffer()
.map_err(|e| anyhow::anyhow!("Failed to create input buffer: {:?}", e))?;
let output_bitstream = session_ref
.create_output_bitstream()
.map_err(|e| anyhow::anyhow!("Failed to create output bitstream: {:?}", e))?;
let buffer_size = if ten_bit {
(width * height * 3) as usize
} else {
(width * height * 3 / 2) as usize
};
Ok(Self {
session: session_ptr,
input_buffer: std::mem::ManuallyDrop::new(input_buffer),
output_bitstream: std::mem::ManuallyDrop::new(output_bitstream),
width,
height,
nv12_buffer: vec![0u8; buffer_size],
frame_count: 0,
fps,
ten_bit,
})
}
pub fn encode_frame(&mut self, timestamp_us: u64) -> Result<Vec<u8>> {
{
let mut lock = self
.input_buffer
.lock()
.map_err(|e| anyhow::anyhow!("Failed to lock input: {:?}", e))?;
if self.ten_bit {
unsafe { lock.write_p010(&self.nv12_buffer, self.width, self.height) };
} else {
unsafe { lock.write_nv12(&self.nv12_buffer, self.width, self.height) };
}
}
let is_idr = self.frame_count % self.fps as u64 == 0;
let picture_type = if is_idr {
NV_ENC_PIC_TYPE::NV_ENC_PIC_TYPE_IDR
} else {
NV_ENC_PIC_TYPE::NV_ENC_PIC_TYPE_P
};
let encode_pic_flags = if is_idr {
NV_ENC_PIC_FLAGS::NV_ENC_PIC_FLAG_OUTPUT_SPSPPS as u32
} else {
0
};
self.frame_count += 1;
let session: &Session = unsafe { &*self.session };
session
.encode_picture(
&mut *self.input_buffer,
&mut *self.output_bitstream,
EncodePictureParams {
input_timestamp: timestamp_us,
picture_type,
encode_pic_flags,
..Default::default()
},
)
.map_err(|e| anyhow::anyhow!("Failed to encode: {:?}", e))?;
let lock = self
.output_bitstream
.lock()
.map_err(|e| anyhow::anyhow!("Failed to lock output: {:?}", e))?;
Ok(lock.data().to_vec())
}
pub fn encode_p010(&mut self, data: &[u8], timestamp_us: u64) -> Result<Vec<u8>> {
self.nv12_buffer[..data.len()].copy_from_slice(data);
self.encode_frame(timestamp_us)
}
pub fn encode_rgb(&mut self, rgb: &[u8], timestamp_us: u64) -> Result<Vec<u8>> {
self.rgb_to_nv12(rgb);
self.encode_frame(timestamp_us)
}
pub fn encode_yuyv(&mut self, yuyv: &[u8], timestamp_us: u64) -> Result<Vec<u8>> {
self.yuyv_to_nv12(yuyv);
self.encode_frame(timestamp_us)
}
pub fn encode_grey(&mut self, grey: &[u8], timestamp_us: u64) -> Result<Vec<u8>> {
self.grey_to_nv12(grey);
self.encode_frame(timestamp_us)
}
fn rgb_to_nv12(&mut self, rgb: &[u8]) {
let width = self.width as usize;
let height = self.height as usize;
let y_size = width * height;
for y in 0..height {
for x in 0..width {
let rgb_idx = (y * width + x) * 3;
let r = rgb.get(rgb_idx).copied().unwrap_or(0) as f32;
let g = rgb.get(rgb_idx + 1).copied().unwrap_or(0) as f32;
let b = rgb.get(rgb_idx + 2).copied().unwrap_or(0) as f32;
let y_val = (0.299 * r + 0.587 * g + 0.114 * b).clamp(0.0, 255.0) as u8;
self.nv12_buffer[y * width + x] = y_val;
if y % 2 == 0 && x % 2 == 0 {
let u = ((-0.169 * r - 0.331 * g + 0.500 * b) + 128.0).clamp(0.0, 255.0) as u8;
let v = ((0.500 * r - 0.419 * g - 0.081 * b) + 128.0).clamp(0.0, 255.0) as u8;
let uv_idx = y_size + (y / 2) * width + x;
self.nv12_buffer[uv_idx] = u;
self.nv12_buffer[uv_idx + 1] = v;
}
}
}
}
fn yuyv_to_nv12(&mut self, yuyv: &[u8]) {
let width = self.width as usize;
let height = self.height as usize;
let y_size = width * height;
for y in 0..height {
for x in (0..width).step_by(2) {
let yuyv_idx = (y * width + x) * 2;
let y0 = yuyv.get(yuyv_idx).copied().unwrap_or(0);
let y1 = yuyv.get(yuyv_idx + 2).copied().unwrap_or(0);
self.nv12_buffer[y * width + x] = y0;
self.nv12_buffer[y * width + x + 1] = y1;
if y % 2 == 0 {
let u = yuyv.get(yuyv_idx + 1).copied().unwrap_or(128);
let v = yuyv.get(yuyv_idx + 3).copied().unwrap_or(128);
let uv_idx = y_size + (y / 2) * width + x;
self.nv12_buffer[uv_idx] = u;
self.nv12_buffer[uv_idx + 1] = v;
}
}
}
}
fn grey_to_nv12(&mut self, grey: &[u8]) {
let y_size = (self.width as usize) * (self.height as usize);
self.nv12_buffer[..y_size].copy_from_slice(&grey[..y_size]);
self.nv12_buffer[y_size..].fill(128);
}
}