cubecl_runtime/channel/
mutex.rs

1use super::ComputeChannel;
2use crate::logging::ServerLogger;
3use crate::server::{
4    Binding, BindingWithMeta, Bindings, ComputeServer, CubeCount, Handle, ProfileError,
5    ProfilingToken,
6};
7use crate::storage::{BindingResource, ComputeStorage};
8use alloc::sync::Arc;
9use alloc::vec::Vec;
10use cubecl_common::ExecutionMode;
11use cubecl_common::future::DynFut;
12use cubecl_common::profile::ProfileDuration;
13use spin::Mutex;
14
15/// The MutexComputeChannel ensures thread-safety by locking the server
16/// on every operation
17#[derive(Debug)]
18pub struct MutexComputeChannel<Server> {
19    server: Arc<Mutex<Server>>,
20}
21
22impl<S> Clone for MutexComputeChannel<S> {
23    fn clone(&self) -> Self {
24        Self {
25            server: self.server.clone(),
26        }
27    }
28}
29impl<Server> MutexComputeChannel<Server>
30where
31    Server: ComputeServer,
32{
33    /// Create a new mutex compute channel.
34    pub fn new(server: Server) -> Self {
35        Self {
36            server: Arc::new(Mutex::new(server)),
37        }
38    }
39}
40
41impl<Server> ComputeChannel<Server> for MutexComputeChannel<Server>
42where
43    Server: ComputeServer,
44{
45    fn read(&self, bindings: Vec<Binding>) -> DynFut<Vec<Vec<u8>>> {
46        let mut server = self.server.lock();
47        server.read(bindings)
48    }
49
50    fn read_tensor(&self, bindings: Vec<BindingWithMeta>) -> DynFut<Vec<Vec<u8>>> {
51        let mut server = self.server.lock();
52        server.read_tensor(bindings)
53    }
54
55    fn sync(&self) -> DynFut<()> {
56        let mut server = self.server.lock();
57        server.sync()
58    }
59
60    fn get_resource(
61        &self,
62        binding: Binding,
63    ) -> BindingResource<<Server::Storage as ComputeStorage>::Resource> {
64        self.server.lock().get_resource(binding)
65    }
66
67    fn create(&self, data: &[u8]) -> Handle {
68        self.server.lock().create(data)
69    }
70
71    fn create_tensors(
72        &self,
73        data: Vec<&[u8]>,
74        shape: Vec<&[usize]>,
75        elem_size: Vec<usize>,
76    ) -> Vec<(Handle, Vec<usize>)> {
77        self.server.lock().create_tensors(data, shape, elem_size)
78    }
79
80    fn empty(&self, size: usize) -> Handle {
81        self.server.lock().empty(size)
82    }
83
84    fn empty_tensors(
85        &self,
86        shape: Vec<&[usize]>,
87        elem_size: Vec<usize>,
88    ) -> Vec<(Handle, Vec<usize>)> {
89        self.server.lock().empty_tensors(shape, elem_size)
90    }
91
92    unsafe fn execute(
93        &self,
94        kernel: Server::Kernel,
95        count: CubeCount,
96        handles: Bindings,
97        kind: ExecutionMode,
98        logger: Arc<ServerLogger>,
99    ) {
100        unsafe {
101            self.server
102                .lock()
103                .execute(kernel, count, handles, kind, logger)
104        }
105    }
106
107    fn flush(&self) {
108        self.server.lock().flush();
109    }
110
111    fn memory_usage(&self) -> crate::memory_management::MemoryUsage {
112        self.server.lock().memory_usage()
113    }
114
115    fn memory_cleanup(&self) {
116        self.server.lock().memory_cleanup();
117    }
118
119    fn start_profile(&self) -> ProfilingToken {
120        self.server.lock().start_profile()
121    }
122
123    fn end_profile(&self, token: ProfilingToken) -> Result<ProfileDuration, ProfileError> {
124        self.server.lock().end_profile(token)
125    }
126}