kornia_rs/tensor/
allocator.rs

1use std::alloc::{GlobalAlloc, Layout, System};
2
3use thiserror::Error;
4
5#[derive(Debug, Error)]
6pub enum TensorAllocatorError {
7    #[error("Invalid tensor layout {0}")]
8    LayoutError(core::alloc::LayoutError),
9
10    #[error("Null pointer")]
11    NullPointer,
12}
13
14/// A trait for allocating and deallocating memory for tensors.
15///
16/// # Safety
17///
18/// The tensor allocator must be thread-safe.
19///
20/// # Methods
21///
22/// * `alloc` - Allocates memory for a tensor with the given layout.
23/// * `dealloc` - Deallocates memory for a tensor with the given layout.
24pub trait TensorAllocator: Clone {
25    fn alloc(&self, layout: Layout) -> Result<*mut u8, TensorAllocatorError>;
26    fn dealloc(&self, ptr: *mut u8, layout: Layout);
27}
28
29#[derive(Clone)]
30/// A tensor allocator that uses the system allocator.
31pub struct CpuAllocator;
32
33/// Implement the `Default` trait for the `CpuAllocator` struct.
34impl Default for CpuAllocator {
35    fn default() -> Self {
36        Self
37    }
38}
39
40/// Implement the `TensorAllocator` trait for the `CpuAllocator` struct.
41impl TensorAllocator for CpuAllocator {
42    /// Allocates memory for a tensor with the given layout.
43    ///
44    /// # Arguments
45    ///
46    /// * `layout` - The layout of the tensor.
47    ///
48    /// # Returns
49    ///
50    /// A non-null pointer to the allocated memory if successful, otherwise an error.
51    fn alloc(&self, layout: Layout) -> Result<*mut u8, TensorAllocatorError> {
52        let ptr = unsafe { System.alloc(layout) };
53        if ptr.is_null() {
54            Err(TensorAllocatorError::NullPointer)?
55        }
56        Ok(ptr)
57    }
58
59    /// Deallocates memory for a tensor with the given layout.
60    ///
61    /// # Arguments
62    ///
63    /// * `ptr` - A non-null pointer to the allocated memory.
64    /// * `layout` - The layout of the tensor.
65    ///
66    /// # Safety
67    ///
68    /// The pointer must be non-null and the layout must be correct.
69    #[allow(clippy::not_unsafe_ptr_arg_deref)]
70    fn dealloc(&self, ptr: *mut u8, layout: Layout) {
71        unsafe { System.dealloc(ptr, layout) }
72    }
73}
74
75#[cfg(test)]
76mod tests {
77    use super::*;
78
79    #[test]
80    fn test_cpu_allocator() -> Result<(), TensorAllocatorError> {
81        let allocator = CpuAllocator;
82        let layout = Layout::from_size_align(1024, 64).unwrap();
83        let ptr = allocator.alloc(layout)?;
84        allocator.dealloc(ptr, layout);
85        Ok(())
86    }
87}