cubecl_runtime/channel/
mpsc.rs

1use std::{sync::Arc, thread};
2
3use cubecl_common::{ExecutionMode, benchmark::ProfileDuration};
4
5use super::ComputeChannel;
6use crate::{
7    memory_management::MemoryUsage,
8    server::{Binding, BindingWithMeta, Bindings, ComputeServer, CubeCount, Handle},
9    storage::{BindingResource, ComputeStorage},
10};
11
12/// Create a channel using a [multi-producer, single-consumer channel to communicate with
13/// the compute server spawn on its own thread.
14#[derive(Debug)]
15pub struct MpscComputeChannel<Server>
16where
17    Server: ComputeServer,
18{
19    state: Arc<MpscComputeChannelState<Server>>,
20}
21
22#[derive(Debug)]
23struct MpscComputeChannelState<Server>
24where
25    Server: ComputeServer,
26{
27    _handle: thread::JoinHandle<()>,
28    sender: async_channel::Sender<Message<Server>>,
29}
30
31type Callback<Response> = async_channel::Sender<Response>;
32
33enum Message<Server>
34where
35    Server: ComputeServer,
36{
37    Read(Vec<Binding>, Callback<Vec<Vec<u8>>>),
38    ReadTensor(Vec<BindingWithMeta>, Callback<Vec<Vec<u8>>>),
39    GetResource(
40        Binding,
41        Callback<BindingResource<<Server::Storage as ComputeStorage>::Resource>>,
42    ),
43    Create(Vec<u8>, Callback<Handle>),
44    CreateTensor(Vec<u8>, Vec<usize>, usize, Callback<(Handle, Vec<usize>)>),
45    Empty(usize, Callback<Handle>),
46    EmptyTensor(Vec<usize>, usize, Callback<(Handle, Vec<usize>)>),
47    ExecuteKernel((Server::Kernel, CubeCount, ExecutionMode), Bindings),
48    Flush,
49    Sync(Callback<()>),
50    MemoryUsage(Callback<MemoryUsage>),
51    MemoryCleanup,
52    StartProfile,
53    StopMeasure(Callback<ProfileDuration>),
54}
55
56impl<Server> MpscComputeChannel<Server>
57where
58    Server: ComputeServer + 'static,
59{
60    /// Create a new mpsc compute channel.
61    pub fn new(mut server: Server) -> Self {
62        let (sender, receiver) = async_channel::unbounded();
63
64        let _handle = thread::spawn(move || {
65            // Run the whole procedure as one blocking future. This is much simpler than trying
66            // to use some multithreaded executor.
67            cubecl_common::future::block_on(async {
68                while let Ok(message) = receiver.recv().await {
69                    match message {
70                        Message::Read(bindings, callback) => {
71                            let data = server.read(bindings).await;
72                            callback.send(data).await.unwrap();
73                        }
74                        Message::ReadTensor(bindings, callback) => {
75                            let data = server.read_tensor(bindings).await;
76                            callback.send(data).await.unwrap();
77                        }
78                        Message::GetResource(binding, callback) => {
79                            let data = server.get_resource(binding);
80                            callback.send(data).await.unwrap();
81                        }
82                        Message::Create(data, callback) => {
83                            let handle = server.create(&data);
84                            callback.send(handle).await.unwrap();
85                        }
86                        Message::CreateTensor(data, shape, elem_size, callback) => {
87                            let handle = server.create_tensor(&data, &shape, elem_size);
88                            callback.send(handle).await.unwrap();
89                        }
90                        Message::Empty(size, callback) => {
91                            let handle = server.empty(size);
92                            callback.send(handle).await.unwrap();
93                        }
94                        Message::EmptyTensor(shape, elem_size, callback) => {
95                            let handle = server.empty_tensor(&shape, elem_size);
96                            callback.send(handle).await.unwrap();
97                        }
98                        Message::ExecuteKernel(kernel, bindings) => unsafe {
99                            server.execute(kernel.0, kernel.1, bindings, kernel.2);
100                        },
101                        Message::Sync(callback) => {
102                            server.sync().await;
103                            callback.send(()).await.unwrap();
104                        }
105                        Message::Flush => {
106                            server.flush();
107                        }
108                        Message::MemoryUsage(callback) => {
109                            callback.send(server.memory_usage()).await.unwrap();
110                        }
111                        Message::MemoryCleanup => {
112                            server.memory_cleanup();
113                        }
114                        Message::StartProfile => {
115                            server.start_profile();
116                        }
117                        Message::StopMeasure(callback) => {
118                            callback.send(server.end_profile()).await.unwrap();
119                        }
120                    };
121                }
122            });
123        });
124
125        let state = Arc::new(MpscComputeChannelState { sender, _handle });
126
127        Self { state }
128    }
129}
130
131impl<Server: ComputeServer> Clone for MpscComputeChannel<Server> {
132    fn clone(&self) -> Self {
133        Self {
134            state: self.state.clone(),
135        }
136    }
137}
138
139impl<Server> ComputeChannel<Server> for MpscComputeChannel<Server>
140where
141    Server: ComputeServer + 'static,
142{
143    async fn read(&self, bindings: Vec<Binding>) -> Vec<Vec<u8>> {
144        let sender = self.state.sender.clone();
145        let (callback, response) = async_channel::unbounded();
146        sender
147            .send(Message::Read(bindings, callback))
148            .await
149            .unwrap();
150        handle_response(response.recv().await)
151    }
152
153    async fn read_tensor(&self, bindings: Vec<BindingWithMeta>) -> Vec<Vec<u8>> {
154        let sender = self.state.sender.clone();
155        let (callback, response) = async_channel::unbounded();
156        sender
157            .send(Message::ReadTensor(bindings, callback))
158            .await
159            .unwrap();
160        handle_response(response.recv().await)
161    }
162
163    fn get_resource(
164        &self,
165        binding: Binding,
166    ) -> BindingResource<<Server::Storage as ComputeStorage>::Resource> {
167        let (callback, response) = async_channel::unbounded();
168
169        self.state
170            .sender
171            .send_blocking(Message::GetResource(binding, callback))
172            .unwrap();
173
174        handle_response(response.recv_blocking())
175    }
176
177    fn create(&self, data: &[u8]) -> Handle {
178        let (callback, response) = async_channel::unbounded();
179
180        self.state
181            .sender
182            .send_blocking(Message::Create(data.to_vec(), callback))
183            .unwrap();
184
185        handle_response(response.recv_blocking())
186    }
187
188    fn create_tensor(
189        &self,
190        data: &[u8],
191        shape: &[usize],
192        elem_size: usize,
193    ) -> (Handle, Vec<usize>) {
194        let (callback, response) = async_channel::unbounded();
195
196        self.state
197            .sender
198            .send_blocking(Message::CreateTensor(
199                data.to_vec(),
200                shape.to_vec(),
201                elem_size,
202                callback,
203            ))
204            .unwrap();
205
206        handle_response(response.recv_blocking())
207    }
208
209    fn empty(&self, size: usize) -> Handle {
210        let (callback, response) = async_channel::unbounded();
211        self.state
212            .sender
213            .send_blocking(Message::Empty(size, callback))
214            .unwrap();
215
216        handle_response(response.recv_blocking())
217    }
218
219    fn empty_tensor(&self, shape: &[usize], elem_size: usize) -> (Handle, Vec<usize>) {
220        let (callback, response) = async_channel::unbounded();
221        self.state
222            .sender
223            .send_blocking(Message::EmptyTensor(shape.to_vec(), elem_size, callback))
224            .unwrap();
225
226        handle_response(response.recv_blocking())
227    }
228
229    unsafe fn execute(
230        &self,
231        kernel: Server::Kernel,
232        count: CubeCount,
233        bindings: Bindings,
234        kind: ExecutionMode,
235    ) {
236        self.state
237            .sender
238            .send_blocking(Message::ExecuteKernel((kernel, count, kind), bindings))
239            .unwrap();
240    }
241
242    fn flush(&self) {
243        self.state.sender.send_blocking(Message::Flush).unwrap()
244    }
245
246    async fn sync(&self) {
247        let (callback, response) = async_channel::unbounded();
248        self.state
249            .sender
250            .send(Message::Sync(callback))
251            .await
252            .unwrap();
253        handle_response(response.recv().await)
254    }
255
256    fn memory_usage(&self) -> crate::memory_management::MemoryUsage {
257        let (callback, response) = async_channel::unbounded();
258        self.state
259            .sender
260            .send_blocking(Message::MemoryUsage(callback))
261            .unwrap();
262        handle_response(response.recv_blocking())
263    }
264
265    fn memory_cleanup(&self) {
266        self.state
267            .sender
268            .send_blocking(Message::MemoryCleanup)
269            .unwrap()
270    }
271
272    fn start_profile(&self) {
273        self.state
274            .sender
275            .send_blocking(Message::StartProfile)
276            .unwrap();
277    }
278
279    fn end_profile(&self) -> ProfileDuration {
280        let (callback, response) = async_channel::unbounded();
281        self.state
282            .sender
283            .send_blocking(Message::StopMeasure(callback))
284            .unwrap();
285        handle_response(response.recv_blocking())
286    }
287}
288
289fn handle_response<Response, Err: core::fmt::Debug>(response: Result<Response, Err>) -> Response {
290    match response {
291        Ok(val) => val,
292        Err(err) => panic!("Can't connect to the server correctly {err:?}"),
293    }
294}