Skip to main content

cubecl_runtime/
allocator.rs

1use crate::{
2    memory_management::optimal_align,
3    server::{
4        Handle, MemoryLayout, MemoryLayoutDescriptor, MemoryLayoutPolicy, MemoryLayoutStrategy,
5    },
6};
7use alloc::vec::Vec;
8use cubecl_common::stream_id::StreamId;
9use cubecl_zspace::{Shape, Strides, strides};
10
11/// Allocators where every allocations is with contiguous memory.
12pub struct ContiguousMemoryLayoutPolicy {
13    mem_alignment: usize,
14}
15
16/// Allocators where some allocations can leverage a pitched layout.
17pub struct PitchedMemoryLayoutPolicy {
18    mem_alignment: usize,
19}
20
21impl MemoryLayoutPolicy for PitchedMemoryLayoutPolicy {
22    fn apply(
23        &self,
24        stream_id: StreamId,
25        descriptors: &[MemoryLayoutDescriptor],
26    ) -> (Handle, Vec<MemoryLayout>) {
27        let mut total_size = 0u64;
28
29        let (sizes, strides): (Vec<_>, Vec<_>) = descriptors
30            .iter()
31            .map(|descriptor| {
32                let last_dim = descriptor.shape.last().copied().unwrap_or(1);
33                let pitch_align = match descriptor.strategy {
34                    MemoryLayoutStrategy::Contiguous => 1,
35                    MemoryLayoutStrategy::Optimized => {
36                        optimal_align(last_dim, descriptor.elem_size, self.mem_alignment)
37                    }
38                };
39
40                let rank = descriptor.shape.len();
41                let width = *descriptor.shape.last().unwrap_or(&1);
42                let height: usize = descriptor.shape.iter().rev().skip(1).product();
43                let height = Ord::max(height, 1);
44
45                let width_bytes = width * descriptor.elem_size;
46                let pitch = width_bytes.next_multiple_of(pitch_align);
47                let size = height * pitch;
48
49                let mut strides = strides![1; rank];
50                if rank > 1 {
51                    strides[rank - 2] = pitch / descriptor.elem_size;
52                }
53                if rank > 2 {
54                    for i in (0..rank - 2).rev() {
55                        strides[i] = strides[i + 1] * descriptor.shape[i + 1];
56                    }
57                }
58                total_size += size.next_multiple_of(self.mem_alignment) as u64;
59                (size, strides)
60            })
61            .unzip();
62
63        let base_handle = Handle::new(stream_id, total_size);
64
65        let layouts = offset_handles(base_handle.clone(), &sizes, self.mem_alignment)
66            .into_iter()
67            .zip(strides)
68            .map(|(handle, strides)| MemoryLayout::new(handle, strides))
69            .collect();
70        (base_handle, layouts)
71    }
72}
73
74impl ContiguousMemoryLayoutPolicy {
75    /// Creates a new allocator with the given memory alignment.
76    pub fn new(mem_alignment: usize) -> Self {
77        Self { mem_alignment }
78    }
79}
80
81impl PitchedMemoryLayoutPolicy {
82    /// Creates a new allocator with the given memory alignment.
83    pub fn new(mem_alignment: usize) -> Self {
84        Self { mem_alignment }
85    }
86}
87
88impl MemoryLayoutPolicy for ContiguousMemoryLayoutPolicy {
89    fn apply(
90        &self,
91        stream_id: StreamId,
92        descriptors: &[MemoryLayoutDescriptor],
93    ) -> (Handle, Vec<MemoryLayout>) {
94        let mut total_size = 0u64;
95        let (sizes, strides): (Vec<_>, Vec<_>) = descriptors
96            .iter()
97            .map(|desc| {
98                let size = desc.shape.iter().product::<usize>() * desc.elem_size;
99                total_size += size.next_multiple_of(self.mem_alignment) as u64;
100                (size, contiguous_strides(&desc.shape))
101            })
102            .unzip();
103
104        let base_handle = Handle::new(stream_id, total_size);
105
106        let layouts = offset_handles(base_handle.clone(), &sizes, self.mem_alignment)
107            .into_iter()
108            .zip(strides)
109            .map(|(handle, stride)| MemoryLayout::new(handle, stride))
110            .collect();
111
112        (base_handle, layouts)
113    }
114}
115
116pub(crate) fn contiguous_strides(shape: &Shape) -> Strides {
117    let rank = shape.len();
118    let mut strides = strides![1; rank];
119    for i in (0..rank - 1).rev() {
120        strides[i] = strides[i + 1] * shape[i + 1];
121    }
122    strides
123}
124
125/// Take a list of sub-slices of a buffer and create a list of offset handles.
126/// Sizes must be in bytes and handles will be aligned to the memory alignment.
127pub fn offset_handles(
128    base_handle: Handle,
129    sizes_bytes: &[usize],
130    buffer_align: usize,
131) -> Vec<Handle> {
132    let total_size = base_handle.size() as usize;
133    let mut offset = 0;
134    let mut out = Vec::new();
135
136    for size in sizes_bytes {
137        let handle = base_handle
138            .clone()
139            .offset_start(offset as u64)
140            .offset_end((total_size - offset - size) as u64);
141        out.push(handle);
142        offset += size.next_multiple_of(buffer_align);
143    }
144
145    out
146}