use std::ffi::c_void;
use std::os::raw::{c_int, c_uint, c_ulong};
use std::panic::{catch_unwind, AssertUnwindSafe};
use std::ptr;
use crate::frame::ColorSpace;
use super::NvdecError;
use super::convert::validate_format;
use super::ffi::{
CU_MEMORYTYPE_DEVICE, CU_MEMORYTYPE_HOST,
CUVID_CHROMA_420, CUVID_CREATE_PREFER_CUVID, CUVID_FMT_NV12, CUVID_FMT_P016, CUVID_H264,
CUdeviceptr, CUvideodecoder,
CudaMemcpy2D, CuVideoDecodeCaps, CuVideoDecodeCreateInfo, CuVideoDispInfo, CuVideoFormat,
CuVideoPicParams, CuVideoProcParams,
};
use super::state::{CallbackState, DecodedFrame};
pub unsafe extern "C" fn sequence_callback(
user_data: *mut c_void,
format: *mut CuVideoFormat,
) -> c_int {
unsafe {
catch_unwind(AssertUnwindSafe(|| {
if user_data.is_null() || format.is_null() {
return 0;
}
let state = &mut *(user_data as *mut CallbackState);
let fmt = &*format;
if fmt.codec != state.codec_type {
tracing::warn!(
expected = state.codec_type,
got = fmt.codec,
"NVDEC sequence_callback codec mismatch — ABI drift suspected"
);
state.set_error(format!(
"sequence_callback codec mismatch: expected {} got {}",
state.codec_type, fmt.codec
));
return 0;
}
let num_surfaces = (fmt.min_num_decode_surfaces as c_uint).clamp(20, 32) as c_ulong;
tracing::info!(
codec = fmt.codec,
width = fmt.coded_width,
height = fmt.coded_height,
chroma = fmt.chroma_format,
bit_depth = fmt.bit_depth_luma_minus8 + 8,
surfaces = num_surfaces,
"NVDEC backend engaged"
);
if let Some(err) = validate_format(
fmt.chroma_format,
fmt.bit_depth_luma_minus8,
fmt.coded_width,
fmt.coded_height,
) {
match &err {
NvdecError::UnsupportedChroma { label, .. } => {
tracing::warn!(
codec = state.codec_type,
w = fmt.coded_width,
h = fmt.coded_height,
chroma = fmt.chroma_format,
chroma_label = *label,
"NVDEC rejecting: chroma {} unsupported",
label
);
}
NvdecError::UnsupportedPixelFormat { bit_depth } => {
tracing::warn!(
codec = state.codec_type,
w = fmt.coded_width,
h = fmt.coded_height,
bit_depth = bit_depth,
"NVDEC rejecting: {}-bit content unsupported",
bit_depth
);
}
NvdecError::UnsupportedByHardware { reason } => {
tracing::warn!(codec = state.codec_type, "NVDEC rejecting: {reason}");
}
}
state.set_typed_error(err);
return 0;
}
{
let mut caps: CuVideoDecodeCaps = std::mem::zeroed();
caps.codec_type = state.codec_type;
caps.chroma_format = fmt.chroma_format;
caps.bit_depth_minus8 = fmt.bit_depth_luma_minus8 as u32;
if (state.cuvid_get_decoder_caps)(&mut caps) == 0 {
if caps.is_supported == 0 {
let reason = format!(
"GPU NVDEC does not support codec={} chroma={} {}-bit",
state.codec_type,
fmt.chroma_format,
fmt.bit_depth_luma_minus8 + 8
);
tracing::warn!(
codec = state.codec_type,
chroma = fmt.chroma_format,
bit_depth = fmt.bit_depth_luma_minus8 + 8,
"NVDEC rejecting: {reason}"
);
state.set_typed_error(NvdecError::UnsupportedByHardware { reason });
return 0;
}
if caps.max_width > 0
&& caps.max_height > 0
&& (fmt.coded_width > caps.max_width || fmt.coded_height > caps.max_height)
{
let reason = format!(
"frame {}x{} exceeds NVDEC max {}x{}",
fmt.coded_width, fmt.coded_height, caps.max_width, caps.max_height
);
tracing::warn!(
w = fmt.coded_width,
h = fmt.coded_height,
max_w = caps.max_width,
max_h = caps.max_height,
"NVDEC rejecting: {reason}"
);
state.set_typed_error(NvdecError::UnsupportedByHardware { reason });
return 0;
}
tracing::debug!(
codec = state.codec_type,
max_w = caps.max_width,
max_h = caps.max_height,
"NVDEC capability validated"
);
}
}
let is_high_bit_depth = fmt.bit_depth_luma_minus8 > 0;
state.bit_depth_luma_minus8 = fmt.bit_depth_luma_minus8;
let cp = fmt.video_signal_description[1];
let tc = fmt.video_signal_description[2];
let mc = fmt.video_signal_description[3];
let full_range = (fmt.video_signal_description[0] >> 3) & 1 == 1;
state.vui_colour_primaries = cp;
state.vui_transfer_characteristics = tc;
state.vui_matrix_coefficients = mc;
state.vui_full_range_flag = full_range;
state.color_space = match mc {
1 => ColorSpace::Bt709,
5 | 6 => ColorSpace::Bt601,
9 | 10 => ColorSpace::Bt2020,
_ => {
if is_high_bit_depth {
ColorSpace::Bt2020
} else {
ColorSpace::Bt709
}
}
};
tracing::info!(
matrix_coefficients = mc,
color_primaries = fmt.video_signal_description[1],
transfer = fmt.video_signal_description[2],
color_space = ?state.color_space,
"NVDEC color metadata"
);
if state.decoder.is_none() {
let mut create_info: CuVideoDecodeCreateInfo = std::mem::zeroed();
create_info.code_width = fmt.coded_width as c_ulong;
create_info.coded_height = fmt.coded_height as c_ulong;
create_info.num_decode_surfaces = num_surfaces;
create_info.codec_type = state.codec_type;
create_info.chroma_format = CUVID_CHROMA_420;
create_info.creation_flags = CUVID_CREATE_PREFER_CUVID;
create_info.bit_depth_minus8 = fmt.bit_depth_luma_minus8 as c_ulong;
create_info.output_format = if is_high_bit_depth {
CUVID_FMT_P016
} else {
CUVID_FMT_NV12
};
create_info.deinterlace_mode = if fmt.progressive_sequence != 0 {
0
} else if state.codec_type == CUVID_H264 {
2
} else {
1
};
create_info.target_width = fmt.coded_width as c_ulong;
create_info.target_height = fmt.coded_height as c_ulong;
create_info.num_output_surfaces = 4;
create_info.max_width = 0;
create_info.max_height = 0;
state.width = fmt.coded_width;
state.height = fmt.coded_height;
let mut decoder: CUvideodecoder = ptr::null_mut();
let rc = (state.cuvid_create_decoder)(&mut decoder, &mut create_info);
if rc != 0 {
state.set_error(format!("cuvidCreateDecoder failed: {rc}"));
return 0;
}
state.decoder = Some(decoder);
}
num_surfaces as c_int
}))
.unwrap_or(0)
}
}
pub unsafe extern "C" fn decode_callback(
user_data: *mut c_void,
pic_params: *mut CuVideoPicParams,
) -> c_int {
unsafe {
catch_unwind(AssertUnwindSafe(|| {
if user_data.is_null() || pic_params.is_null() {
return 0;
}
let state = &mut *(user_data as *mut CallbackState);
let Some(decoder) = state.decoder else {
state.set_error("decode_callback before decoder created");
return 0;
};
let rc = (state.cuvid_decode_picture)(decoder, pic_params);
if rc != 0 {
state.set_error(format!("cuvidDecodePicture failed: {rc}"));
return 0;
}
1
}))
.unwrap_or(0)
}
}
pub unsafe extern "C" fn display_callback(
user_data: *mut c_void,
disp_info: *mut CuVideoDispInfo,
) -> c_int {
unsafe {
catch_unwind(AssertUnwindSafe(|| {
if user_data.is_null() || disp_info.is_null() {
return 0;
}
let state = &mut *(user_data as *mut CallbackState);
let info = &*disp_info;
let Some(decoder) = state.decoder else {
state.set_error("display_callback before decoder created");
return 0;
};
if info.picture_index < 0 {
state.set_error(format!(
"display_callback picture_index invalid: {}",
info.picture_index
));
return 0;
}
let mut proc_params: CuVideoProcParams = std::mem::zeroed();
proc_params.progressive_frame = info.progressive_frame;
proc_params.second_field = 0;
proc_params.top_field_first = info.top_field_first;
proc_params.unpaired_field = 0;
let mut frame_ptr: CUdeviceptr = 0;
let mut pitch: c_uint = 0;
let rc = (state.cuvid_map_video_frame)(
decoder,
info.picture_index,
&mut frame_ptr,
&mut pitch,
&mut proc_params,
);
if rc != 0 {
state.set_error(format!("cuvidMapVideoFrame failed: {rc}"));
return 0;
}
let width = state.width as usize;
let height = state.height as usize;
let bytes_per_sample = if state.bit_depth_luma_minus8 > 0 {
2
} else {
1
};
let row_bytes = width * bytes_per_sample;
let chroma_height = height.div_ceil(2);
let y_bytes = row_bytes * height;
let uv_bytes = row_bytes * chroma_height;
let mut host_buf = vec![0u8; y_bytes + uv_bytes];
let mut luma_copy: CudaMemcpy2D = std::mem::zeroed();
luma_copy.src_memory_type = CU_MEMORYTYPE_DEVICE;
luma_copy.src_device = frame_ptr;
luma_copy.src_pitch = pitch as usize;
luma_copy.dst_memory_type = CU_MEMORYTYPE_HOST;
luma_copy.dst_host = host_buf.as_mut_ptr() as *mut c_void;
luma_copy.dst_pitch = row_bytes;
luma_copy.width_in_bytes = row_bytes;
luma_copy.height = height;
let rc = (state.cu_memcpy2d)(&luma_copy);
if rc != 0 {
(state.cuvid_unmap_video_frame)(decoder, frame_ptr);
state.set_error(format!("cuMemcpy2D (luma) failed: {rc}"));
return 0;
}
let chroma_src = frame_ptr + (pitch as CUdeviceptr) * (height as CUdeviceptr);
let mut chroma_copy: CudaMemcpy2D = std::mem::zeroed();
chroma_copy.src_memory_type = CU_MEMORYTYPE_DEVICE;
chroma_copy.src_device = chroma_src;
chroma_copy.src_pitch = pitch as usize;
chroma_copy.dst_memory_type = CU_MEMORYTYPE_HOST;
chroma_copy.dst_host = host_buf[y_bytes..].as_mut_ptr() as *mut c_void;
chroma_copy.dst_pitch = row_bytes;
chroma_copy.width_in_bytes = row_bytes;
chroma_copy.height = chroma_height;
let rc = (state.cu_memcpy2d)(&chroma_copy);
let _ = (state.cuvid_unmap_video_frame)(decoder, frame_ptr);
if rc != 0 {
state.set_error(format!("cuMemcpy2D (chroma) failed: {rc}"));
return 0;
}
if let Ok(mut c) = state.collector.lock() {
c.frames.push_back(DecodedFrame {
nv12: host_buf,
width: state.width,
height: state.height,
bit_depth_minus8: state.bit_depth_luma_minus8,
color_space: state.color_space,
timestamp: info.timestamp,
});
}
1
}))
.unwrap_or(0)
}
}
pub unsafe extern "C" fn get_operating_point_callback(
_user_data: *mut c_void,
_op_info: *mut c_void,
) -> c_int {
catch_unwind(AssertUnwindSafe(|| 0_i32)).unwrap_or(0)
}