cubecl_runtime/channel/
mutex.rs

1use super::ComputeChannel;
2use crate::server::{Binding, BindingWithMeta, Bindings, ComputeServer, CubeCount, Handle};
3use crate::storage::{BindingResource, ComputeStorage};
4use alloc::sync::Arc;
5use alloc::vec::Vec;
6use cubecl_common::ExecutionMode;
7use cubecl_common::benchmark::ProfileDuration;
8use spin::Mutex;
9
10/// The MutexComputeChannel ensures thread-safety by locking the server
11/// on every operation
12#[derive(Debug)]
13pub struct MutexComputeChannel<Server> {
14    server: Arc<Mutex<Server>>,
15}
16
17impl<S> Clone for MutexComputeChannel<S> {
18    fn clone(&self) -> Self {
19        Self {
20            server: self.server.clone(),
21        }
22    }
23}
24impl<Server> MutexComputeChannel<Server>
25where
26    Server: ComputeServer,
27{
28    /// Create a new mutex compute channel.
29    pub fn new(server: Server) -> Self {
30        Self {
31            server: Arc::new(Mutex::new(server)),
32        }
33    }
34}
35
36impl<Server> ComputeChannel<Server> for MutexComputeChannel<Server>
37where
38    Server: ComputeServer,
39{
40    async fn read(&self, bindings: Vec<Binding>) -> Vec<Vec<u8>> {
41        // Nb: The order here is really important - the mutex guard has to be dropped before
42        // the future is polled. Just calling lock().read().await can deadlock.
43        let fut = {
44            let mut server = self.server.lock();
45            server.read(bindings)
46        };
47        fut.await
48    }
49
50    async fn read_tensor(&self, bindings: Vec<BindingWithMeta>) -> Vec<Vec<u8>> {
51        // Nb: The order here is really important - the mutex guard has to be dropped before
52        // the future is polled. Just calling lock().read().await can deadlock.
53        let fut = {
54            let mut server = self.server.lock();
55            server.read_tensor(bindings)
56        };
57        fut.await
58    }
59
60    fn get_resource(
61        &self,
62        binding: Binding,
63    ) -> BindingResource<<Server::Storage as ComputeStorage>::Resource> {
64        self.server.lock().get_resource(binding)
65    }
66
67    fn create(&self, data: &[u8]) -> Handle {
68        self.server.lock().create(data)
69    }
70
71    fn create_tensor(
72        &self,
73        data: &[u8],
74        shape: &[usize],
75        elem_size: usize,
76    ) -> (Handle, Vec<usize>) {
77        self.server.lock().create_tensor(data, shape, elem_size)
78    }
79
80    fn empty(&self, size: usize) -> Handle {
81        self.server.lock().empty(size)
82    }
83
84    fn empty_tensor(&self, shape: &[usize], elem_size: usize) -> (Handle, Vec<usize>) {
85        self.server.lock().empty_tensor(shape, elem_size)
86    }
87
88    unsafe fn execute(
89        &self,
90        kernel: Server::Kernel,
91        count: CubeCount,
92        handles: Bindings,
93        kind: ExecutionMode,
94    ) {
95        unsafe { self.server.lock().execute(kernel, count, handles, kind) }
96    }
97
98    fn flush(&self) {
99        self.server.lock().flush();
100    }
101
102    async fn sync(&self) {
103        // Nb: The order here is really important - the mutex guard has to be dropped before
104        // the future is polled. Just calling lock().sync().await can deadlock.
105        let fut = {
106            let mut server = self.server.lock();
107            server.sync()
108        };
109        fut.await
110    }
111
112    fn memory_usage(&self) -> crate::memory_management::MemoryUsage {
113        self.server.lock().memory_usage()
114    }
115
116    fn memory_cleanup(&self) {
117        self.server.lock().memory_cleanup();
118    }
119
120    fn start_profile(&self) {
121        self.server.lock().start_profile();
122    }
123
124    fn end_profile(&self) -> ProfileDuration {
125        self.server.lock().end_profile()
126    }
127}