1use crate::{
5 error::{Error, Result},
6 BufferIdentity, TensorMap, TensorMapTrait, TensorMemory, TensorTrait,
7};
8use log::trace;
9use num_traits::Num;
10use std::{
11 ffi::c_void,
12 fmt,
13 marker::PhantomData,
14 ops::{Deref, DerefMut},
15 ptr::NonNull,
16 sync::{
17 atomic::{AtomicBool, Ordering},
18 Arc, Mutex,
19 },
20};
21
22pub struct PboMapping {
25 pub ptr: *mut u8,
26 pub size: usize,
27}
28
29unsafe impl Send for PboMapping {}
34
35pub unsafe trait PboOps: Send + Sync {
49 fn map_buffer(&self, buffer_id: u32, size: usize) -> Result<PboMapping>;
52
53 fn unmap_buffer(&self, buffer_id: u32) -> Result<()>;
56
57 fn delete_buffer(&self, buffer_id: u32);
60}
61
62struct PboHandle {
64 ops: Arc<dyn PboOps>,
65 buffer_id: u32,
66 size: usize,
67 mapped: AtomicBool,
68}
69
70impl Drop for PboHandle {
71 fn drop(&mut self) {
72 self.ops.delete_buffer(self.buffer_id);
73 }
74}
75
76pub struct PboTensor<T>
78where
79 T: Num + Clone + fmt::Debug + Send + Sync,
80{
81 pub name: String,
82 pub shape: Vec<usize>,
83 handle: Arc<PboHandle>,
84 identity: BufferIdentity,
85 _marker: PhantomData<T>,
86}
87
88impl<T> fmt::Debug for PboTensor<T>
89where
90 T: Num + Clone + fmt::Debug + Send + Sync,
91{
92 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93 f.debug_struct("PboTensor")
94 .field("name", &self.name)
95 .field("shape", &self.shape)
96 .field("buffer_id", &self.handle.buffer_id)
97 .field("size", &self.handle.size)
98 .finish()
99 }
100}
101
102unsafe impl<T> Send for PboTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
103unsafe impl<T> Sync for PboTensor<T> where T: Num + Clone + fmt::Debug + Send + Sync {}
104
105impl<T> PboTensor<T>
106where
107 T: Num + Clone + fmt::Debug + Send + Sync,
108{
109 pub fn from_pbo(
120 buffer_id: u32,
121 size: usize,
122 shape: &[usize],
123 name: Option<&str>,
124 ops: Arc<dyn PboOps>,
125 ) -> Result<Self> {
126 let expected = shape.iter().product::<usize>() * std::mem::size_of::<T>();
127 if size != expected {
128 return Err(Error::ShapeMismatch(format!(
129 "PBO size {size} does not match shape {shape:?} * sizeof({}) = {expected}",
130 std::any::type_name::<T>(),
131 )));
132 }
133 if size == 0 {
134 return Err(Error::InvalidSize(0));
135 }
136 let name = name.unwrap_or("pbo_tensor").to_owned();
137 Ok(Self {
138 name,
139 shape: shape.to_vec(),
140 handle: Arc::new(PboHandle {
141 ops,
142 buffer_id,
143 size,
144 mapped: AtomicBool::new(false),
145 }),
146 identity: BufferIdentity::new(),
147 _marker: PhantomData,
148 })
149 }
150
151 pub fn buffer_id(&self) -> u32 {
153 self.handle.buffer_id
154 }
155
156 pub fn is_mapped(&self) -> bool {
158 self.handle.mapped.load(Ordering::Acquire)
159 }
160}
161
162impl<T> TensorTrait<T> for PboTensor<T>
163where
164 T: Num + Clone + fmt::Debug + Send + Sync,
165{
166 fn new(_shape: &[usize], _name: Option<&str>) -> Result<Self> {
167 Err(Error::NotImplemented(
168 "PboTensor cannot be created directly — use ImageProcessor::create_image()".to_owned(),
169 ))
170 }
171
172 #[cfg(unix)]
173 fn from_fd(_fd: std::os::fd::OwnedFd, _shape: &[usize], _name: Option<&str>) -> Result<Self> {
174 Err(Error::NotImplemented(
175 "PboTensor does not support from_fd".to_owned(),
176 ))
177 }
178
179 #[cfg(unix)]
180 fn clone_fd(&self) -> Result<std::os::fd::OwnedFd> {
181 Err(Error::NotImplemented(
182 "PboTensor does not support clone_fd".to_owned(),
183 ))
184 }
185
186 fn memory(&self) -> TensorMemory {
187 TensorMemory::Pbo
188 }
189
190 fn name(&self) -> String {
191 self.name.clone()
192 }
193
194 fn shape(&self) -> &[usize] {
195 &self.shape
196 }
197
198 fn reshape(&mut self, shape: &[usize]) -> Result<()> {
199 if shape.is_empty() {
200 return Err(Error::InvalidSize(0));
201 }
202 let new_size = shape.iter().product::<usize>() * std::mem::size_of::<T>();
203 if new_size != self.handle.size {
204 return Err(Error::ShapeMismatch(format!(
205 "Cannot reshape incompatible shape: {:?} to {:?}",
206 self.shape, shape
207 )));
208 }
209 self.shape = shape.to_vec();
210 Ok(())
211 }
212
213 fn map(&self) -> Result<TensorMap<T>> {
214 if self.handle.mapped.swap(true, Ordering::AcqRel) {
215 return Err(Error::PboMapped);
216 }
217 match self
218 .handle
219 .ops
220 .map_buffer(self.handle.buffer_id, self.handle.size)
221 {
222 Ok(mapping) => {
223 let pbo_ptr = PboPtr(
224 NonNull::new(mapping.ptr as *mut c_void)
225 .ok_or(Error::InvalidSize(self.handle.size))?,
226 );
227 Ok(TensorMap::Pbo(PboMap {
228 ptr: Arc::new(Mutex::new(pbo_ptr)),
229 shape: self.shape.clone(),
230 handle: Arc::clone(&self.handle),
231 _marker: PhantomData,
232 }))
233 }
234 Err(e) => {
235 self.handle.mapped.store(false, Ordering::Release);
236 Err(e)
237 }
238 }
239 }
240
241 fn buffer_identity(&self) -> &BufferIdentity {
242 &self.identity
243 }
244}
245
246#[derive(Debug)]
249struct PboPtr(NonNull<c_void>);
250
251impl Deref for PboPtr {
252 type Target = NonNull<c_void>;
253 fn deref(&self) -> &Self::Target {
254 &self.0
255 }
256}
257
258unsafe impl Send for PboPtr {}
259
260pub struct PboMap<T>
261where
262 T: Num + Clone + fmt::Debug,
263{
264 ptr: Arc<Mutex<PboPtr>>,
265 shape: Vec<usize>,
266 handle: Arc<PboHandle>,
267 _marker: PhantomData<T>,
268}
269
270impl<T> fmt::Debug for PboMap<T>
271where
272 T: Num + Clone + fmt::Debug,
273{
274 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
275 f.debug_struct("PboMap")
276 .field("shape", &self.shape)
277 .field("buffer_id", &self.handle.buffer_id)
278 .finish()
279 }
280}
281
282impl<T> TensorMapTrait<T> for PboMap<T>
283where
284 T: Num + Clone + fmt::Debug,
285{
286 fn shape(&self) -> &[usize] {
287 &self.shape
288 }
289
290 fn unmap(&mut self) {
291 trace!("Unmapping PboMap buffer_id={}", self.handle.buffer_id);
292 if let Err(e) = self.handle.ops.unmap_buffer(self.handle.buffer_id) {
293 log::warn!("Failed to unmap PBO buffer {}: {e}", self.handle.buffer_id);
294 }
295 self.handle.mapped.store(false, Ordering::Release);
296 }
297
298 fn as_slice(&self) -> &[T] {
299 let ptr = self.ptr.lock().expect("Failed to lock PboMap pointer");
300 unsafe { std::slice::from_raw_parts(ptr.as_ptr() as *const T, self.len()) }
301 }
302
303 fn as_mut_slice(&mut self) -> &mut [T] {
304 let ptr = self.ptr.lock().expect("Failed to lock PboMap pointer");
305 unsafe { std::slice::from_raw_parts_mut(ptr.as_ptr() as *mut T, self.len()) }
306 }
307}
308
309impl<T> Deref for PboMap<T>
310where
311 T: Num + Clone + fmt::Debug,
312{
313 type Target = [T];
314 fn deref(&self) -> &[T] {
315 self.as_slice()
316 }
317}
318
319impl<T> DerefMut for PboMap<T>
320where
321 T: Num + Clone + fmt::Debug,
322{
323 fn deref_mut(&mut self) -> &mut [T] {
324 self.as_mut_slice()
325 }
326}
327
328impl<T> Drop for PboMap<T>
329where
330 T: Num + Clone + fmt::Debug,
331{
332 fn drop(&mut self) {
333 self.unmap();
334 }
335}
336
337impl<T> Clone for PboTensor<T>
338where
339 T: Num + Clone + fmt::Debug + Send + Sync,
340{
341 fn clone(&self) -> Self {
342 Self {
343 name: self.name.clone(),
344 shape: self.shape.clone(),
345 handle: Arc::clone(&self.handle),
346 identity: self.identity.clone(),
347 _marker: PhantomData,
348 }
349 }
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355
356 struct MockPboOps {
358 storage: Mutex<Vec<u8>>,
359 }
360
361 impl MockPboOps {
362 fn new(size: usize) -> Arc<Self> {
363 Arc::new(Self {
364 storage: Mutex::new(vec![0u8; size]),
365 })
366 }
367 }
368
369 unsafe impl PboOps for MockPboOps {
372 fn map_buffer(&self, _buffer_id: u32, size: usize) -> Result<PboMapping> {
373 let storage = self.storage.lock().expect("lock");
374 assert_eq!(storage.len(), size);
375 Ok(PboMapping {
376 ptr: storage.as_ptr() as *mut u8,
377 size,
378 })
379 }
380
381 fn unmap_buffer(&self, _buffer_id: u32) -> Result<()> {
382 Ok(())
383 }
384
385 fn delete_buffer(&self, _buffer_id: u32) {}
386 }
387
388 #[test]
389 fn test_pbo_tensor_create_and_metadata() {
390 let ops = MockPboOps::new(24);
391 let tensor = PboTensor::<u8>::from_pbo(42, 24, &[2, 3, 4], Some("test_pbo"), ops).unwrap();
392 assert_eq!(tensor.memory(), TensorMemory::Pbo);
393 assert_eq!(tensor.name(), "test_pbo");
394 assert_eq!(tensor.shape(), &[2, 3, 4]);
395 assert_eq!(tensor.buffer_id(), 42);
396 assert!(!tensor.is_mapped());
397 }
398
399 #[test]
400 fn test_pbo_tensor_map_write_read() {
401 let ops = MockPboOps::new(12);
402 let tensor = PboTensor::<u8>::from_pbo(1, 12, &[3, 4], Some("rw_test"), ops).unwrap();
403 {
404 let mut map = tensor.map().expect("map should succeed");
405 assert_eq!(map.shape(), &[3, 4]);
406 assert!(tensor.is_mapped());
407 map.as_mut_slice().fill(0xAB);
408 assert!(map.as_slice().iter().all(|&b| b == 0xAB));
409 }
410 assert!(!tensor.is_mapped());
411 }
412
413 #[test]
414 fn test_pbo_tensor_double_map_fails() {
415 let ops = MockPboOps::new(8);
416 let tensor = PboTensor::<u8>::from_pbo(2, 8, &[8], None, ops).unwrap();
417 let _map1 = tensor.map().expect("first map should succeed");
418 assert!(tensor.is_mapped());
419 let result = tensor.map();
420 assert!(result.is_err(), "second map while mapped should fail");
421 }
422
423 #[test]
424 fn test_pbo_tensor_reshape() {
425 let ops = MockPboOps::new(24);
426 let mut tensor = PboTensor::<u8>::from_pbo(3, 24, &[2, 3, 4], None, ops).unwrap();
427 tensor
428 .reshape(&[4, 6])
429 .expect("compatible reshape should succeed");
430 assert_eq!(tensor.shape(), &[4, 6]);
431 let result = tensor.reshape(&[100]);
432 assert!(result.is_err(), "incompatible reshape should fail");
433 }
434
435 #[test]
436 fn test_pbo_tensor_buffer_identity() {
437 let ops1 = MockPboOps::new(8);
438 let ops2 = MockPboOps::new(8);
439 let t1 = PboTensor::<u8>::from_pbo(1, 8, &[8], None, ops1).unwrap();
440 let t2 = PboTensor::<u8>::from_pbo(2, 8, &[8], None, ops2).unwrap();
441 assert_ne!(t1.buffer_identity().id(), t2.buffer_identity().id());
442 }
443
444 #[test]
445 fn test_pbo_tensor_new_returns_error() {
446 let result = PboTensor::<u8>::new(&[8], None);
447 assert!(result.is_err(), "PboTensor::new() should fail");
448 }
449
450 #[cfg(unix)]
451 #[test]
452 fn test_pbo_tensor_fd_ops_return_error() {
453 let ops = MockPboOps::new(8);
454 let tensor = PboTensor::<u8>::from_pbo(1, 8, &[8], None, ops).unwrap();
455 assert!(tensor.clone_fd().is_err());
456 }
457
458 #[test]
459 fn test_pbo_tensor_from_pbo_size_mismatch() {
460 let ops = MockPboOps::new(24);
461 let result = PboTensor::<u8>::from_pbo(1, 24, &[2, 3, 5], None, ops);
462 assert!(result.is_err(), "mismatched size/shape should fail");
463 }
464
465 #[test]
466 fn test_pbo_tensor_from_pbo_zero_size() {
467 let ops = MockPboOps::new(0);
468 let result = PboTensor::<u8>::from_pbo(1, 0, &[0], None, ops);
469 assert!(result.is_err(), "zero size should fail");
470 }
471
472 #[test]
473 fn test_pbo_via_tensor_enum() {
474 let ops = MockPboOps::new(12);
475 let pbo = PboTensor::<u8>::from_pbo(10, 12, &[3, 4], Some("enum_test"), ops).unwrap();
476 let tensor = crate::Tensor::wrap(crate::TensorStorage::Pbo(pbo));
477 assert_eq!(tensor.memory(), TensorMemory::Pbo);
478 assert_eq!(tensor.name(), "enum_test");
479 assert_eq!(tensor.shape(), &[3, 4]);
480 let mut map = tensor.map().expect("map via enum");
481 map.as_mut_slice().fill(42);
482 assert!(map.as_slice().iter().all(|&b| b == 42));
483 }
484}