use crate::{
error::{Error, Result},
BufferIdentity, TensorMap, TensorMapTrait, TensorMemory, TensorTrait,
};
use log::trace;
use num_traits::Num;
use std::{
ffi::c_void,
fmt,
marker::PhantomData,
ops::{Deref, DerefMut},
ptr::NonNull,
sync::{
atomic::{AtomicBool, Ordering},
Arc, Mutex,
},
};
pub struct PboMapping {
pub ptr: *mut u8,
pub size: usize,
}
unsafe impl Send for PboMapping {}
pub unsafe trait PboOps: Send + Sync {
fn map_buffer(&self, buffer_id: u32, size: usize) -> Result<PboMapping>;
fn unmap_buffer(&self, buffer_id: u32) -> Result<()>;
fn delete_buffer(&self, buffer_id: u32);
}
struct PboHandle {
ops: Arc<dyn PboOps>,
buffer_id: u32,
size: usize,
mapped: AtomicBool,
}
impl Drop for PboHandle {
fn drop(&mut self) {
self.ops.delete_buffer(self.buffer_id);
}
}
pub struct PboTensor<T>
where
T: Num + Clone + fmt::Debug + Send + Sync,
{
pub name: String,
pub shape: Vec<usize>,
handle: Arc<PboHandle>,
identity: BufferIdentity,
_marker: PhantomData<T>,
}
impl<T> fmt::Debug for PboTensor<T>
where
T: Num + Clone + fmt::Debug + Send + Sync,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PboTensor")
.field("name", &self.name)
.field("shape", &self.shape)
.field("buffer_id", &self.handle.buffer_id)
.field("size", &self.handle.size)
.finish()
}
}
unsafe impl<T> Send for PboTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
unsafe impl<T> Sync for PboTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
impl<T> PboTensor<T>
where
T: Num + Clone + fmt::Debug + Send + Sync,
{
pub fn from_pbo(
buffer_id: u32,
size: usize,
shape: &[usize],
name: Option<&str>,
ops: Arc<dyn PboOps>,
) -> Result<Self> {
let expected = shape.iter().product::<usize>() * std::mem::size_of::<T>();
if size != expected {
return Err(Error::ShapeMismatch(format!(
"PBO size {size} does not match shape {shape:?} * sizeof({}) = {expected}",
std::any::type_name::<T>(),
)));
}
if size == 0 {
return Err(Error::InvalidSize(0));
}
let name = name.unwrap_or("pbo_tensor").to_owned();
Ok(Self {
name,
shape: shape.to_vec(),
handle: Arc::new(PboHandle {
ops,
buffer_id,
size,
mapped: AtomicBool::new(false),
}),
identity: BufferIdentity::new(),
_marker: PhantomData,
})
}
pub fn buffer_id(&self) -> u32 {
self.handle.buffer_id
}
pub fn is_mapped(&self) -> bool {
self.handle.mapped.load(Ordering::Acquire)
}
}
impl<T> TensorTrait<T> for PboTensor<T>
where
T: Num + Clone + fmt::Debug + Send + Sync,
{
fn new(_shape: &[usize], _name: Option<&str>) -> Result<Self> {
Err(Error::NotImplemented(
"PboTensor cannot be created directly — use ImageProcessor::create_image()".to_owned(),
))
}
#[cfg(unix)]
fn from_fd(_fd: std::os::fd::OwnedFd, _shape: &[usize], _name: Option<&str>) -> Result<Self> {
Err(Error::NotImplemented(
"PboTensor does not support from_fd".to_owned(),
))
}
#[cfg(unix)]
fn clone_fd(&self) -> Result<std::os::fd::OwnedFd> {
Err(Error::NotImplemented(
"PboTensor does not support clone_fd".to_owned(),
))
}
fn memory(&self) -> TensorMemory {
TensorMemory::Pbo
}
fn name(&self) -> String {
self.name.clone()
}
fn shape(&self) -> &[usize] {
&self.shape
}
fn reshape(&mut self, shape: &[usize]) -> Result<()> {
if shape.is_empty() {
return Err(Error::InvalidSize(0));
}
let new_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
if new_size != self.handle.size {
return Err(Error::ShapeMismatch(format!(
"Cannot reshape incompatible shape: {:?} to {:?}",
self.shape, shape
)));
}
self.shape = shape.to_vec();
Ok(())
}
fn map(&self) -> Result<TensorMap<T>> {
if self.handle.mapped.swap(true, Ordering::AcqRel) {
return Err(Error::PboMapped);
}
match self
.handle
.ops
.map_buffer(self.handle.buffer_id, self.handle.size)
{
Ok(mapping) => {
let pbo_ptr = PboPtr(
NonNull::new(mapping.ptr as *mut c_void)
.ok_or(Error::InvalidSize(self.handle.size))?,
);
Ok(TensorMap::Pbo(PboMap {
ptr: Arc::new(Mutex::new(pbo_ptr)),
shape: self.shape.clone(),
handle: Arc::clone(&self.handle),
_marker: PhantomData,
}))
}
Err(e) => {
self.handle.mapped.store(false, Ordering::Release);
Err(e)
}
}
}
fn buffer_identity(&self) -> &BufferIdentity {
&self.identity
}
}
#[derive(Debug)]
struct PboPtr(NonNull<c_void>);
impl Deref for PboPtr {
type Target = NonNull<c_void>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
unsafe impl Send for PboPtr {}
pub struct PboMap<T>
where
T: Num + Clone + fmt::Debug,
{
ptr: Arc<Mutex<PboPtr>>,
shape: Vec<usize>,
handle: Arc<PboHandle>,
_marker: PhantomData<T>,
}
impl<T> fmt::Debug for PboMap<T>
where
T: Num + Clone + fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PboMap")
.field("shape", &self.shape)
.field("buffer_id", &self.handle.buffer_id)
.finish()
}
}
impl<T> TensorMapTrait<T> for PboMap<T>
where
T: Num + Clone + fmt::Debug,
{
fn shape(&self) -> &[usize] {
&self.shape
}
fn unmap(&mut self) {
trace!("Unmapping PboMap buffer_id={}", self.handle.buffer_id);
if let Err(e) = self.handle.ops.unmap_buffer(self.handle.buffer_id) {
log::warn!("Failed to unmap PBO buffer {}: {e}", self.handle.buffer_id);
}
self.handle.mapped.store(false, Ordering::Release);
}
fn as_slice(&self) -> &[T] {
let ptr = self.ptr.lock().expect("Failed to lock PboMap pointer");
unsafe { std::slice::from_raw_parts(ptr.as_ptr() as *const T, self.len()) }
}
fn as_mut_slice(&mut self) -> &mut [T] {
let ptr = self.ptr.lock().expect("Failed to lock PboMap pointer");
unsafe { std::slice::from_raw_parts_mut(ptr.as_ptr() as *mut T, self.len()) }
}
}
impl<T> Deref for PboMap<T>
where
T: Num + Clone + fmt::Debug,
{
type Target = [T];
fn deref(&self) -> &[T] {
self.as_slice()
}
}
impl<T> DerefMut for PboMap<T>
where
T: Num + Clone + fmt::Debug,
{
fn deref_mut(&mut self) -> &mut [T] {
self.as_mut_slice()
}
}
impl<T> Drop for PboMap<T>
where
T: Num + Clone + fmt::Debug,
{
fn drop(&mut self) {
self.unmap();
}
}
impl<T> Clone for PboTensor<T>
where
T: Num + Clone + fmt::Debug + Send + Sync,
{
fn clone(&self) -> Self {
Self {
name: self.name.clone(),
shape: self.shape.clone(),
handle: Arc::clone(&self.handle),
identity: self.identity.clone(),
_marker: PhantomData,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
struct MockPboOps {
storage: Mutex<Vec<u8>>,
}
impl MockPboOps {
fn new(size: usize) -> Arc<Self> {
Arc::new(Self {
storage: Mutex::new(vec![0u8; size]),
})
}
}
unsafe impl PboOps for MockPboOps {
fn map_buffer(&self, _buffer_id: u32, size: usize) -> Result<PboMapping> {
let storage = self.storage.lock().expect("lock");
assert_eq!(storage.len(), size);
Ok(PboMapping {
ptr: storage.as_ptr() as *mut u8,
size,
})
}
fn unmap_buffer(&self, _buffer_id: u32) -> Result<()> {
Ok(())
}
fn delete_buffer(&self, _buffer_id: u32) {}
}
#[test]
fn test_pbo_tensor_create_and_metadata() {
let ops = MockPboOps::new(24);
let tensor = PboTensor::<u8>::from_pbo(42, 24, &[2, 3, 4], Some("test_pbo"), ops).unwrap();
assert_eq!(tensor.memory(), TensorMemory::Pbo);
assert_eq!(tensor.name(), "test_pbo");
assert_eq!(tensor.shape(), &[2, 3, 4]);
assert_eq!(tensor.buffer_id(), 42);
assert!(!tensor.is_mapped());
}
#[test]
fn test_pbo_tensor_map_write_read() {
let ops = MockPboOps::new(12);
let tensor = PboTensor::<u8>::from_pbo(1, 12, &[3, 4], Some("rw_test"), ops).unwrap();
{
let mut map = tensor.map().expect("map should succeed");
assert_eq!(map.shape(), &[3, 4]);
assert!(tensor.is_mapped());
map.as_mut_slice().fill(0xAB);
assert!(map.as_slice().iter().all(|&b| b == 0xAB));
}
assert!(!tensor.is_mapped());
}
#[test]
fn test_pbo_tensor_double_map_fails() {
let ops = MockPboOps::new(8);
let tensor = PboTensor::<u8>::from_pbo(2, 8, &[8], None, ops).unwrap();
let _map1 = tensor.map().expect("first map should succeed");
assert!(tensor.is_mapped());
let result = tensor.map();
assert!(result.is_err(), "second map while mapped should fail");
}
#[test]
fn test_pbo_tensor_reshape() {
let ops = MockPboOps::new(24);
let mut tensor = PboTensor::<u8>::from_pbo(3, 24, &[2, 3, 4], None, ops).unwrap();
tensor
.reshape(&[4, 6])
.expect("compatible reshape should succeed");
assert_eq!(tensor.shape(), &[4, 6]);
let result = tensor.reshape(&[100]);
assert!(result.is_err(), "incompatible reshape should fail");
}
#[test]
fn test_pbo_tensor_buffer_identity() {
let ops1 = MockPboOps::new(8);
let ops2 = MockPboOps::new(8);
let t1 = PboTensor::<u8>::from_pbo(1, 8, &[8], None, ops1).unwrap();
let t2 = PboTensor::<u8>::from_pbo(2, 8, &[8], None, ops2).unwrap();
assert_ne!(t1.buffer_identity().id(), t2.buffer_identity().id());
}
#[test]
fn test_pbo_tensor_new_returns_error() {
let result = PboTensor::<u8>::new(&[8], None);
assert!(result.is_err(), "PboTensor::new() should fail");
}
#[cfg(unix)]
#[test]
fn test_pbo_tensor_fd_ops_return_error() {
let ops = MockPboOps::new(8);
let tensor = PboTensor::<u8>::from_pbo(1, 8, &[8], None, ops).unwrap();
assert!(tensor.clone_fd().is_err());
}
#[test]
fn test_pbo_tensor_from_pbo_size_mismatch() {
let ops = MockPboOps::new(24);
let result = PboTensor::<u8>::from_pbo(1, 24, &[2, 3, 5], None, ops);
assert!(result.is_err(), "mismatched size/shape should fail");
}
#[test]
fn test_pbo_tensor_from_pbo_zero_size() {
let ops = MockPboOps::new(0);
let result = PboTensor::<u8>::from_pbo(1, 0, &[0], None, ops);
assert!(result.is_err(), "zero size should fail");
}
#[test]
fn test_pbo_via_tensor_enum() {
let ops = MockPboOps::new(12);
let pbo = PboTensor::<u8>::from_pbo(10, 12, &[3, 4], Some("enum_test"), ops).unwrap();
let tensor = crate::Tensor::wrap(crate::TensorStorage::Pbo(pbo));
assert_eq!(tensor.memory(), TensorMemory::Pbo);
assert_eq!(tensor.name(), "enum_test");
assert_eq!(tensor.shape(), &[3, 4]);
let mut map = tensor.map().expect("map via enum");
map.as_mut_slice().fill(42);
assert!(map.as_slice().iter().all(|&b| b == 42));
}
}