cubecl_runtime/channel/
mutex.rs

1use super::ComputeChannel;
2use crate::server::{Binding, ComputeServer, CubeCount, Handle};
3use crate::storage::BindingResource;
4use crate::ExecutionMode;
5use alloc::sync::Arc;
6use alloc::vec::Vec;
7use cubecl_common::benchmark::TimestampsResult;
8use spin::Mutex;
9
10/// The MutexComputeChannel ensures thread-safety by locking the server
11/// on every operation
12#[derive(Debug)]
13pub struct MutexComputeChannel<Server> {
14    server: Arc<Mutex<Server>>,
15}
16
17impl<S> Clone for MutexComputeChannel<S> {
18    fn clone(&self) -> Self {
19        Self {
20            server: self.server.clone(),
21        }
22    }
23}
24impl<Server> MutexComputeChannel<Server>
25where
26    Server: ComputeServer,
27{
28    /// Create a new mutex compute channel.
29    pub fn new(server: Server) -> Self {
30        Self {
31            server: Arc::new(Mutex::new(server)),
32        }
33    }
34}
35
36impl<Server> ComputeChannel<Server> for MutexComputeChannel<Server>
37where
38    Server: ComputeServer,
39{
40    async fn read(&self, bindings: Vec<Binding>) -> Vec<Vec<u8>> {
41        // Nb: The order here is really important - the mutex guard has to be dropped before
42        // the future is polled. Just calling lock().read().await can deadlock.
43        let fut = {
44            let mut server = self.server.lock();
45            server.read(bindings)
46        };
47        fut.await
48    }
49
50    fn get_resource(&self, binding: Binding) -> BindingResource<Server> {
51        self.server.lock().get_resource(binding)
52    }
53
54    fn create(&self, data: &[u8]) -> Handle {
55        self.server.lock().create(data)
56    }
57
58    fn empty(&self, size: usize) -> Handle {
59        self.server.lock().empty(size)
60    }
61
62    unsafe fn execute(
63        &self,
64        kernel: Server::Kernel,
65        count: CubeCount,
66        handles: Vec<Binding>,
67        kind: ExecutionMode,
68    ) {
69        self.server.lock().execute(kernel, count, handles, kind)
70    }
71
72    fn flush(&self) {
73        self.server.lock().flush();
74    }
75
76    async fn sync(&self) {
77        // Nb: The order here is really important - the mutex guard has to be dropped before
78        // the future is polled. Just calling lock().sync().await can deadlock.
79        let fut = {
80            let mut server = self.server.lock();
81            server.sync()
82        };
83        fut.await
84    }
85
86    async fn sync_elapsed(&self) -> TimestampsResult {
87        // Nb: The order here is really important - the mutex guard has to be dropped before
88        // the future is polled. Just calling lock().sync().await can deadlock.
89        let fut = {
90            let mut server = self.server.lock();
91            server.sync_elapsed()
92        };
93        fut.await
94    }
95
96    fn memory_usage(&self) -> crate::memory_management::MemoryUsage {
97        self.server.lock().memory_usage()
98    }
99
100    fn enable_timestamps(&self) {
101        self.server.lock().enable_timestamps();
102    }
103
104    fn disable_timestamps(&self) {
105        self.server.lock().disable_timestamps();
106    }
107}