use std::{ffi::c_void, ptr};
use cudarc::driver::{DevicePtr, MappedBuffer};
use super::{api::ENCODE_API, encoder::Encoder, result::EncodeError, session::Session};
use crate::sys::nvEncodeAPI::{
NV_ENC_BUFFER_FORMAT,
NV_ENC_CREATE_BITSTREAM_BUFFER,
NV_ENC_CREATE_BITSTREAM_BUFFER_VER,
NV_ENC_CREATE_INPUT_BUFFER,
NV_ENC_CREATE_INPUT_BUFFER_VER,
NV_ENC_INPUT_RESOURCE_TYPE,
NV_ENC_LOCK_BITSTREAM,
NV_ENC_LOCK_BITSTREAM_VER,
NV_ENC_LOCK_INPUT_BUFFER,
NV_ENC_LOCK_INPUT_BUFFER_VER,
NV_ENC_MAP_INPUT_RESOURCE,
NV_ENC_MAP_INPUT_RESOURCE_VER,
NV_ENC_PIC_TYPE,
NV_ENC_REGISTER_RESOURCE,
};
pub trait EncoderInput {
fn pitch(&self) -> u32;
fn handle(&mut self) -> *mut c_void;
}
pub trait EncoderOutput {
fn handle(&mut self) -> *mut c_void;
}
impl Session {
pub fn create_input_buffer(&self) -> Result<Buffer<'_>, EncodeError> {
let mut create_input_buffer_params = NV_ENC_CREATE_INPUT_BUFFER {
version: NV_ENC_CREATE_INPUT_BUFFER_VER,
width: self.width,
height: self.height,
bufferFmt: self.buffer_format,
inputBuffer: ptr::null_mut(),
..Default::default()
};
unsafe {
(ENCODE_API.create_input_buffer)(self.encoder.ptr, &mut create_input_buffer_params)
}
.result(&self.encoder)?;
Ok(Buffer {
ptr: create_input_buffer_params.inputBuffer,
pitch: self.width,
encoder: &self.encoder,
})
}
pub fn create_output_bitstream(&self) -> Result<Bitstream<'_>, EncodeError> {
let mut create_bitstream_buffer_params = NV_ENC_CREATE_BITSTREAM_BUFFER {
version: NV_ENC_CREATE_BITSTREAM_BUFFER_VER,
bitstreamBuffer: ptr::null_mut(),
..Default::default()
};
unsafe {
(ENCODE_API.create_bitstream_buffer)(
self.encoder.ptr,
&mut create_bitstream_buffer_params,
)
}
.result(&self.encoder)?;
Ok(Bitstream {
ptr: create_bitstream_buffer_params.bitstreamBuffer,
encoder: &self.encoder,
})
}
pub fn register_cuda_resource(
&self,
pitch: u32,
mapped_buffer: MappedBuffer,
) -> Result<RegisteredResource<'_, MappedBuffer>, EncodeError> {
let stream = self.encoder.ctx.default_stream();
let (device_ptr, _) = mapped_buffer.device_ptr(&stream);
self.register_generic_resource(
mapped_buffer,
NV_ENC_INPUT_RESOURCE_TYPE::NV_ENC_INPUT_RESOURCE_TYPE_CUDADEVICEPTR,
device_ptr as *mut c_void,
pitch,
)
}
pub fn register_generic_resource<T>(
&self,
marker: T,
resource_type: NV_ENC_INPUT_RESOURCE_TYPE,
resource_to_register: *mut c_void,
pitch: u32,
) -> Result<RegisteredResource<'_, T>, EncodeError> {
let mut register_resource_params = NV_ENC_REGISTER_RESOURCE::new(
resource_type,
self.width,
self.height,
resource_to_register,
self.buffer_format,
)
.pitch(pitch);
unsafe { (ENCODE_API.register_resource)(self.encoder.ptr, &mut register_resource_params) }
.result(&self.encoder)?;
let registered_resource = register_resource_params.registeredResource;
let mut map_input_resource_params = NV_ENC_MAP_INPUT_RESOURCE {
version: NV_ENC_MAP_INPUT_RESOURCE_VER,
registeredResource: registered_resource,
mappedResource: ptr::null_mut(),
mappedBufferFmt: NV_ENC_BUFFER_FORMAT::NV_ENC_BUFFER_FORMAT_UNDEFINED,
..Default::default()
};
unsafe {
(ENCODE_API.map_input_resource)(self.encoder.ptr, &mut map_input_resource_params)
}
.result(&self.encoder)?;
let mapped_resource = map_input_resource_params.mappedResource;
Ok(RegisteredResource {
reg_ptr: registered_resource,
map_ptr: mapped_resource,
pitch,
encoder: &self.encoder,
_marker: marker,
})
}
}
#[derive(Debug)]
pub struct Buffer<'a> {
pub(crate) ptr: *mut c_void,
pitch: u32,
encoder: &'a Encoder,
}
unsafe impl Send for Buffer<'_> {}
impl<'a> Buffer<'a> {
pub fn lock<'b>(&'b mut self) -> Result<BufferLock<'b, 'a>, EncodeError> {
self.lock_inner(true)
}
pub fn try_lock<'b>(&'b mut self) -> Result<BufferLock<'b, 'a>, EncodeError> {
self.lock_inner(false)
}
#[inline]
fn lock_inner<'b>(&'b mut self, wait: bool) -> Result<BufferLock<'b, 'a>, EncodeError> {
let mut lock_input_buffer_params = NV_ENC_LOCK_INPUT_BUFFER {
version: NV_ENC_LOCK_INPUT_BUFFER_VER,
inputBuffer: self.ptr,
..Default::default()
};
if !wait {
lock_input_buffer_params.set_doNotWait(1);
}
unsafe { (ENCODE_API.lock_input_buffer)(self.encoder.ptr, &mut lock_input_buffer_params) }
.result(self.encoder)?;
let data_ptr = lock_input_buffer_params.bufferDataPtr;
let pitch = lock_input_buffer_params.pitch;
self.pitch = pitch;
Ok(BufferLock {
buffer: self,
data_ptr,
pitch,
})
}
}
impl Drop for Buffer<'_> {
fn drop(&mut self) {
unsafe { (ENCODE_API.destroy_input_buffer)(self.encoder.ptr, self.ptr) }
.result(self.encoder)
.expect("The encoder and buffer pointers should be valid.");
}
}
impl EncoderInput for Buffer<'_> {
fn pitch(&self) -> u32 {
self.pitch
}
fn handle(&mut self) -> *mut c_void {
self.ptr
}
}
#[allow(clippy::module_name_repetitions)]
#[derive(Debug)]
pub struct BufferLock<'a, 'b> {
buffer: &'a Buffer<'b>,
data_ptr: *mut c_void,
#[allow(dead_code)]
pitch: u32,
}
impl BufferLock<'_, '_> {
pub unsafe fn write(&mut self, data: &[u8]) {
data.as_ptr()
.copy_to(self.data_ptr.cast::<u8>(), data.len());
}
}
impl Drop for BufferLock<'_, '_> {
fn drop(&mut self) {
unsafe { (ENCODE_API.unlock_input_buffer)(self.buffer.encoder.ptr, self.buffer.ptr) }
.result(self.buffer.encoder)
.expect("The encoder and buffer pointers should be valid.");
}
}
#[derive(Debug)]
pub struct Bitstream<'a> {
pub(crate) ptr: *mut c_void,
encoder: &'a Encoder,
}
unsafe impl Send for Bitstream<'_> {}
impl Bitstream<'_> {
pub fn lock(&mut self) -> Result<BitstreamLock<'_, '_>, EncodeError> {
self.lock_inner(true)
}
pub fn try_lock(&mut self) -> Result<BitstreamLock<'_, '_>, EncodeError> {
self.lock_inner(false)
}
fn lock_inner(&mut self, wait: bool) -> Result<BitstreamLock<'_, '_>, EncodeError> {
let mut lock_bitstream_buffer_params = NV_ENC_LOCK_BITSTREAM {
version: NV_ENC_LOCK_BITSTREAM_VER,
outputBitstream: self.ptr,
..Default::default()
};
if !wait {
lock_bitstream_buffer_params.set_doNotWait(1);
}
unsafe { (ENCODE_API.lock_bitstream)(self.encoder.ptr, &mut lock_bitstream_buffer_params) }
.result(self.encoder)?;
let data_ptr = lock_bitstream_buffer_params.bitstreamBufferPtr;
let data_size = lock_bitstream_buffer_params.bitstreamSizeInBytes as usize;
let data = unsafe { std::slice::from_raw_parts_mut(data_ptr.cast::<u8>(), data_size) };
Ok(BitstreamLock {
bitstream: self,
data,
frame_index: lock_bitstream_buffer_params.frameIdx,
timestamp: lock_bitstream_buffer_params.outputTimeStamp,
duration: lock_bitstream_buffer_params.outputDuration,
picture_type: lock_bitstream_buffer_params.pictureType,
})
}
}
impl Drop for Bitstream<'_> {
fn drop(&mut self) {
unsafe { (ENCODE_API.destroy_bitstream_buffer)(self.encoder.ptr, self.ptr) }
.result(self.encoder)
.expect("The encoder and bitstream pointers should be valid.");
}
}
impl EncoderOutput for Bitstream<'_> {
fn handle(&mut self) -> *mut c_void {
self.ptr
}
}
#[derive(Debug)]
pub struct BitstreamLock<'a, 'b> {
bitstream: &'a Bitstream<'b>,
data: &'a [u8],
frame_index: u32,
timestamp: u64,
duration: u64,
picture_type: NV_ENC_PIC_TYPE,
}
impl BitstreamLock<'_, '_> {
#[must_use]
pub fn data(&self) -> &[u8] {
self.data
}
#[must_use]
pub fn frame_index(&self) -> u32 {
self.frame_index
}
#[must_use]
pub fn timestamp(&self) -> u64 {
self.timestamp
}
#[must_use]
pub fn duration(&self) -> u64 {
self.duration
}
#[must_use]
pub fn picture_type(&self) -> NV_ENC_PIC_TYPE {
self.picture_type
}
}
impl Drop for BitstreamLock<'_, '_> {
fn drop(&mut self) {
unsafe { (ENCODE_API.unlock_bitstream)(self.bitstream.encoder.ptr, self.bitstream.ptr) }
.result(self.bitstream.encoder)
.expect("The encoder and bitstream pointers should be valid.");
}
}
#[derive(Debug)]
pub struct RegisteredResource<'a, T> {
pub(crate) reg_ptr: *mut c_void,
pub(crate) map_ptr: *mut c_void,
pitch: u32,
encoder: &'a Encoder,
_marker: T,
}
unsafe impl Send for RegisteredResource<'_, MappedBuffer> {}
impl<T> Drop for RegisteredResource<'_, T> {
fn drop(&mut self) {
unsafe { (ENCODE_API.unmap_input_resource)(self.encoder.ptr, self.map_ptr) }
.result(self.encoder)
.expect("The encoder pointer and map handle should be valid.");
unsafe { (ENCODE_API.unregister_resource)(self.encoder.ptr, self.reg_ptr) }
.result(self.encoder)
.expect("The encoder pointer and resource handle should be valid.");
}
}
impl<T> EncoderInput for RegisteredResource<'_, T> {
fn pitch(&self) -> u32 {
self.pitch
}
fn handle(&mut self) -> *mut c_void {
self.map_ptr
}
}