use std::{collections::HashMap, sync::Arc};
use ash::vk;
use h264_reader::nal::{
pps::PicParameterSet,
sps::{Profile, SeqParameterSet},
};
use images::DecodingImages;
use parameters::{SessionParams, VideoSessionParametersManager};
use crate::{
VulkanDecoderError,
device::DecodingDevice,
vulkan_decoder::{DecoderTracker, DecoderTrackerWaitState},
wrappers::{
DecodeInputBuffer, DecodingQueryPool, H264DecodeProfileInfo, OpenCommandBuffer,
SeqParameterSetExt, VideoSession, h264_level_idc_to_max_dpb_mbs, vk_to_h264_level_idc,
},
};
mod images;
mod parameters;
pub(super) struct VideoSessionResources<'a> {
pub(crate) video_session: VideoSession,
pub(crate) parameters: SessionParams<'a>,
pub(crate) parameters_manager: VideoSessionParametersManager,
pub(crate) decoding_images: DecodingImages<'a>,
pub(crate) sps: HashMap<u8, SeqParameterSet>,
pub(crate) pps: HashMap<(u8, u8), PicParameterSet>,
pub(crate) decode_query_pool: Option<DecodingQueryPool>,
pub(crate) decode_buffer: DecodeInputBuffer,
parameters_scheduled_for_reset: Option<SessionParams<'a>>,
}
fn calculate_max_num_reorder_frames(sps: &SeqParameterSet) -> Result<u64, VulkanDecoderError> {
let fallback_max_num_reorder_frames = if [44u8, 86, 100, 110, 122, 244]
.contains(&sps.profile_idc.into())
&& sps.constraint_flags.flag3()
{
0
} else if let Profile::Baseline = sps.profile() {
0
} else {
h264_level_idc_to_max_dpb_mbs(sps.level_idc)?
/ ((sps.pic_width_in_mbs_minus1 as u64 + 1)
* (sps.pic_height_in_map_units_minus1 as u64 + 1))
.min(16)
};
let max_num_reorder_frames = sps
.vui_parameters
.as_ref()
.and_then(|v| v.bitstream_restrictions.as_ref())
.map(|b| b.max_num_reorder_frames as u64)
.unwrap_or(fallback_max_num_reorder_frames);
Ok(max_num_reorder_frames)
}
impl<'a> VideoSessionResources<'a> {
pub(crate) fn new_from_sps(
decoding_device: &DecodingDevice,
decode_buffer: OpenCommandBuffer,
sps: SeqParameterSet,
usage_info: vk::VideoDecodeUsageInfoKHR<'a>,
tracker: &mut DecoderTracker,
) -> Result<Self, VulkanDecoderError> {
let profile_info = Arc::new(H264DecodeProfileInfo::from_sps_decode(&sps, usage_info)?);
let level_idc = sps.level_idc;
let max_level_idc = vk_to_h264_level_idc(
decoding_device
.profile_capabilities
.h264_decode_capabilities
.max_level_idc,
)?;
if level_idc > max_level_idc {
return Err(VulkanDecoderError::InvalidInputData(format!(
"stream has level_idc = {level_idc}, while the GPU can decode at most {max_level_idc}"
)));
}
let max_coded_extent = sps.size()?;
let max_dpb_slots = sps.max_num_ref_frames + 1;
let max_active_references = sps.max_num_ref_frames;
let max_num_reorder_frames = calculate_max_num_reorder_frames(&sps)?;
let video_session = VideoSession::new(
&decoding_device.vulkan_device,
&decoding_device.h264_decode_queues,
&profile_info.profile_info.profile_info,
max_coded_extent,
max_dpb_slots,
max_active_references,
vk::VideoSessionCreateFlagsKHR::empty(),
&decoding_device
.profile_capabilities
.video_capabilities
.std_header_version,
)?;
let mut parameters_manager =
VideoSessionParametersManager::new(decoding_device, video_session.session)?;
parameters_manager.put_sps(&sps)?;
let decoding_images = Self::new_decoding_images(
decoding_device,
&profile_info,
max_coded_extent,
max_dpb_slots,
decode_buffer,
tracker,
)?;
let sps = HashMap::from_iter([(sps.id().id(), sps)]);
let decode_query_pool = if decoding_device
.h264_decode_queues
.supports_result_status_queries()
{
Some(DecodingQueryPool::new(
decoding_device.vulkan_device.device.clone(),
profile_info.profile_info.profile_info,
)?)
} else {
None
};
let decode_buffer =
DecodeInputBuffer::new(decoding_device.allocator.clone(), &profile_info)?;
let parameters = SessionParams {
max_coded_extent,
max_dpb_slots,
max_active_references,
max_num_reorder_frames,
profile_info,
level_idc,
};
Ok(VideoSessionResources {
parameters,
video_session,
parameters_manager,
decoding_images,
sps,
pps: HashMap::new(),
decode_query_pool,
decode_buffer,
parameters_scheduled_for_reset: None,
})
}
pub(crate) fn process_sps(
&mut self,
sps: SeqParameterSet,
usage_info: vk::VideoDecodeUsageInfoKHR<'a>,
) -> Result<(), VulkanDecoderError> {
let new_session_params = SessionParams {
max_coded_extent: sps.size()?,
max_dpb_slots: sps.max_num_ref_frames + 1, max_active_references: sps.max_num_ref_frames,
max_num_reorder_frames: calculate_max_num_reorder_frames(&sps)?,
profile_info: Arc::new(H264DecodeProfileInfo::from_sps_decode(&sps, usage_info)?),
level_idc: sps.level_idc,
};
let current_session_params = self
.parameters_scheduled_for_reset
.take()
.unwrap_or_else(|| self.parameters.clone());
self.parameters_scheduled_for_reset = Some(SessionParams::combine(
current_session_params,
new_session_params,
));
self.parameters_manager.put_sps(&sps)?;
self.sps.insert(sps.id().id(), sps);
Ok(())
}
pub(crate) fn process_pps(&mut self, pps: PicParameterSet) -> Result<(), VulkanDecoderError> {
self.parameters_manager.put_pps(&pps)?;
self.pps.insert(
(pps.seq_parameter_set_id.id(), pps.pic_parameter_set_id.id()),
pps,
);
Ok(())
}
pub(crate) fn ensure_session(
&mut self,
decoding_device: &DecodingDevice,
decode_buffer: OpenCommandBuffer,
tracker: &mut DecoderTracker,
) -> Result<(), VulkanDecoderError> {
let Some(new_params) = self.parameters_scheduled_for_reset.take() else {
return Ok(());
};
if self.parameters.is_valid(&new_params) {
self.parameters.max_num_reorder_frames = new_params.max_num_reorder_frames;
return Ok(());
}
let max_level_idc = vk_to_h264_level_idc(
decoding_device
.profile_capabilities
.h264_decode_capabilities
.max_level_idc,
)?;
if new_params.level_idc > max_level_idc {
return Err(VulkanDecoderError::InvalidInputData(format!(
"stream has level_idc = {}, while the GPU can decode at most {}",
new_params.level_idc, max_level_idc
)));
}
if self.parameters.profile_info != new_params.profile_info {
self.decode_query_pool = match decoding_device
.h264_decode_queues
.supports_result_status_queries()
{
true => Some(DecodingQueryPool::new(
decoding_device.vulkan_device.device.clone(),
new_params.profile_info.profile_info.profile_info,
)?),
false => None,
};
self.decode_buffer = DecodeInputBuffer::new(
decoding_device.allocator.clone(),
&new_params.profile_info,
)?;
}
self.video_session = VideoSession::new(
&decoding_device.vulkan_device,
&decoding_device.h264_decode_queues,
&new_params.profile_info.profile_info.profile_info,
new_params.max_coded_extent,
new_params.max_dpb_slots,
new_params.max_active_references,
vk::VideoSessionCreateFlagsKHR::empty(),
&decoding_device
.profile_capabilities
.video_capabilities
.std_header_version,
)?;
self.parameters_manager
.change_session(self.video_session.session)?;
self.decoding_images = Self::new_decoding_images(
decoding_device,
&new_params.profile_info,
self.video_session.max_coded_extent,
self.video_session.max_dpb_slots,
decode_buffer,
tracker,
)?;
self.parameters = new_params;
Ok(())
}
fn new_decoding_images(
decoding_device: &DecodingDevice,
profile: &H264DecodeProfileInfo,
max_coded_extent: vk::Extent2D,
max_dpb_slots: u32,
mut decode_buffer: OpenCommandBuffer,
tracker: &mut DecoderTracker,
) -> Result<DecodingImages<'a>, VulkanDecoderError> {
let decoding_images = DecodingImages::new(
decoding_device,
&mut decode_buffer,
tracker.image_layout_tracker.clone(),
profile,
&decoding_device
.profile_capabilities
.h264_dpb_format_properties,
&decoding_device
.profile_capabilities
.h264_dst_format_properties,
max_coded_extent,
max_dpb_slots,
)?;
decoding_device.h264_decode_queues.submit_chain_semaphore(
decode_buffer.end()?,
tracker,
vk::PipelineStageFlags2::ALL_COMMANDS,
vk::PipelineStageFlags2::ALL_COMMANDS,
DecoderTrackerWaitState::NewDecodingImagesLayoutTransition,
)?;
Ok(decoding_images)
}
pub(crate) fn free_reference_picture(&mut self, i: usize) {
self.decoding_images.free_reference_picture(i);
}
}