1use core::future::Future;
2
3use crate::{
4 channel::ComputeChannel,
5 memory_management::MemoryUsage,
6 server::{Binding, ComputeServer, CubeCount, Handle},
7 storage::BindingResource,
8 DeviceProperties, ExecutionMode,
9};
10use alloc::sync::Arc;
11use alloc::vec::Vec;
12use cubecl_common::benchmark::TimestampsResult;
13
14#[derive(Debug)]
17pub struct ComputeClient<Server: ComputeServer, Channel> {
18 channel: Channel,
19 state: Arc<ComputeClientState<Server>>,
20}
21
22#[derive(new, Debug)]
23struct ComputeClientState<Server: ComputeServer> {
24 properties: DeviceProperties<Server::Feature>,
25 timestamp_lock: async_lock::Mutex<()>,
26}
27
28impl<S, C> Clone for ComputeClient<S, C>
29where
30 S: ComputeServer,
31 C: ComputeChannel<S>,
32{
33 fn clone(&self) -> Self {
34 Self {
35 channel: self.channel.clone(),
36 state: self.state.clone(),
37 }
38 }
39}
40
41impl<Server, Channel> ComputeClient<Server, Channel>
42where
43 Server: ComputeServer,
44 Channel: ComputeChannel<Server>,
45{
46 pub fn new(channel: Channel, properties: DeviceProperties<Server::Feature>) -> Self {
48 let state = ComputeClientState::new(properties, async_lock::Mutex::new(()));
49 Self {
50 channel,
51 state: Arc::new(state),
52 }
53 }
54
55 pub async fn read_async(&self, bindings: Vec<Binding>) -> Vec<Vec<u8>> {
57 self.channel.read(bindings).await
58 }
59
60 pub fn read(&self, bindings: Vec<Binding>) -> Vec<Vec<u8>> {
66 cubecl_common::reader::read_sync(self.channel.read(bindings))
67 }
68
69 pub async fn read_one_async(&self, binding: Binding) -> Vec<u8> {
71 self.channel.read([binding].into()).await.remove(0)
72 }
73
74 pub fn read_one(&self, binding: Binding) -> Vec<u8> {
79 cubecl_common::reader::read_sync(self.channel.read([binding].into())).remove(0)
80 }
81
82 pub fn get_resource(&self, binding: Binding) -> BindingResource<Server> {
84 self.channel.get_resource(binding)
85 }
86
87 pub fn create(&self, data: &[u8]) -> Handle {
89 self.channel.create(data)
90 }
91
92 pub fn empty(&self, size: usize) -> Handle {
94 self.channel.empty(size)
95 }
96
97 pub fn execute(&self, kernel: Server::Kernel, count: CubeCount, bindings: Vec<Binding>) {
99 unsafe {
100 self.channel
101 .execute(kernel, count, bindings, ExecutionMode::Checked)
102 }
103 }
104
105 pub unsafe fn execute_unchecked(
111 &self,
112 kernel: Server::Kernel,
113 count: CubeCount,
114 bindings: Vec<Binding>,
115 ) {
116 self.channel
117 .execute(kernel, count, bindings, ExecutionMode::Unchecked)
118 }
119
120 pub fn flush(&self) {
122 self.channel.flush();
123 }
124
125 pub async fn sync(&self) {
127 self.channel.sync().await
128 }
129
130 pub async fn sync_elapsed(&self) -> TimestampsResult {
132 self.channel.sync_elapsed().await
133 }
134
135 pub fn properties(&self) -> &DeviceProperties<Server::Feature> {
137 &self.state.properties
138 }
139
140 pub fn memory_usage(&self) -> MemoryUsage {
142 self.channel.memory_usage()
143 }
144
145 pub async fn profile<O, Fut, Func>(&self, func: Func) -> O
153 where
154 Fut: Future<Output = O>,
155 Func: FnOnce() -> Fut,
156 {
157 let lock = &self.state.timestamp_lock;
158 let guard = lock.lock().await;
159
160 self.channel.enable_timestamps();
161
162 self.sync_elapsed().await.ok();
164
165 let fut = func();
168 let output = fut.await;
169
170 self.channel.disable_timestamps();
171
172 core::mem::drop(guard);
173 output
174 }
175
176 pub fn enable_timestamps(&self) {
195 self.channel.enable_timestamps();
196 }
197}