use anyhow::Result;
use cudarc::driver::sys::CUresult;
use cudarc::driver::CudaContext;
use nvidia_video_codec_sdk::sys::cuviddec::*;
use nvidia_video_codec_sdk::sys::nvcuvid::*;
use std::ffi::c_void;
use std::ptr;
const CUDA_SUCCESS: CUresult = CUresult::CUDA_SUCCESS;
pub struct DecodedFrame {
pub width: u32,
pub height: u32,
pub data: Vec<u8>,
pub bits_per_component: u8,
}
struct NvdecFrame {
data: Vec<u8>,
coded_width: u32,
coded_height: u32,
bit_depth: u8,
}
pub struct NvdecAv1Decoder {
parser: CUvideoparser,
decoder: CUvideodecoder,
_ctx: std::sync::Arc<CudaContext>,
width: u32,
height: u32,
high_bitdepth: bool,
bit_depth: u8,
decoded_frames: Vec<NvdecFrame>,
}
impl NvdecAv1Decoder {
pub fn new(high_bitdepth: bool) -> Result<Self> {
let ctx = CudaContext::new(0)
.map_err(|e| anyhow::anyhow!("Failed to create CUDA context: {}", e))?;
Ok(NvdecAv1Decoder {
parser: ptr::null_mut(),
decoder: ptr::null_mut(),
_ctx: ctx,
width: 0,
height: 0,
high_bitdepth,
bit_depth: if high_bitdepth { 10 } else { 8 },
decoded_frames: Vec::new(),
})
}
fn ensure_parser(&mut self) -> Result<()> {
if !self.parser.is_null() {
return Ok(());
}
let mut parser_params: CUVIDPARSERPARAMS = unsafe { std::mem::zeroed() };
parser_params.CodecType = cudaVideoCodec::cudaVideoCodec_AV1;
parser_params.ulMaxNumDecodeSurfaces = 8;
parser_params.ulMaxDisplayDelay = 0;
parser_params.pUserData = self as *mut _ as *mut c_void;
parser_params.pfnSequenceCallback = Some(Self::sequence_callback);
parser_params.pfnDecodePicture = Some(Self::decode_callback);
parser_params.pfnDisplayPicture = Some(Self::display_callback);
let result = unsafe { cuvidCreateVideoParser(&mut self.parser, &mut parser_params) };
if result != CUDA_SUCCESS {
anyhow::bail!("Failed to create AV1 video parser: {:?}", result);
}
Ok(())
}
pub fn decode(&mut self, data: &[u8]) -> Result<Option<DecodedFrame>> {
self.ensure_parser()?;
let mut packet: CUVIDSOURCEDATAPACKET = unsafe { std::mem::zeroed() };
packet.payload = data.as_ptr();
packet.payload_size = data.len() as u64;
let result = unsafe { cuvidParseVideoData(self.parser, &mut packet) };
if result != CUDA_SUCCESS {
anyhow::bail!("Failed to parse AV1 data: {:?}", result);
}
if let Some(frame) = self.decoded_frames.pop() {
let width = frame.coded_width;
let height = frame.coded_height;
let out = if frame.bit_depth > 8 {
Self::extract_p016_y_plane(&frame.data, width as usize, height as usize)
} else {
Self::nv12_to_rgb(&frame.data, width as usize, height as usize)
};
Ok(Some(DecodedFrame {
width,
height,
data: out,
bits_per_component: frame.bit_depth,
}))
} else {
Ok(None)
}
}
fn extract_p016_y_plane(data: &[u8], width: usize, height: usize) -> Vec<u8> {
let stride = width * 2;
let mut out = Vec::with_capacity(width * height * 2);
for row in 0..height {
let start = row * stride;
let end = start + width * 2;
if end <= data.len() {
out.extend_from_slice(&data[start..end]);
}
}
out
}
fn nv12_to_rgb(nv12: &[u8], width: usize, height: usize) -> Vec<u8> {
let y_plane_size = width * height;
let mut rgb = vec![0u8; width * height * 3];
for y in 0..height {
for x in 0..width {
let y_val = nv12.get(y * width + x).copied().unwrap_or(0) as f32;
let uv_idx = y_plane_size + (y / 2) * width + (x / 2) * 2;
let u = nv12.get(uv_idx).copied().unwrap_or(128) as f32;
let v = nv12.get(uv_idx + 1).copied().unwrap_or(128) as f32;
let c = y_val - 16.0;
let d = u - 128.0;
let e = v - 128.0;
let r = (1.164 * c + 1.596 * e).clamp(0.0, 255.0) as u8;
let g = (1.164 * c - 0.392 * d - 0.813 * e).clamp(0.0, 255.0) as u8;
let b = (1.164 * c + 2.017 * d).clamp(0.0, 255.0) as u8;
let rgb_idx = (y * width + x) * 3;
rgb[rgb_idx] = r;
rgb[rgb_idx + 1] = g;
rgb[rgb_idx + 2] = b;
}
}
rgb
}
extern "C" fn sequence_callback(
user_data: *mut c_void,
video_format: *mut CUVIDEOFORMAT,
) -> i32 {
let decoder = unsafe { &mut *(user_data as *mut NvdecAv1Decoder) };
let format = unsafe { &*video_format };
let output_format = if decoder.high_bitdepth {
cudaVideoSurfaceFormat::cudaVideoSurfaceFormat_P016
} else {
cudaVideoSurfaceFormat::cudaVideoSurfaceFormat_NV12
};
let num_surfaces = (format.min_num_decode_surfaces as u64).max(8);
let mut create_info: CUVIDDECODECREATEINFO = unsafe { std::mem::zeroed() };
create_info.ulWidth = format.coded_width as u64;
create_info.ulHeight = format.coded_height as u64;
create_info.ulNumDecodeSurfaces = num_surfaces;
create_info.CodecType = format.codec;
create_info.ChromaFormat = format.chroma_format;
create_info.ulCreationFlags = 0;
create_info.OutputFormat = output_format;
create_info.DeinterlaceMode = cudaVideoDeinterlaceMode::cudaVideoDeinterlaceMode_Adaptive;
create_info.ulTargetWidth = format.coded_width as u64;
create_info.ulTargetHeight = format.coded_height as u64;
create_info.ulNumOutputSurfaces = 4;
create_info.bitDepthMinus8 = format.bit_depth_luma_minus8 as u64;
if !decoder.decoder.is_null() {
let _ = unsafe { cuvidDestroyDecoder(decoder.decoder) };
decoder.decoder = ptr::null_mut();
}
let result = unsafe { cuvidCreateDecoder(&mut decoder.decoder, &mut create_info) };
if result != CUDA_SUCCESS {
tracing::error!("Failed to create AV1 NVDEC decoder: {:?}", result);
return 0;
}
decoder.width = format.coded_width;
decoder.height = format.coded_height;
let actual_bits = 8 + format.bit_depth_luma_minus8 as u8;
if actual_bits > 8 {
decoder.bit_depth = actual_bits;
}
tracing::info!(
"[nvdec-av1] sequence: {}x{}, bit_depth={}, format={:?}",
format.coded_width,
format.coded_height,
actual_bits,
output_format,
);
num_surfaces as i32
}
extern "C" fn decode_callback(user_data: *mut c_void, pic_params: *mut CUVIDPICPARAMS) -> i32 {
let decoder = unsafe { &mut *(user_data as *mut NvdecAv1Decoder) };
if decoder.decoder.is_null() {
return 0;
}
let result = unsafe { cuvidDecodePicture(decoder.decoder, pic_params) };
if result != CUDA_SUCCESS {
tracing::error!("Failed to decode AV1 picture: {:?}", result);
return 0;
}
1
}
extern "C" fn display_callback(
user_data: *mut c_void,
disp_info: *mut CUVIDPARSERDISPINFO,
) -> i32 {
let decoder = unsafe { &mut *(user_data as *mut NvdecAv1Decoder) };
let info = unsafe { &*disp_info };
if decoder.decoder.is_null() || info.picture_index < 0 {
return 0;
}
let mut proc_params: CUVIDPROCPARAMS = unsafe { std::mem::zeroed() };
proc_params.progressive_frame = info.progressive_frame as i32;
let mut dev_ptr: u64 = 0;
let mut pitch: u32 = 0;
let result = unsafe {
cuvidMapVideoFrame64(
decoder.decoder,
info.picture_index,
&mut dev_ptr,
&mut pitch,
&mut proc_params,
)
};
if result != CUDA_SUCCESS {
tracing::error!("Failed to map AV1 video frame: {:?}", result);
return 0;
}
let width = decoder.width as usize;
let height = decoder.height as usize;
let bytes_per_pixel = if decoder.bit_depth > 8 { 2 } else { 1 };
let row_bytes = width * bytes_per_pixel;
let total_size = row_bytes * height + row_bytes * (height / 2);
let mut frame_data = vec![0u8; total_size];
for y in 0..height {
let src = (dev_ptr + (y as u64) * (pitch as u64)) as cudarc::driver::sys::CUdeviceptr;
let dst = unsafe { frame_data.as_mut_ptr().add(y * row_bytes) as *mut c_void };
let r = unsafe { cudarc::driver::sys::cuMemcpyDtoH_v2(dst, src, row_bytes) };
if r != CUDA_SUCCESS {
tracing::error!("[nvdec-av1] Y row {} copy failed: {:?}", y, r);
break;
}
}
let uv_offset = row_bytes * height;
let uv_height = height / 2;
for y in 0..uv_height {
let src = (dev_ptr + ((height + y) as u64) * (pitch as u64))
as cudarc::driver::sys::CUdeviceptr;
let dst =
unsafe { frame_data.as_mut_ptr().add(uv_offset + y * row_bytes) as *mut c_void };
let r = unsafe { cudarc::driver::sys::cuMemcpyDtoH_v2(dst, src, row_bytes) };
if r != CUDA_SUCCESS {
tracing::error!("[nvdec-av1] UV row {} copy failed: {:?}", y, r);
break;
}
}
let _ = unsafe { cuvidUnmapVideoFrame64(decoder.decoder, dev_ptr) };
decoder.decoded_frames.push(NvdecFrame {
data: frame_data,
coded_width: decoder.width,
coded_height: decoder.height,
bit_depth: decoder.bit_depth,
});
1
}
}
impl Drop for NvdecAv1Decoder {
fn drop(&mut self) {
if !self.parser.is_null() {
let _ = unsafe { cuvidDestroyVideoParser(self.parser) };
}
if !self.decoder.is_null() {
let _ = unsafe { cuvidDestroyDecoder(self.decoder) };
}
}
}
unsafe impl Send for NvdecAv1Decoder {}
pub fn p016_y_to_depth_mm(y_data: &[u8], depth_shift: u32) -> Vec<u16> {
let pixel_count = y_data.len() / 2;
let mut depth = Vec::with_capacity(pixel_count);
for i in 0..pixel_count {
let val = u16::from_le_bytes([y_data[i * 2], y_data[i * 2 + 1]]);
let gray10 = val >> 6; let mm = (gray10 as u32) << depth_shift;
depth.push(mm as u16);
}
depth
}