kornia_core/allocator.rs
1use std::alloc;
2use std::alloc::Layout;
3
4use thiserror::Error;
5
6/// An error type for tensor allocator operations.
7#[derive(Debug, Error, PartialEq)]
8pub enum TensorAllocatorError {
9 /// An error occurred during memory allocation.
10 #[error("Invalid tensor layout {0}")]
11 LayoutError(core::alloc::LayoutError),
12
13 /// An error occurred during memory allocation.
14 #[error("Null pointer")]
15 NullPointer,
16}
17
18/// A trait for allocating and deallocating memory for tensors.
19///
20/// # Safety
21///
22/// The tensor allocator must be thread-safe.
23///
24/// # Methods
25///
26/// * `alloc` - Allocates memory for a tensor with the given layout.
27/// * `dealloc` - Deallocates memory for a tensor with the given layout.
28pub trait TensorAllocator: Clone {
29 /// Allocates memory for a tensor with the given layout.
30 fn alloc(&self, layout: Layout) -> Result<*mut u8, TensorAllocatorError>;
31
32 /// Deallocates memory for a tensor with the given layout.
33 fn dealloc(&self, ptr: *mut u8, layout: Layout);
34}
35
36#[derive(Clone)]
37/// A tensor allocator that uses the system allocator.
38pub struct CpuAllocator;
39
40/// Implement the `Default` trait for the `CpuAllocator` struct.
41impl Default for CpuAllocator {
42 fn default() -> Self {
43 Self
44 }
45}
46
47/// Implement the `TensorAllocator` trait for the `CpuAllocator` struct.
48impl TensorAllocator for CpuAllocator {
49 /// Allocates memory for a tensor with the given layout.
50 ///
51 /// # Arguments
52 ///
53 /// * `layout` - The layout of the tensor.
54 ///
55 /// # Returns
56 ///
57 /// A non-null pointer to the allocated memory if successful, otherwise an error.
58 fn alloc(&self, layout: Layout) -> Result<*mut u8, TensorAllocatorError> {
59 let ptr = unsafe { alloc::alloc(layout) };
60 if ptr.is_null() {
61 Err(TensorAllocatorError::NullPointer)?
62 }
63 Ok(ptr)
64 }
65
66 /// Deallocates memory for a tensor with the given layout.
67 ///
68 /// # Arguments
69 ///
70 /// * `ptr` - A non-null pointer to the allocated memory.
71 /// * `layout` - The layout of the tensor.
72 ///
73 /// # Safety
74 ///
75 /// The pointer must be non-null and the layout must be correct.
76 #[allow(clippy::not_unsafe_ptr_arg_deref)]
77 fn dealloc(&self, ptr: *mut u8, layout: Layout) {
78 unsafe { alloc::dealloc(ptr, layout) }
79 }
80}
81
82#[cfg(test)]
83mod tests {
84 use super::*;
85
86 #[test]
87 fn test_cpu_allocator() -> Result<(), TensorAllocatorError> {
88 let allocator = CpuAllocator;
89 let layout = Layout::from_size_align(1024, 64).unwrap();
90 let ptr = allocator.alloc(layout)?;
91 allocator.dealloc(ptr, layout);
92 Ok(())
93 }
94}