use std::ffi::c_void;
use std::ptr::NonNull;
use edgefirst_tflite_sys::hal_ffi::{HalDmaBufFunctions, HalDmabufTensorInfo, HalDtype};
use edgefirst_tflite_sys::vx_ffi::{
VxDmaBufDesc, VxDmaBufFunctions, VxDmaBufOwnership, VxDmaBufSyncMode,
};
use edgefirst_tflite_sys::{kTfLiteNullBufferHandle, TfLiteDelegate};
use crate::error::{self, Error, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DType {
U8,
I8,
U16,
I16,
U32,
I32,
U64,
I64,
F16,
F32,
F64,
}
impl DType {
fn from_hal(h: HalDtype) -> Self {
match h {
HalDtype::U8 => Self::U8,
HalDtype::I8 => Self::I8,
HalDtype::U16 => Self::U16,
HalDtype::I16 => Self::I16,
HalDtype::U32 => Self::U32,
HalDtype::I32 => Self::I32,
HalDtype::U64 => Self::U64,
HalDtype::I64 => Self::I64,
HalDtype::F16 => Self::F16,
HalDtype::F32 => Self::F32,
HalDtype::F64 => Self::F64,
}
}
}
impl std::fmt::Display for DType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(match self {
Self::U8 => "u8",
Self::I8 => "i8",
Self::U16 => "u16",
Self::I16 => "i16",
Self::U32 => "u32",
Self::I32 => "i32",
Self::U64 => "u64",
Self::I64 => "i64",
Self::F16 => "f16",
Self::F32 => "f32",
Self::F64 => "f64",
})
}
}
#[derive(Debug, Clone)]
pub struct TensorInfo {
pub size: usize,
pub offset: usize,
pub shape: Vec<usize>,
pub fd: i32,
pub dtype: DType,
}
#[deprecated(note = "`VxDelegate`-specific, will be removed in a future release")]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum SyncMode {
None,
Read,
Write,
ReadWrite,
}
#[allow(deprecated)]
impl SyncMode {
fn to_raw(self) -> VxDmaBufSyncMode {
match self {
Self::None => VxDmaBufSyncMode::None,
Self::Read => VxDmaBufSyncMode::Read,
Self::Write => VxDmaBufSyncMode::Write,
Self::ReadWrite => VxDmaBufSyncMode::ReadWrite,
}
}
}
#[deprecated(note = "`VxDelegate`-specific, will be removed in a future release")]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Ownership {
Client,
Delegate,
}
#[allow(deprecated)]
impl Ownership {
fn to_raw(self) -> VxDmaBufOwnership {
match self {
Self::Client => VxDmaBufOwnership::Client,
Self::Delegate => VxDmaBufOwnership::Delegate,
}
}
}
#[deprecated(note = "`VxDelegate`-specific, will be removed in a future release")]
#[derive(Debug)]
pub struct BufferDesc {
pub fd: i32,
pub size: usize,
pub map_ptr: Option<*mut std::ffi::c_void>,
}
#[deprecated(note = "`VxDelegate`-specific, will be removed in a future release")]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct BufferHandle(i32);
#[allow(deprecated)]
impl BufferHandle {
#[must_use]
pub fn from_raw(value: i32) -> Self {
Self(value)
}
#[must_use]
pub fn raw(self) -> i32 {
self.0
}
}
#[derive(Debug)]
pub struct DmaBuf<'a> {
delegate: NonNull<TfLiteDelegate>,
hal_handle: Option<*mut c_void>,
hal_fns: Option<&'a HalDmaBufFunctions>,
vx_fns: Option<&'a VxDmaBufFunctions>,
}
impl<'a> DmaBuf<'a> {
pub(crate) fn new(
delegate: NonNull<TfLiteDelegate>,
hal_handle: Option<*mut c_void>,
hal_fns: Option<&'a HalDmaBufFunctions>,
vx_fns: Option<&'a VxDmaBufFunctions>,
) -> Self {
Self {
delegate,
hal_handle,
hal_fns,
vx_fns,
}
}
fn hal_delegate_ptr(&self) -> *mut c_void {
self.hal_handle
.unwrap_or_else(|| self.delegate.as_ptr().cast::<c_void>())
}
#[must_use]
pub fn is_supported(&self) -> bool {
if let Some(hal) = self.hal_fns {
unsafe { (hal.is_supported)(self.hal_delegate_ptr()) == 1 }
} else if let Some(vx) = self.vx_fns {
unsafe { (vx.is_supported)(self.delegate.as_ptr()) }
} else {
false
}
}
pub fn tensor_info(&self, tensor_index: i32) -> Result<TensorInfo> {
let hal = self.hal_fns.ok_or_else(|| {
Error::invalid_argument("tensor_info requires the HAL Delegate DMA-BUF API")
})?;
let mut info = HalDmabufTensorInfo::default();
let ret = unsafe {
(hal.get_tensor_info)(
self.hal_delegate_ptr(),
tensor_index,
&mut info,
std::mem::size_of::<HalDmabufTensorInfo>(),
)
};
error::hal_to_result(ret, "hal_dmabuf_get_tensor_info")?;
let ndim = info.ndim.min(info.shape.len());
Ok(TensorInfo {
size: info.size,
offset: info.offset,
shape: info.shape[..ndim].to_vec(),
fd: info.fd,
dtype: DType::from_hal(info.dtype),
})
}
pub fn sync_for_device(&self, tensor_index: i32) -> Result<()> {
if let Some(hal) = self.hal_fns {
let ret = unsafe { (hal.sync_for_device)(self.hal_delegate_ptr(), tensor_index) };
error::hal_to_result(ret, "hal_dmabuf_sync_for_device")
} else if let Some(vx) = self.vx_fns {
let handle = unsafe { (vx.get_active_buffer)(self.delegate.as_ptr(), tensor_index) };
if handle == kTfLiteNullBufferHandle {
return Err(Error::invalid_argument(format!(
"no active DMA-BUF for tensor index {tensor_index}"
)));
}
let status = unsafe { (vx.sync_for_device)(self.delegate.as_ptr(), handle) };
error::status_to_result(status)
} else {
Err(Error::invalid_argument("no DMA-BUF backend available"))
}
}
pub fn sync_for_cpu(&self, tensor_index: i32) -> Result<()> {
if let Some(hal) = self.hal_fns {
let ret = unsafe { (hal.sync_for_cpu)(self.hal_delegate_ptr(), tensor_index) };
error::hal_to_result(ret, "hal_dmabuf_sync_for_cpu")
} else if let Some(vx) = self.vx_fns {
let handle = unsafe { (vx.get_active_buffer)(self.delegate.as_ptr(), tensor_index) };
if handle == kTfLiteNullBufferHandle {
return Err(Error::invalid_argument(format!(
"no active DMA-BUF for tensor index {tensor_index}"
)));
}
let status = unsafe { (vx.sync_for_cpu)(self.delegate.as_ptr(), handle) };
error::status_to_result(status)
} else {
Err(Error::invalid_argument("no DMA-BUF backend available"))
}
}
fn vx(&self) -> Result<&VxDmaBufFunctions> {
self.vx_fns.ok_or_else(|| {
Error::invalid_argument(
"this method requires the `VxDelegate` DMA-BUF API, which is not available",
)
})
}
#[deprecated(note = "`VxDelegate`-specific, will be removed in a future release")]
#[allow(deprecated)]
pub fn register(&self, fd: i32, size: usize, sync_mode: SyncMode) -> Result<BufferHandle> {
let vx = self.vx()?;
let handle = unsafe { (vx.register)(self.delegate.as_ptr(), fd, size, sync_mode.to_raw()) };
if handle == kTfLiteNullBufferHandle {
return Err(Error::null_pointer(
"`VxDelegate`RegisterDmaBuf returned null handle",
));
}
Ok(BufferHandle(handle))
}
#[deprecated(note = "`VxDelegate`-specific, will be removed in a future release")]
#[allow(deprecated)]
pub fn unregister(&self, handle: BufferHandle) -> Result<()> {
let vx = self.vx()?;
let status = unsafe { (vx.unregister)(self.delegate.as_ptr(), handle.0) };
error::status_to_result(status)
}
#[deprecated(note = "`VxDelegate`-specific, will be removed in a future release")]
#[allow(deprecated)]
pub fn request(
&self,
tensor_index: i32,
ownership: Ownership,
size: usize,
) -> Result<(BufferHandle, BufferDesc)> {
let vx = self.vx()?;
let mut desc = VxDmaBufDesc {
size,
..VxDmaBufDesc::default()
};
let handle = unsafe {
(vx.request)(
self.delegate.as_ptr(),
tensor_index,
ownership.to_raw(),
&mut desc,
)
};
if handle == kTfLiteNullBufferHandle {
return Err(Error::null_pointer(
"`VxDelegate`RequestDmaBuf returned null handle",
));
}
let map_ptr = if desc.map_ptr.is_null() {
Option::None
} else {
Some(desc.map_ptr)
};
Ok((
BufferHandle(handle),
BufferDesc {
fd: desc.fd,
size: desc.size,
map_ptr,
},
))
}
#[deprecated(note = "`VxDelegate`-specific, will be removed in a future release")]
#[allow(deprecated)]
pub fn release(&self, handle: BufferHandle) -> Result<()> {
let vx = self.vx()?;
let status = unsafe { (vx.release)(self.delegate.as_ptr(), handle.0) };
error::status_to_result(status)
}
#[deprecated(note = "`VxDelegate`-specific, will be removed in a future release")]
#[allow(deprecated)]
pub fn bind_to_tensor(&self, handle: BufferHandle, tensor_index: i32) -> Result<()> {
let vx = self.vx()?;
let status = unsafe { (vx.bind_to_tensor)(self.delegate.as_ptr(), handle.0, tensor_index) };
error::status_to_result(status)
}
#[deprecated(
note = "`VxDelegate`-specific, use DmaBuf::tensor_info() instead; will be removed in a future release"
)]
#[allow(deprecated)]
pub fn buffer_fd(&self, handle: BufferHandle) -> Result<i32> {
let vx = self.vx()?;
let fd = unsafe { (vx.get_fd)(self.delegate.as_ptr(), handle.0) };
if fd < 0 {
return Err(Error::invalid_argument(format!(
"`VxDelegate`GetDmaBufFd returned {fd} for handle {}",
handle.0
)));
}
Ok(fd)
}
#[deprecated(
note = "`VxDelegate`-specific, use DmaBuf::sync_for_cpu() instead; will be removed in a future release"
)]
#[allow(deprecated)]
pub fn begin_cpu_access(&self, handle: BufferHandle, mode: SyncMode) -> Result<()> {
let vx = self.vx()?;
let status =
unsafe { (vx.begin_cpu_access)(self.delegate.as_ptr(), handle.0, mode.to_raw()) };
error::status_to_result(status)
}
#[deprecated(
note = "`VxDelegate`-specific, use DmaBuf::sync_for_device() instead; will be removed in a future release"
)]
#[allow(deprecated)]
pub fn end_cpu_access(&self, handle: BufferHandle, mode: SyncMode) -> Result<()> {
let vx = self.vx()?;
let status =
unsafe { (vx.end_cpu_access)(self.delegate.as_ptr(), handle.0, mode.to_raw()) };
error::status_to_result(status)
}
#[deprecated(
note = "`VxDelegate`-specific, use DmaBuf::sync_for_device(tensor_index) instead; will be removed in a future release"
)]
#[allow(deprecated)]
pub fn sync_for_device_by_handle(&self, handle: BufferHandle) -> Result<()> {
let vx = self.vx()?;
let status = unsafe { (vx.sync_for_device)(self.delegate.as_ptr(), handle.0) };
error::status_to_result(status)
}
#[deprecated(
note = "`VxDelegate`-specific, use DmaBuf::sync_for_cpu(tensor_index) instead; will be removed in a future release"
)]
#[allow(deprecated)]
pub fn sync_for_cpu_by_handle(&self, handle: BufferHandle) -> Result<()> {
let vx = self.vx()?;
let status = unsafe { (vx.sync_for_cpu)(self.delegate.as_ptr(), handle.0) };
error::status_to_result(status)
}
#[deprecated(note = "`VxDelegate`-specific, will be removed in a future release")]
#[allow(deprecated)]
pub fn set_active(&self, tensor_index: i32, handle: BufferHandle) -> Result<()> {
let vx = self.vx()?;
let status = unsafe { (vx.set_active)(self.delegate.as_ptr(), tensor_index, handle.0) };
error::status_to_result(status)
}
#[deprecated(note = "`VxDelegate`-specific, will be removed in a future release")]
#[allow(deprecated)]
#[must_use]
pub fn active_buffer(&self, tensor_index: i32) -> Option<BufferHandle> {
let vx = self.vx_fns?;
let handle = unsafe { (vx.get_active_buffer)(self.delegate.as_ptr(), tensor_index) };
if handle == kTfLiteNullBufferHandle {
None
} else {
Some(BufferHandle(handle))
}
}
#[deprecated(note = "`VxDelegate`-specific, will be removed in a future release")]
pub fn invalidate_graph(&self) -> Result<()> {
let vx = self.vx()?;
let status = unsafe { (vx.invalidate_graph)(self.delegate.as_ptr()) };
error::status_to_result(status)
}
#[deprecated(note = "`VxDelegate`-specific, will be removed in a future release")]
#[must_use]
pub fn is_graph_compiled(&self) -> bool {
self.vx_fns.is_some_and(|vx| {
unsafe { (vx.is_graph_compiled)(self.delegate.as_ptr()) }
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dtype_from_hal_roundtrip() {
let cases = [
(HalDtype::U8, DType::U8),
(HalDtype::I8, DType::I8),
(HalDtype::U16, DType::U16),
(HalDtype::I16, DType::I16),
(HalDtype::U32, DType::U32),
(HalDtype::I32, DType::I32),
(HalDtype::U64, DType::U64),
(HalDtype::I64, DType::I64),
(HalDtype::F16, DType::F16),
(HalDtype::F32, DType::F32),
(HalDtype::F64, DType::F64),
];
for (hal, expected) in cases {
assert_eq!(DType::from_hal(hal), expected);
}
}
#[test]
fn dtype_display() {
assert_eq!(DType::U8.to_string(), "u8");
assert_eq!(DType::I8.to_string(), "i8");
assert_eq!(DType::F32.to_string(), "f32");
assert_eq!(DType::F64.to_string(), "f64");
}
#[test]
fn dtype_clone_copy_eq_hash() {
use std::collections::HashSet;
let a = DType::F32;
let b = a;
assert_eq!(a, b);
let mut set = HashSet::new();
set.insert(DType::U8);
set.insert(DType::I8);
set.insert(DType::U8); assert_eq!(set.len(), 2);
}
#[test]
fn tensor_info_debug() {
let info = TensorInfo {
size: 4096,
offset: 0,
shape: vec![1, 3, 224, 224],
fd: 5,
dtype: DType::U8,
};
let debug = format!("{info:?}");
assert!(debug.contains("4096"));
assert!(debug.contains("224"));
assert!(debug.contains("U8"));
}
#[allow(deprecated)]
#[test]
fn sync_mode_to_raw() {
assert_eq!(SyncMode::None.to_raw(), VxDmaBufSyncMode::None);
assert_eq!(SyncMode::Read.to_raw(), VxDmaBufSyncMode::Read);
assert_eq!(SyncMode::Write.to_raw(), VxDmaBufSyncMode::Write);
assert_eq!(SyncMode::ReadWrite.to_raw(), VxDmaBufSyncMode::ReadWrite);
}
#[allow(deprecated)]
#[test]
fn ownership_to_raw() {
assert_eq!(Ownership::Client.to_raw(), VxDmaBufOwnership::Client);
assert_eq!(Ownership::Delegate.to_raw(), VxDmaBufOwnership::Delegate);
}
#[allow(deprecated)]
#[test]
fn buffer_handle_raw() {
let handle = BufferHandle(42);
assert_eq!(handle.raw(), 42);
}
#[allow(deprecated)]
#[test]
fn buffer_handle_from_raw() {
let handle = BufferHandle::from_raw(7);
assert_eq!(handle.raw(), 7);
}
#[allow(deprecated)]
#[test]
fn buffer_handle_equality() {
let a = BufferHandle(7);
let b = BufferHandle(7);
let c = BufferHandle(99);
assert_eq!(a, b);
assert_ne!(a, c);
}
#[allow(deprecated)]
#[test]
fn buffer_desc_debug() {
let desc = BufferDesc {
fd: 3,
size: 4096,
map_ptr: Option::None,
};
let debug = format!("{desc:?}");
assert!(debug.contains("fd"));
assert!(debug.contains('3'));
}
}