cubecl_runtime/
client.rs

1use core::future::Future;
2
3use crate::{
4    channel::ComputeChannel,
5    memory_management::MemoryUsage,
6    server::{Binding, ComputeServer, CubeCount, Handle},
7    storage::BindingResource,
8    DeviceProperties, ExecutionMode,
9};
10use alloc::sync::Arc;
11use alloc::vec::Vec;
12use cubecl_common::benchmark::TimestampsResult;
13
14/// The ComputeClient is the entry point to require tasks from the ComputeServer.
15/// It should be obtained for a specific device via the Compute struct.
16#[derive(Debug)]
17pub struct ComputeClient<Server: ComputeServer, Channel> {
18    channel: Channel,
19    state: Arc<ComputeClientState<Server>>,
20}
21
22#[derive(new, Debug)]
23struct ComputeClientState<Server: ComputeServer> {
24    properties: DeviceProperties<Server::Feature>,
25    timestamp_lock: async_lock::Mutex<()>,
26}
27
28impl<S, C> Clone for ComputeClient<S, C>
29where
30    S: ComputeServer,
31    C: ComputeChannel<S>,
32{
33    fn clone(&self) -> Self {
34        Self {
35            channel: self.channel.clone(),
36            state: self.state.clone(),
37        }
38    }
39}
40
41impl<Server, Channel> ComputeClient<Server, Channel>
42where
43    Server: ComputeServer,
44    Channel: ComputeChannel<Server>,
45{
46    /// Create a new client.
47    pub fn new(channel: Channel, properties: DeviceProperties<Server::Feature>) -> Self {
48        let state = ComputeClientState::new(properties, async_lock::Mutex::new(()));
49        Self {
50            channel,
51            state: Arc::new(state),
52        }
53    }
54
55    /// Given bindings, returns owned resources as bytes.
56    pub async fn read_async(&self, bindings: Vec<Binding>) -> Vec<Vec<u8>> {
57        self.channel.read(bindings).await
58    }
59
60    /// Given bindings, returns owned resources as bytes.
61    ///
62    /// # Remarks
63    ///
64    /// Panics if the read operation fails.
65    pub fn read(&self, bindings: Vec<Binding>) -> Vec<Vec<u8>> {
66        cubecl_common::reader::read_sync(self.channel.read(bindings))
67    }
68
69    /// Given a binding, returns owned resource as bytes.
70    pub async fn read_one_async(&self, binding: Binding) -> Vec<u8> {
71        self.channel.read([binding].into()).await.remove(0)
72    }
73
74    /// Given a binding, returns owned resource as bytes.
75    ///
76    /// # Remarks
77    /// Panics if the read operation fails.
78    pub fn read_one(&self, binding: Binding) -> Vec<u8> {
79        cubecl_common::reader::read_sync(self.channel.read([binding].into())).remove(0)
80    }
81
82    /// Given a resource handle, returns the storage resource.
83    pub fn get_resource(&self, binding: Binding) -> BindingResource<Server> {
84        self.channel.get_resource(binding)
85    }
86
87    /// Given a resource, stores it and returns the resource handle.
88    pub fn create(&self, data: &[u8]) -> Handle {
89        self.channel.create(data)
90    }
91
92    /// Reserves `size` bytes in the storage, and returns a handle over them.
93    pub fn empty(&self, size: usize) -> Handle {
94        self.channel.empty(size)
95    }
96
97    /// Executes the `kernel` over the given `bindings`.
98    pub fn execute(&self, kernel: Server::Kernel, count: CubeCount, bindings: Vec<Binding>) {
99        unsafe {
100            self.channel
101                .execute(kernel, count, bindings, ExecutionMode::Checked)
102        }
103    }
104
105    /// Executes the `kernel` over the given `bindings` without performing any bound checks.
106    ///
107    /// # Safety
108    ///
109    /// Without checks, the out-of-bound reads and writes can happen.
110    pub unsafe fn execute_unchecked(
111        &self,
112        kernel: Server::Kernel,
113        count: CubeCount,
114        bindings: Vec<Binding>,
115    ) {
116        self.channel
117            .execute(kernel, count, bindings, ExecutionMode::Unchecked)
118    }
119
120    /// Flush all outstanding commands.
121    pub fn flush(&self) {
122        self.channel.flush();
123    }
124
125    /// Wait for the completion of every task in the server.
126    pub async fn sync(&self) {
127        self.channel.sync().await
128    }
129
130    /// Wait for the completion of every task in the server.
131    pub async fn sync_elapsed(&self) -> TimestampsResult {
132        self.channel.sync_elapsed().await
133    }
134
135    /// Get the features supported by the compute server.
136    pub fn properties(&self) -> &DeviceProperties<Server::Feature> {
137        &self.state.properties
138    }
139
140    /// Get the current memory usage of this client.
141    pub fn memory_usage(&self) -> MemoryUsage {
142        self.channel.memory_usage()
143    }
144
145    /// When executing operation within the profile scope, you can call
146    /// [sync_elapsed](Self::sync_elapsed) safely even in multithreaded workloads.
147    /// Creates a profiling scope that enables safe timing measurements in concurrent contexts.
148    ///
149    /// Operations executed within this scope can safely call [`sync_elapsed()`](Self::sync_elapsed)
150    /// to measure elapsed time, even in multithreaded environments. The measurements are
151    /// thread-safe and properly synchronized.
152    pub async fn profile<O, Fut, Func>(&self, func: Func) -> O
153    where
154        Fut: Future<Output = O>,
155        Func: FnOnce() -> Fut,
156    {
157        let lock = &self.state.timestamp_lock;
158        let guard = lock.lock().await;
159
160        self.channel.enable_timestamps();
161
162        // Reset the client's timestamp state.
163        self.sync_elapsed().await.ok();
164
165        // We can't simply receive a future, since we need to make sure the future doesn't start
166        // before the lock, which might be the case on `wasm`.
167        let fut = func();
168        let output = fut.await;
169
170        self.channel.disable_timestamps();
171
172        core::mem::drop(guard);
173        output
174    }
175
176    /// Enable timestamp collection on the server for performance profiling.
177    ///
178    /// This feature records precise timing data for server operations, which can be used
179    /// for performance analysis and benchmarking.
180    ///
181    /// # Warning
182    ///
183    /// This should only be used during development and benchmarking, not in production,
184    /// as it significantly impacts server throughput and performance. The overhead comes
185    /// from frequent timestamp collection and storage.
186    ///
187    /// # Example
188    ///
189    /// ```ignore
190    /// server.enable_timestamps();
191    /// // Run your benchmarks/operations
192    /// let duration = server.sync_elapsed();
193    /// ```
194    pub fn enable_timestamps(&self) {
195        self.channel.enable_timestamps();
196    }
197}