cubecl_runtime/
allocator.rs1use crate::{
2 memory_management::optimal_align,
3 server::{
4 Handle, MemoryLayout, MemoryLayoutDescriptor, MemoryLayoutPolicy, MemoryLayoutStrategy,
5 },
6};
7use alloc::vec::Vec;
8use cubecl_common::stream_id::StreamId;
9use cubecl_zspace::{Shape, Strides, strides};
10
11pub struct ContiguousMemoryLayoutPolicy {
13 mem_alignment: usize,
14}
15
16pub struct PitchedMemoryLayoutPolicy {
18 mem_alignment: usize,
19}
20
21impl MemoryLayoutPolicy for PitchedMemoryLayoutPolicy {
22 fn apply(
23 &self,
24 stream_id: StreamId,
25 descriptors: &[MemoryLayoutDescriptor],
26 ) -> (Handle, Vec<MemoryLayout>) {
27 let mut total_size = 0u64;
28
29 let (sizes, strides): (Vec<_>, Vec<_>) = descriptors
30 .iter()
31 .map(|descriptor| {
32 let last_dim = descriptor.shape.last().copied().unwrap_or(1);
33 let pitch_align = match descriptor.strategy {
34 MemoryLayoutStrategy::Contiguous => 1,
35 MemoryLayoutStrategy::Optimized => {
36 optimal_align(last_dim, descriptor.elem_size, self.mem_alignment)
37 }
38 };
39
40 let rank = descriptor.shape.len();
41 let width = *descriptor.shape.last().unwrap_or(&1);
42 let height: usize = descriptor.shape.iter().rev().skip(1).product();
43 let height = Ord::max(height, 1);
44
45 let width_bytes = width * descriptor.elem_size;
46 let pitch = width_bytes.next_multiple_of(pitch_align);
47 let size = height * pitch;
48
49 let mut strides = strides![1; rank];
50 if rank > 1 {
51 strides[rank - 2] = pitch / descriptor.elem_size;
52 }
53 if rank > 2 {
54 for i in (0..rank - 2).rev() {
55 strides[i] = strides[i + 1] * descriptor.shape[i + 1];
56 }
57 }
58 total_size += size.next_multiple_of(self.mem_alignment) as u64;
59 (size, strides)
60 })
61 .unzip();
62
63 let base_handle = Handle::new(stream_id, total_size);
64
65 let layouts = offset_handles(base_handle.clone(), &sizes, self.mem_alignment)
66 .into_iter()
67 .zip(strides)
68 .map(|(handle, strides)| MemoryLayout::new(handle, strides))
69 .collect();
70 (base_handle, layouts)
71 }
72}
73
74impl ContiguousMemoryLayoutPolicy {
75 pub fn new(mem_alignment: usize) -> Self {
77 Self { mem_alignment }
78 }
79}
80
81impl PitchedMemoryLayoutPolicy {
82 pub fn new(mem_alignment: usize) -> Self {
84 Self { mem_alignment }
85 }
86}
87
88impl MemoryLayoutPolicy for ContiguousMemoryLayoutPolicy {
89 fn apply(
90 &self,
91 stream_id: StreamId,
92 descriptors: &[MemoryLayoutDescriptor],
93 ) -> (Handle, Vec<MemoryLayout>) {
94 let mut total_size = 0u64;
95 let (sizes, strides): (Vec<_>, Vec<_>) = descriptors
96 .iter()
97 .map(|desc| {
98 let size = desc.shape.iter().product::<usize>() * desc.elem_size;
99 total_size += size.next_multiple_of(self.mem_alignment) as u64;
100 (size, contiguous_strides(&desc.shape))
101 })
102 .unzip();
103
104 let base_handle = Handle::new(stream_id, total_size);
105
106 let layouts = offset_handles(base_handle.clone(), &sizes, self.mem_alignment)
107 .into_iter()
108 .zip(strides)
109 .map(|(handle, stride)| MemoryLayout::new(handle, stride))
110 .collect();
111
112 (base_handle, layouts)
113 }
114}
115
116pub(crate) fn contiguous_strides(shape: &Shape) -> Strides {
117 let rank = shape.len();
118 let mut strides = strides![1; rank];
119 for i in (0..rank - 1).rev() {
120 strides[i] = strides[i + 1] * shape[i + 1];
121 }
122 strides
123}
124
125pub fn offset_handles(
128 base_handle: Handle,
129 sizes_bytes: &[usize],
130 buffer_align: usize,
131) -> Vec<Handle> {
132 let total_size = base_handle.size() as usize;
133 let mut offset = 0;
134 let mut out = Vec::new();
135
136 for size in sizes_bytes {
137 let handle = base_handle
138 .clone()
139 .offset_start(offset as u64)
140 .offset_end((total_size - offset - size) as u64);
141 out.push(handle);
142 offset += size.next_multiple_of(buffer_align);
143 }
144
145 out
146}