cubecl_runtime/
client.rs

1use crate::{
2    DeviceProperties, TimeMeasurement,
3    channel::ComputeChannel,
4    memory_management::MemoryUsage,
5    server::{Binding, BindingWithMeta, Bindings, ComputeServer, CubeCount, Handle},
6    storage::{BindingResource, ComputeStorage},
7};
8use alloc::sync::Arc;
9use alloc::vec::Vec;
10use cubecl_common::{ExecutionMode, benchmark::ProfileDuration};
11use spin::Mutex;
12
13/// The ComputeClient is the entry point to require tasks from the ComputeServer.
14/// It should be obtained for a specific device via the Compute struct.
15#[derive(Debug)]
16pub struct ComputeClient<Server: ComputeServer, Channel> {
17    channel: Channel,
18    state: Arc<ComputeClientState<Server>>,
19}
20
21#[derive(new, Debug)]
22struct ComputeClientState<Server: ComputeServer> {
23    properties: DeviceProperties<Server::Feature>,
24    profile_lock: Mutex<()>,
25
26    info: Server::Info,
27}
28
29impl<S, C> Clone for ComputeClient<S, C>
30where
31    S: ComputeServer,
32    C: ComputeChannel<S>,
33{
34    fn clone(&self) -> Self {
35        Self {
36            channel: self.channel.clone(),
37            state: self.state.clone(),
38        }
39    }
40}
41
42impl<Server, Channel> ComputeClient<Server, Channel>
43where
44    Server: ComputeServer,
45    Channel: ComputeChannel<Server>,
46{
47    /// Get the info of the current backend.
48    pub fn info(&self) -> &Server::Info {
49        &self.state.info
50    }
51
52    /// Create a new client.
53    pub fn new(
54        channel: Channel,
55        properties: DeviceProperties<Server::Feature>,
56        info: Server::Info,
57    ) -> Self {
58        let state = ComputeClientState::new(properties, Mutex::new(()), info);
59        Self {
60            channel,
61            state: Arc::new(state),
62        }
63    }
64
65    /// Given bindings, returns owned resources as bytes.
66    pub async fn read_async(&self, bindings: Vec<Binding>) -> Vec<Vec<u8>> {
67        self.channel.read(bindings).await
68    }
69
70    /// Given bindings, returns owned resources as bytes.
71    ///
72    /// # Remarks
73    ///
74    /// Panics if the read operation fails.
75    pub fn read(&self, bindings: Vec<Binding>) -> Vec<Vec<u8>> {
76        cubecl_common::reader::read_sync(self.channel.read(bindings))
77    }
78
79    /// Given a binding, returns owned resource as bytes.
80    pub async fn read_one_async(&self, binding: Binding) -> Vec<u8> {
81        self.channel.read([binding].into()).await.remove(0)
82    }
83
84    /// Given a binding, returns owned resource as bytes.
85    ///
86    /// # Remarks
87    /// Panics if the read operation fails.
88    pub fn read_one(&self, binding: Binding) -> Vec<u8> {
89        cubecl_common::reader::read_sync(self.channel.read([binding].into())).remove(0)
90    }
91
92    /// Given bindings, returns owned resources as bytes.
93    pub async fn read_tensor_async(&self, bindings: Vec<BindingWithMeta>) -> Vec<Vec<u8>> {
94        self.channel.read_tensor(bindings).await
95    }
96
97    /// Given bindings, returns owned resources as bytes.
98    ///
99    /// # Remarks
100    ///
101    /// Panics if the read operation fails.
102    ///
103    /// The tensor must be in the same layout as created by the runtime, or more strict.
104    /// Contiguous tensors are always fine, strided tensors are only ok if the stride is similar to
105    /// the one created by the runtime (i.e. padded on only the last dimension). A way to check
106    /// stride compatiblity on the runtime will be added in the future.
107    ///
108    /// Also see [ComputeClient::create_tensor].
109    pub fn read_tensor(&self, bindings: Vec<BindingWithMeta>) -> Vec<Vec<u8>> {
110        cubecl_common::reader::read_sync(self.channel.read_tensor(bindings))
111    }
112
113    /// Given a binding, returns owned resource as bytes.
114    /// See [ComputeClient::read_tensor]
115    pub async fn read_one_tensor_async(&self, binding: BindingWithMeta) -> Vec<u8> {
116        self.channel.read_tensor([binding].into()).await.remove(0)
117    }
118
119    /// Given a binding, returns owned resource as bytes.
120    ///
121    /// # Remarks
122    /// Panics if the read operation fails.
123    /// See [ComputeClient::read_tensor]
124    pub fn read_one_tensor(&self, binding: BindingWithMeta) -> Vec<u8> {
125        cubecl_common::reader::read_sync(self.channel.read_tensor([binding].into())).remove(0)
126    }
127
128    /// Given a resource handle, returns the storage resource.
129    pub fn get_resource(
130        &self,
131        binding: Binding,
132    ) -> BindingResource<<Server::Storage as ComputeStorage>::Resource> {
133        self.channel.get_resource(binding)
134    }
135
136    /// Given a resource, stores it and returns the resource handle.
137    pub fn create(&self, data: &[u8]) -> Handle {
138        self.channel.create(data)
139    }
140
141    /// Given a resource and shape, stores it and returns the tensor handle and strides.
142    /// This may or may not return contiguous strides. The layout is up to the runtime, and care
143    /// should be taken when indexing.
144    ///
145    /// Currently the tensor may either be contiguous (most runtimes), or "pitched", to use the CUDA
146    /// terminology. This means the last (contiguous) dimension is padded to fit a certain alignment,
147    /// and the strides are adjusted accordingly. This can make memory accesses significantly faster
148    /// since all rows are aligned to at least 16 bytes (the maximum load width), meaning the GPU
149    /// can load as much data as possible in a single instruction. It may be aligned even more to
150    /// also take cache lines into account.
151    ///
152    /// However, the stride must be taken into account when indexing and reading the tensor
153    /// (also see [ComputeClient::read_tensor]).
154    pub fn create_tensor(
155        &self,
156        data: &[u8],
157        shape: &[usize],
158        elem_size: usize,
159    ) -> (Handle, Vec<usize>) {
160        self.channel.create_tensor(data, shape, elem_size)
161    }
162
163    /// Reserves `size` bytes in the storage, and returns a handle over them.
164    pub fn empty(&self, size: usize) -> Handle {
165        self.channel.empty(size)
166    }
167
168    /// Reserves `shape` in the storage, and returns a tensor handle for it.
169    /// See [ComputeClient::create_tensor]
170    pub fn empty_tensor(&self, shape: &[usize], elem_size: usize) -> (Handle, Vec<usize>) {
171        self.channel.empty_tensor(shape, elem_size)
172    }
173
174    /// Executes the `kernel` over the given `bindings`.
175    pub fn execute(&self, kernel: Server::Kernel, count: CubeCount, bindings: Bindings) {
176        unsafe {
177            self.channel
178                .execute(kernel, count, bindings, ExecutionMode::Checked)
179        }
180    }
181
182    /// Executes the `kernel` over the given `bindings` without performing any bound checks.
183    ///
184    /// # Safety
185    ///
186    /// Without checks, the out-of-bound reads and writes can happen.
187    pub unsafe fn execute_unchecked(
188        &self,
189        kernel: Server::Kernel,
190        count: CubeCount,
191        bindings: Bindings,
192    ) {
193        unsafe {
194            self.channel
195                .execute(kernel, count, bindings, ExecutionMode::Unchecked)
196        }
197    }
198
199    /// Flush all outstanding commands.
200    pub fn flush(&self) {
201        self.channel.flush();
202    }
203
204    /// Wait for the completion of every task in the server.
205    pub async fn sync(&self) {
206        self.channel.sync().await
207    }
208
209    /// Get the features supported by the compute server.
210    pub fn properties(&self) -> &DeviceProperties<Server::Feature> {
211        &self.state.properties
212    }
213
214    /// Get the current memory usage of this client.
215    pub fn memory_usage(&self) -> MemoryUsage {
216        self.channel.memory_usage()
217    }
218
219    /// Ask the client to release memory that it can release.
220    ///
221    /// Nb: Results will vary on what the memory allocator deems beneficial,
222    /// so it's not guaranteed any memory is freed.
223    pub fn memory_cleanup(&self) {
224        self.channel.memory_cleanup()
225    }
226
227    /// Measure the execution time of some inner operations.
228    ///
229    /// Nb: this function will only allow one function at a time to be submitted when multithrading.
230    /// Recursive measurements are not allowed and will deadlock.
231    pub fn profile(&self, func: impl FnOnce()) -> ProfileDuration {
232        let guard = self.state.profile_lock.lock();
233        self.channel.start_profile();
234        func();
235        let result = self.channel.end_profile();
236        let result = match self.properties().time_measurement() {
237            TimeMeasurement::Device => result,
238            TimeMeasurement::System => {
239                #[cfg(target_family = "wasm")]
240                panic!("Can't use system timing mode on wasm");
241
242                #[cfg(not(target_family = "wasm"))]
243                {
244                    // It is important to wait for the profiling to be done, since we're actually
245                    // measuring its execution timing using 'real' time.
246                    let duration = cubecl_common::future::block_on(result.resolve());
247                    ProfileDuration::from_duration(duration)
248                }
249            }
250        };
251        core::mem::drop(guard);
252        result
253    }
254}