cubecl_runtime/
client.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
use core::future::Future;

use crate::{
    channel::ComputeChannel,
    memory_management::MemoryUsage,
    server::{Binding, ComputeServer, CubeCount, Handle},
    storage::BindingResource,
    DeviceProperties, ExecutionMode,
};
use alloc::sync::Arc;
use alloc::vec::Vec;
use cubecl_common::benchmark::TimestampsResult;

/// The ComputeClient is the entry point to require tasks from the ComputeServer.
/// It should be obtained for a specific device via the Compute struct.
#[derive(Debug)]
pub struct ComputeClient<Server: ComputeServer, Channel> {
    channel: Channel,
    state: Arc<ComputeClientState<Server>>,
}

#[derive(new, Debug)]
struct ComputeClientState<Server: ComputeServer> {
    properties: DeviceProperties<Server::Feature>,
    timestamp_lock: async_lock::Mutex<()>,
}

impl<S, C> Clone for ComputeClient<S, C>
where
    S: ComputeServer,
    C: ComputeChannel<S>,
{
    fn clone(&self) -> Self {
        Self {
            channel: self.channel.clone(),
            state: self.state.clone(),
        }
    }
}

impl<Server, Channel> ComputeClient<Server, Channel>
where
    Server: ComputeServer,
    Channel: ComputeChannel<Server>,
{
    /// Create a new client.
    pub fn new(channel: Channel, properties: DeviceProperties<Server::Feature>) -> Self {
        let state = ComputeClientState::new(properties, async_lock::Mutex::new(()));
        Self {
            channel,
            state: Arc::new(state),
        }
    }

    /// Given a binding, returns owned resource as bytes.
    pub async fn read_async(&self, binding: Binding) -> Vec<u8> {
        self.channel.read(binding).await
    }

    /// Given a binding, returns owned resource as bytes.
    ///
    /// # Remarks
    /// Panics if the read operation fails.
    pub fn read(&self, binding: Binding) -> Vec<u8> {
        cubecl_common::reader::read_sync(self.channel.read(binding))
    }

    /// Given a resource handle, returns the storage resource.
    pub fn get_resource(&self, binding: Binding) -> BindingResource<Server> {
        self.channel.get_resource(binding)
    }

    /// Given a resource, stores it and returns the resource handle.
    pub fn create(&self, data: &[u8]) -> Handle {
        self.channel.create(data)
    }

    /// Reserves `size` bytes in the storage, and returns a handle over them.
    pub fn empty(&self, size: usize) -> Handle {
        self.channel.empty(size)
    }

    /// Executes the `kernel` over the given `bindings`.
    pub fn execute(&self, kernel: Server::Kernel, count: CubeCount, bindings: Vec<Binding>) {
        unsafe {
            self.channel
                .execute(kernel, count, bindings, ExecutionMode::Checked)
        }
    }

    /// Executes the `kernel` over the given `bindings` without performing any bound checks.
    ///
    /// # Safety
    ///
    /// Without checks, the out-of-bound reads and writes can happen.
    pub unsafe fn execute_unchecked(
        &self,
        kernel: Server::Kernel,
        count: CubeCount,
        bindings: Vec<Binding>,
    ) {
        self.channel
            .execute(kernel, count, bindings, ExecutionMode::Unchecked)
    }

    /// Flush all outstanding commands.
    pub fn flush(&self) {
        self.channel.flush();
    }

    /// Wait for the completion of every task in the server.
    pub async fn sync(&self) {
        self.channel.sync().await
    }

    /// Wait for the completion of every task in the server.
    pub async fn sync_elapsed(&self) -> TimestampsResult {
        self.channel.sync_elapsed().await
    }

    /// Get the features supported by the compute server.
    pub fn properties(&self) -> &DeviceProperties<Server::Feature> {
        &self.state.properties
    }

    /// Get the current memory usage of this client.
    pub fn memory_usage(&self) -> MemoryUsage {
        self.channel.memory_usage()
    }

    /// When executing operation within the profile scope, you can call
    /// [sync_elapsed](Self::sync_elapsed) safely even in multithreaded workloads.
    /// Creates a profiling scope that enables safe timing measurements in concurrent contexts.
    ///
    /// Operations executed within this scope can safely call [`sync_elapsed()`](Self::sync_elapsed)
    /// to measure elapsed time, even in multithreaded environments. The measurements are
    /// thread-safe and properly synchronized.
    pub async fn profile<O, Fut, Func>(&self, func: Func) -> O
    where
        Fut: Future<Output = O>,
        Func: FnOnce() -> Fut,
    {
        let lock = &self.state.timestamp_lock;
        let guard = lock.lock().await;

        self.channel.enable_timestamps();

        // Reset the client's timestamp state.
        self.sync_elapsed().await.ok();

        // We can't simply receive a future, since we need to make sure the future doesn't start
        // before the lock, which might be the case on `wasm`.
        let fut = func();
        let output = fut.await;

        self.channel.disable_timestamps();

        core::mem::drop(guard);
        output
    }

    /// Enable timestamp collection on the server for performance profiling.
    ///
    /// This feature records precise timing data for server operations, which can be used
    /// for performance analysis and benchmarking.
    ///
    /// # Warning
    ///
    /// This should only be used during development and benchmarking, not in production,
    /// as it significantly impacts server throughput and performance. The overhead comes
    /// from frequent timestamp collection and storage.
    ///
    /// # Example
    ///
    /// ```ignore
    /// server.enable_timestamps();
    /// // Run your benchmarks/operations
    /// let duration = server.sync_elapsed();
    /// ```
    pub fn enable_timestamps(&self) {
        self.channel.enable_timestamps();
    }
}