Skip to main content

cubecl_runtime/memory_management/memory_pool/
handle.rs

1use crate::memory_id_type;
2use crate::memory_management::MemoryHandle;
3use crate::{id::HandleRef, server::Handle};
4use alloc::vec::Vec;
5
6// The SliceId allows to keep track of how many references there are to a specific slice.
7memory_id_type!(SliceId, SliceHandle, SliceBinding);
8
9impl MemoryHandle<SliceBinding> for SliceHandle {
10    fn can_mut(&self) -> bool {
11        HandleRef::can_mut(self)
12    }
13
14    fn binding(self) -> SliceBinding {
15        self.binding()
16    }
17}
18
19/// Take a list of sub-slices of a buffer and create a list of offset handles.
20/// Sizes must be in bytes and handles will be aligned to the memory alignment.
21pub fn offset_handles(
22    base_handle: Handle,
23    sizes_bytes: &[usize],
24    buffer_align: usize,
25) -> Vec<Handle> {
26    let total_size = base_handle.size() as usize;
27    let mut offset = 0;
28    let mut out = Vec::new();
29
30    for size in sizes_bytes {
31        let handle = base_handle
32            .clone()
33            .offset_start(offset as u64)
34            .offset_end((total_size - offset - size) as u64);
35        out.push(handle);
36        offset += size.next_multiple_of(buffer_align);
37    }
38
39    out
40}
41
42/// Calculates a best-effort heuristic for the alignment of row-aligned tensors.
43/// Prefers contiguous alignments for unit dimensions, 16-byte minimum alignment for non-unit,
44/// scaling with input size up to `buffer_align`.
45pub fn optimal_align(shape: usize, elem_size: usize, buffer_align: usize) -> usize {
46    if shape == 1 {
47        elem_size
48    } else {
49        (shape * elem_size)
50            .next_power_of_two()
51            .clamp(16, buffer_align)
52    }
53}