1use std::{sync::Arc, thread};
2
3use cubecl_common::{ExecutionMode, benchmark::ProfileDuration};
4
5use super::ComputeChannel;
6use crate::{
7 memory_management::MemoryUsage,
8 server::{Binding, BindingWithMeta, Bindings, ComputeServer, CubeCount, Handle},
9 storage::{BindingResource, ComputeStorage},
10};
11
12#[derive(Debug)]
15pub struct MpscComputeChannel<Server>
16where
17 Server: ComputeServer,
18{
19 state: Arc<MpscComputeChannelState<Server>>,
20}
21
22#[derive(Debug)]
23struct MpscComputeChannelState<Server>
24where
25 Server: ComputeServer,
26{
27 _handle: thread::JoinHandle<()>,
28 sender: async_channel::Sender<Message<Server>>,
29}
30
31type Callback<Response> = async_channel::Sender<Response>;
32
33enum Message<Server>
34where
35 Server: ComputeServer,
36{
37 Read(Vec<Binding>, Callback<Vec<Vec<u8>>>),
38 ReadTensor(Vec<BindingWithMeta>, Callback<Vec<Vec<u8>>>),
39 GetResource(
40 Binding,
41 Callback<BindingResource<<Server::Storage as ComputeStorage>::Resource>>,
42 ),
43 Create(Vec<u8>, Callback<Handle>),
44 CreateTensor(Vec<u8>, Vec<usize>, usize, Callback<(Handle, Vec<usize>)>),
45 Empty(usize, Callback<Handle>),
46 EmptyTensor(Vec<usize>, usize, Callback<(Handle, Vec<usize>)>),
47 ExecuteKernel((Server::Kernel, CubeCount, ExecutionMode), Bindings),
48 Flush,
49 Sync(Callback<()>),
50 MemoryUsage(Callback<MemoryUsage>),
51 MemoryCleanup,
52 StartProfile,
53 StopMeasure(Callback<ProfileDuration>),
54}
55
56impl<Server> MpscComputeChannel<Server>
57where
58 Server: ComputeServer + 'static,
59{
60 pub fn new(mut server: Server) -> Self {
62 let (sender, receiver) = async_channel::unbounded();
63
64 let _handle = thread::spawn(move || {
65 cubecl_common::future::block_on(async {
68 while let Ok(message) = receiver.recv().await {
69 match message {
70 Message::Read(bindings, callback) => {
71 let data = server.read(bindings).await;
72 callback.send(data).await.unwrap();
73 }
74 Message::ReadTensor(bindings, callback) => {
75 let data = server.read_tensor(bindings).await;
76 callback.send(data).await.unwrap();
77 }
78 Message::GetResource(binding, callback) => {
79 let data = server.get_resource(binding);
80 callback.send(data).await.unwrap();
81 }
82 Message::Create(data, callback) => {
83 let handle = server.create(&data);
84 callback.send(handle).await.unwrap();
85 }
86 Message::CreateTensor(data, shape, elem_size, callback) => {
87 let handle = server.create_tensor(&data, &shape, elem_size);
88 callback.send(handle).await.unwrap();
89 }
90 Message::Empty(size, callback) => {
91 let handle = server.empty(size);
92 callback.send(handle).await.unwrap();
93 }
94 Message::EmptyTensor(shape, elem_size, callback) => {
95 let handle = server.empty_tensor(&shape, elem_size);
96 callback.send(handle).await.unwrap();
97 }
98 Message::ExecuteKernel(kernel, bindings) => unsafe {
99 server.execute(kernel.0, kernel.1, bindings, kernel.2);
100 },
101 Message::Sync(callback) => {
102 server.sync().await;
103 callback.send(()).await.unwrap();
104 }
105 Message::Flush => {
106 server.flush();
107 }
108 Message::MemoryUsage(callback) => {
109 callback.send(server.memory_usage()).await.unwrap();
110 }
111 Message::MemoryCleanup => {
112 server.memory_cleanup();
113 }
114 Message::StartProfile => {
115 server.start_profile();
116 }
117 Message::StopMeasure(callback) => {
118 callback.send(server.end_profile()).await.unwrap();
119 }
120 };
121 }
122 });
123 });
124
125 let state = Arc::new(MpscComputeChannelState { sender, _handle });
126
127 Self { state }
128 }
129}
130
131impl<Server: ComputeServer> Clone for MpscComputeChannel<Server> {
132 fn clone(&self) -> Self {
133 Self {
134 state: self.state.clone(),
135 }
136 }
137}
138
139impl<Server> ComputeChannel<Server> for MpscComputeChannel<Server>
140where
141 Server: ComputeServer + 'static,
142{
143 async fn read(&self, bindings: Vec<Binding>) -> Vec<Vec<u8>> {
144 let sender = self.state.sender.clone();
145 let (callback, response) = async_channel::unbounded();
146 sender
147 .send(Message::Read(bindings, callback))
148 .await
149 .unwrap();
150 handle_response(response.recv().await)
151 }
152
153 async fn read_tensor(&self, bindings: Vec<BindingWithMeta>) -> Vec<Vec<u8>> {
154 let sender = self.state.sender.clone();
155 let (callback, response) = async_channel::unbounded();
156 sender
157 .send(Message::ReadTensor(bindings, callback))
158 .await
159 .unwrap();
160 handle_response(response.recv().await)
161 }
162
163 fn get_resource(
164 &self,
165 binding: Binding,
166 ) -> BindingResource<<Server::Storage as ComputeStorage>::Resource> {
167 let (callback, response) = async_channel::unbounded();
168
169 self.state
170 .sender
171 .send_blocking(Message::GetResource(binding, callback))
172 .unwrap();
173
174 handle_response(response.recv_blocking())
175 }
176
177 fn create(&self, data: &[u8]) -> Handle {
178 let (callback, response) = async_channel::unbounded();
179
180 self.state
181 .sender
182 .send_blocking(Message::Create(data.to_vec(), callback))
183 .unwrap();
184
185 handle_response(response.recv_blocking())
186 }
187
188 fn create_tensor(
189 &self,
190 data: &[u8],
191 shape: &[usize],
192 elem_size: usize,
193 ) -> (Handle, Vec<usize>) {
194 let (callback, response) = async_channel::unbounded();
195
196 self.state
197 .sender
198 .send_blocking(Message::CreateTensor(
199 data.to_vec(),
200 shape.to_vec(),
201 elem_size,
202 callback,
203 ))
204 .unwrap();
205
206 handle_response(response.recv_blocking())
207 }
208
209 fn empty(&self, size: usize) -> Handle {
210 let (callback, response) = async_channel::unbounded();
211 self.state
212 .sender
213 .send_blocking(Message::Empty(size, callback))
214 .unwrap();
215
216 handle_response(response.recv_blocking())
217 }
218
219 fn empty_tensor(&self, shape: &[usize], elem_size: usize) -> (Handle, Vec<usize>) {
220 let (callback, response) = async_channel::unbounded();
221 self.state
222 .sender
223 .send_blocking(Message::EmptyTensor(shape.to_vec(), elem_size, callback))
224 .unwrap();
225
226 handle_response(response.recv_blocking())
227 }
228
229 unsafe fn execute(
230 &self,
231 kernel: Server::Kernel,
232 count: CubeCount,
233 bindings: Bindings,
234 kind: ExecutionMode,
235 ) {
236 self.state
237 .sender
238 .send_blocking(Message::ExecuteKernel((kernel, count, kind), bindings))
239 .unwrap();
240 }
241
242 fn flush(&self) {
243 self.state.sender.send_blocking(Message::Flush).unwrap()
244 }
245
246 async fn sync(&self) {
247 let (callback, response) = async_channel::unbounded();
248 self.state
249 .sender
250 .send(Message::Sync(callback))
251 .await
252 .unwrap();
253 handle_response(response.recv().await)
254 }
255
256 fn memory_usage(&self) -> crate::memory_management::MemoryUsage {
257 let (callback, response) = async_channel::unbounded();
258 self.state
259 .sender
260 .send_blocking(Message::MemoryUsage(callback))
261 .unwrap();
262 handle_response(response.recv_blocking())
263 }
264
265 fn memory_cleanup(&self) {
266 self.state
267 .sender
268 .send_blocking(Message::MemoryCleanup)
269 .unwrap()
270 }
271
272 fn start_profile(&self) {
273 self.state
274 .sender
275 .send_blocking(Message::StartProfile)
276 .unwrap();
277 }
278
279 fn end_profile(&self) -> ProfileDuration {
280 let (callback, response) = async_channel::unbounded();
281 self.state
282 .sender
283 .send_blocking(Message::StopMeasure(callback))
284 .unwrap();
285 handle_response(response.recv_blocking())
286 }
287}
288
289fn handle_response<Response, Err: core::fmt::Debug>(response: Result<Response, Err>) -> Response {
290 match response {
291 Ok(val) => val,
292 Err(err) => panic!("Can't connect to the server correctly {err:?}"),
293 }
294}