Skip to main content

cubecl_runtime/server/
handle.rs

1use cubecl_common::stream_id::StreamId;
2use cubecl_zspace::{Shape, Strides};
3
4use crate::{
5    memory_management::{ManagedMemoryBinding, ManagedMemoryHandle},
6    server::CopyDescriptor,
7};
8
9/// Server handle containing the [memory handle](crate::server::Handle).
10pub struct Handle {
11    /// Memory handle.
12    pub memory: ManagedMemoryHandle,
13    /// Memory offset in bytes.
14    pub offset_start: Option<u64>,
15    /// Memory offset in bytes.
16    pub offset_end: Option<u64>,
17    /// The stream where the data was created.
18    pub stream: StreamId,
19    /// Length of the underlying buffer ignoring offsets
20    pub(crate) size: u64,
21}
22
23impl core::fmt::Debug for Handle {
24    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
25        f.debug_struct("Handle")
26            .field("id", &self.memory)
27            .field("offset_start", &self.offset_start)
28            .field("offset_end", &self.offset_end)
29            .field("stream", &self.stream)
30            .field("size", &self.size)
31            .finish()
32    }
33}
34
35impl Clone for Handle {
36    fn clone(&self) -> Self {
37        Self {
38            memory: self.memory.clone(),
39            offset_start: self.offset_start,
40            offset_end: self.offset_end,
41            stream: self.stream,
42            size: self.size,
43        }
44    }
45}
46
47impl Handle {
48    /// Creates a new handle of the given size.
49    pub fn from_memory(id: ManagedMemoryHandle, stream: StreamId, size: u64) -> Self {
50        Self {
51            memory: id,
52            offset_start: None,
53            offset_end: None,
54            stream,
55            size,
56        }
57    }
58    /// Creates a new handle of the given size.
59    pub fn new(stream: StreamId, size: u64) -> Self {
60        Self {
61            memory: ManagedMemoryHandle::new(),
62            offset_start: None,
63            offset_end: None,
64            stream,
65            size,
66        }
67    }
68    /// Checks whether the handle can be mutated in-place without affecting other computation.
69    pub fn can_mut(&self) -> bool {
70        self.memory.can_mut()
71    }
72
73    /// Returns the [`Binding`] corresponding to the current handle.
74    pub fn binding(self) -> Binding {
75        Binding {
76            memory: self.memory.binding(),
77            offset_start: self.offset_start,
78            offset_end: self.offset_end,
79            stream: self.stream,
80            size: self.size,
81        }
82    }
83
84    /// Add to the current offset in bytes.
85    pub fn offset_start(mut self, offset: u64) -> Self {
86        if let Some(val) = &mut self.offset_start {
87            *val += offset;
88        } else {
89            self.offset_start = Some(offset);
90        }
91
92        self
93    }
94    /// Add to the current offset in bytes.
95    pub fn offset_end(mut self, offset: u64) -> Self {
96        if let Some(val) = &mut self.offset_end {
97            *val += offset;
98        } else {
99            self.offset_end = Some(offset);
100        }
101
102        self
103    }
104
105    /// Convert the [handle](Handle) into a [binding](Binding) with shape and stride metadata.
106    pub fn copy_descriptor(
107        self,
108        shape: Shape,
109        strides: Strides,
110        elem_size: usize,
111    ) -> CopyDescriptor {
112        CopyDescriptor {
113            shape,
114            strides,
115            elem_size,
116            handle: self.binding(),
117        }
118    }
119    /// Get the size of the handle, in bytes, accounting for offsets
120    pub fn size_in_used(&self) -> u64 {
121        self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
122    }
123    /// Get the total size of the handle, in bytes.
124    pub fn size(&self) -> u64 {
125        self.size
126    }
127}
128
129/// A binding represents a [Handle] that is bound to managed memory.
130///
131/// The memory used is known by the compute server.
132/// A binding is only valid after being initlized with [`super::ComputeServer::initialize_bindings`]
133///
134/// # Notes
135///
136/// A binding is detached from a [`Handle`], meaning that is won't affect [`Handle::can_mut`].
137#[derive(Clone, Debug)]
138pub struct Binding {
139    /// The id of the handle the binding is bound to.
140    pub memory: ManagedMemoryBinding,
141    /// Memory offset in bytes.
142    pub offset_start: Option<u64>,
143    /// Memory offset in bytes.
144    pub offset_end: Option<u64>,
145    /// The stream where the data was created.
146    pub stream: StreamId,
147    /// Length of the underlying buffer ignoring offsets
148    pub size: u64,
149}
150
151impl Binding {
152    /// Get the size of the handle, in bytes, accounting for offsets
153    pub fn size_in_used(&self) -> u64 {
154        self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
155    }
156    /// Get the total size of the handle, in bytes.
157    pub fn size(&self) -> u64 {
158        self.size
159    }
160}