kornia_core/
allocator.rs

1use std::alloc;
2use std::alloc::Layout;
3
4use thiserror::Error;
5
6/// An error type for tensor allocator operations.
7#[derive(Debug, Error, PartialEq)]
8pub enum TensorAllocatorError {
9    /// An error occurred during memory allocation.
10    #[error("Invalid tensor layout {0}")]
11    LayoutError(core::alloc::LayoutError),
12
13    /// An error occurred during memory allocation.
14    #[error("Null pointer")]
15    NullPointer,
16}
17
18/// A trait for allocating and deallocating memory for tensors.
19///
20/// # Safety
21///
22/// The tensor allocator must be thread-safe.
23///
24/// # Methods
25///
26/// * `alloc` - Allocates memory for a tensor with the given layout.
27/// * `dealloc` - Deallocates memory for a tensor with the given layout.
28pub trait TensorAllocator: Clone {
29    /// Allocates memory for a tensor with the given layout.
30    fn alloc(&self, layout: Layout) -> Result<*mut u8, TensorAllocatorError>;
31
32    /// Deallocates memory for a tensor with the given layout.
33    fn dealloc(&self, ptr: *mut u8, layout: Layout);
34}
35
36#[derive(Clone)]
37/// A tensor allocator that uses the system allocator.
38pub struct CpuAllocator;
39
40/// Implement the `Default` trait for the `CpuAllocator` struct.
41impl Default for CpuAllocator {
42    fn default() -> Self {
43        Self
44    }
45}
46
47/// Implement the `TensorAllocator` trait for the `CpuAllocator` struct.
48impl TensorAllocator for CpuAllocator {
49    /// Allocates memory for a tensor with the given layout.
50    ///
51    /// # Arguments
52    ///
53    /// * `layout` - The layout of the tensor.
54    ///
55    /// # Returns
56    ///
57    /// A non-null pointer to the allocated memory if successful, otherwise an error.
58    fn alloc(&self, layout: Layout) -> Result<*mut u8, TensorAllocatorError> {
59        let ptr = unsafe { alloc::alloc(layout) };
60        if ptr.is_null() {
61            Err(TensorAllocatorError::NullPointer)?
62        }
63        Ok(ptr)
64    }
65
66    /// Deallocates memory for a tensor with the given layout.
67    ///
68    /// # Arguments
69    ///
70    /// * `ptr` - A non-null pointer to the allocated memory.
71    /// * `layout` - The layout of the tensor.
72    ///
73    /// # Safety
74    ///
75    /// The pointer must be non-null and the layout must be correct.
76    #[allow(clippy::not_unsafe_ptr_arg_deref)]
77    fn dealloc(&self, ptr: *mut u8, layout: Layout) {
78        unsafe { alloc::dealloc(ptr, layout) }
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85
86    #[test]
87    fn test_cpu_allocator() -> Result<(), TensorAllocatorError> {
88        let allocator = CpuAllocator;
89        let layout = Layout::from_size_align(1024, 64).unwrap();
90        let ptr = allocator.alloc(layout)?;
91        allocator.dealloc(ptr, layout);
92        Ok(())
93    }
94}