cubecl_runtime/client.rs
1use crate::{
2 DeviceProperties, TimeMeasurement,
3 channel::ComputeChannel,
4 memory_management::MemoryUsage,
5 server::{Binding, BindingWithMeta, Bindings, ComputeServer, CubeCount, Handle},
6 storage::{BindingResource, ComputeStorage},
7};
8use alloc::sync::Arc;
9use alloc::vec::Vec;
10use cubecl_common::{ExecutionMode, benchmark::ProfileDuration};
11use spin::Mutex;
12
13/// The ComputeClient is the entry point to require tasks from the ComputeServer.
14/// It should be obtained for a specific device via the Compute struct.
15#[derive(Debug)]
16pub struct ComputeClient<Server: ComputeServer, Channel> {
17 channel: Channel,
18 state: Arc<ComputeClientState<Server>>,
19}
20
21#[derive(new, Debug)]
22struct ComputeClientState<Server: ComputeServer> {
23 properties: DeviceProperties<Server::Feature>,
24 profile_lock: Mutex<()>,
25
26 info: Server::Info,
27}
28
29impl<S, C> Clone for ComputeClient<S, C>
30where
31 S: ComputeServer,
32 C: ComputeChannel<S>,
33{
34 fn clone(&self) -> Self {
35 Self {
36 channel: self.channel.clone(),
37 state: self.state.clone(),
38 }
39 }
40}
41
42impl<Server, Channel> ComputeClient<Server, Channel>
43where
44 Server: ComputeServer,
45 Channel: ComputeChannel<Server>,
46{
47 /// Get the info of the current backend.
48 pub fn info(&self) -> &Server::Info {
49 &self.state.info
50 }
51
52 /// Create a new client.
53 pub fn new(
54 channel: Channel,
55 properties: DeviceProperties<Server::Feature>,
56 info: Server::Info,
57 ) -> Self {
58 let state = ComputeClientState::new(properties, Mutex::new(()), info);
59 Self {
60 channel,
61 state: Arc::new(state),
62 }
63 }
64
65 /// Given bindings, returns owned resources as bytes.
66 pub async fn read_async(&self, bindings: Vec<Binding>) -> Vec<Vec<u8>> {
67 self.channel.read(bindings).await
68 }
69
70 /// Given bindings, returns owned resources as bytes.
71 ///
72 /// # Remarks
73 ///
74 /// Panics if the read operation fails.
75 pub fn read(&self, bindings: Vec<Binding>) -> Vec<Vec<u8>> {
76 cubecl_common::reader::read_sync(self.channel.read(bindings))
77 }
78
79 /// Given a binding, returns owned resource as bytes.
80 pub async fn read_one_async(&self, binding: Binding) -> Vec<u8> {
81 self.channel.read([binding].into()).await.remove(0)
82 }
83
84 /// Given a binding, returns owned resource as bytes.
85 ///
86 /// # Remarks
87 /// Panics if the read operation fails.
88 pub fn read_one(&self, binding: Binding) -> Vec<u8> {
89 cubecl_common::reader::read_sync(self.channel.read([binding].into())).remove(0)
90 }
91
92 /// Given bindings, returns owned resources as bytes.
93 pub async fn read_tensor_async(&self, bindings: Vec<BindingWithMeta>) -> Vec<Vec<u8>> {
94 self.channel.read_tensor(bindings).await
95 }
96
97 /// Given bindings, returns owned resources as bytes.
98 ///
99 /// # Remarks
100 ///
101 /// Panics if the read operation fails.
102 ///
103 /// The tensor must be in the same layout as created by the runtime, or more strict.
104 /// Contiguous tensors are always fine, strided tensors are only ok if the stride is similar to
105 /// the one created by the runtime (i.e. padded on only the last dimension). A way to check
106 /// stride compatiblity on the runtime will be added in the future.
107 ///
108 /// Also see [ComputeClient::create_tensor].
109 pub fn read_tensor(&self, bindings: Vec<BindingWithMeta>) -> Vec<Vec<u8>> {
110 cubecl_common::reader::read_sync(self.channel.read_tensor(bindings))
111 }
112
113 /// Given a binding, returns owned resource as bytes.
114 /// See [ComputeClient::read_tensor]
115 pub async fn read_one_tensor_async(&self, binding: BindingWithMeta) -> Vec<u8> {
116 self.channel.read_tensor([binding].into()).await.remove(0)
117 }
118
119 /// Given a binding, returns owned resource as bytes.
120 ///
121 /// # Remarks
122 /// Panics if the read operation fails.
123 /// See [ComputeClient::read_tensor]
124 pub fn read_one_tensor(&self, binding: BindingWithMeta) -> Vec<u8> {
125 cubecl_common::reader::read_sync(self.channel.read_tensor([binding].into())).remove(0)
126 }
127
128 /// Given a resource handle, returns the storage resource.
129 pub fn get_resource(
130 &self,
131 binding: Binding,
132 ) -> BindingResource<<Server::Storage as ComputeStorage>::Resource> {
133 self.channel.get_resource(binding)
134 }
135
136 /// Given a resource, stores it and returns the resource handle.
137 pub fn create(&self, data: &[u8]) -> Handle {
138 self.channel.create(data)
139 }
140
141 /// Given a resource and shape, stores it and returns the tensor handle and strides.
142 /// This may or may not return contiguous strides. The layout is up to the runtime, and care
143 /// should be taken when indexing.
144 ///
145 /// Currently the tensor may either be contiguous (most runtimes), or "pitched", to use the CUDA
146 /// terminology. This means the last (contiguous) dimension is padded to fit a certain alignment,
147 /// and the strides are adjusted accordingly. This can make memory accesses significantly faster
148 /// since all rows are aligned to at least 16 bytes (the maximum load width), meaning the GPU
149 /// can load as much data as possible in a single instruction. It may be aligned even more to
150 /// also take cache lines into account.
151 ///
152 /// However, the stride must be taken into account when indexing and reading the tensor
153 /// (also see [ComputeClient::read_tensor]).
154 pub fn create_tensor(
155 &self,
156 data: &[u8],
157 shape: &[usize],
158 elem_size: usize,
159 ) -> (Handle, Vec<usize>) {
160 self.channel.create_tensor(data, shape, elem_size)
161 }
162
163 /// Reserves `size` bytes in the storage, and returns a handle over them.
164 pub fn empty(&self, size: usize) -> Handle {
165 self.channel.empty(size)
166 }
167
168 /// Reserves `shape` in the storage, and returns a tensor handle for it.
169 /// See [ComputeClient::create_tensor]
170 pub fn empty_tensor(&self, shape: &[usize], elem_size: usize) -> (Handle, Vec<usize>) {
171 self.channel.empty_tensor(shape, elem_size)
172 }
173
174 /// Executes the `kernel` over the given `bindings`.
175 pub fn execute(&self, kernel: Server::Kernel, count: CubeCount, bindings: Bindings) {
176 unsafe {
177 self.channel
178 .execute(kernel, count, bindings, ExecutionMode::Checked)
179 }
180 }
181
182 /// Executes the `kernel` over the given `bindings` without performing any bound checks.
183 ///
184 /// # Safety
185 ///
186 /// Without checks, the out-of-bound reads and writes can happen.
187 pub unsafe fn execute_unchecked(
188 &self,
189 kernel: Server::Kernel,
190 count: CubeCount,
191 bindings: Bindings,
192 ) {
193 unsafe {
194 self.channel
195 .execute(kernel, count, bindings, ExecutionMode::Unchecked)
196 }
197 }
198
199 /// Flush all outstanding commands.
200 pub fn flush(&self) {
201 self.channel.flush();
202 }
203
204 /// Wait for the completion of every task in the server.
205 pub async fn sync(&self) {
206 self.channel.sync().await
207 }
208
209 /// Get the features supported by the compute server.
210 pub fn properties(&self) -> &DeviceProperties<Server::Feature> {
211 &self.state.properties
212 }
213
214 /// Get the current memory usage of this client.
215 pub fn memory_usage(&self) -> MemoryUsage {
216 self.channel.memory_usage()
217 }
218
219 /// Ask the client to release memory that it can release.
220 ///
221 /// Nb: Results will vary on what the memory allocator deems beneficial,
222 /// so it's not guaranteed any memory is freed.
223 pub fn memory_cleanup(&self) {
224 self.channel.memory_cleanup()
225 }
226
227 /// Measure the execution time of some inner operations.
228 ///
229 /// Nb: this function will only allow one function at a time to be submitted when multithrading.
230 /// Recursive measurements are not allowed and will deadlock.
231 pub fn profile(&self, func: impl FnOnce()) -> ProfileDuration {
232 let guard = self.state.profile_lock.lock();
233 self.channel.start_profile();
234 func();
235 let result = self.channel.end_profile();
236 let result = match self.properties().time_measurement() {
237 TimeMeasurement::Device => result,
238 TimeMeasurement::System => {
239 #[cfg(target_family = "wasm")]
240 panic!("Can't use system timing mode on wasm");
241
242 #[cfg(not(target_family = "wasm"))]
243 {
244 // It is important to wait for the profiling to be done, since we're actually
245 // measuring its execution timing using 'real' time.
246 let duration = cubecl_common::future::block_on(result.resolve());
247 ProfileDuration::from_duration(duration)
248 }
249 }
250 };
251 core::mem::drop(guard);
252 result
253 }
254}