cubecl_runtime/channel/
mutex.rs1use super::ComputeChannel;
2use crate::server::{Binding, ComputeServer, CubeCount, Handle};
3use crate::storage::BindingResource;
4use crate::ExecutionMode;
5use alloc::sync::Arc;
6use alloc::vec::Vec;
7use cubecl_common::benchmark::TimestampsResult;
8use spin::Mutex;
9
10#[derive(Debug)]
13pub struct MutexComputeChannel<Server> {
14 server: Arc<Mutex<Server>>,
15}
16
17impl<S> Clone for MutexComputeChannel<S> {
18 fn clone(&self) -> Self {
19 Self {
20 server: self.server.clone(),
21 }
22 }
23}
24impl<Server> MutexComputeChannel<Server>
25where
26 Server: ComputeServer,
27{
28 pub fn new(server: Server) -> Self {
30 Self {
31 server: Arc::new(Mutex::new(server)),
32 }
33 }
34}
35
36impl<Server> ComputeChannel<Server> for MutexComputeChannel<Server>
37where
38 Server: ComputeServer,
39{
40 async fn read(&self, bindings: Vec<Binding>) -> Vec<Vec<u8>> {
41 let fut = {
44 let mut server = self.server.lock();
45 server.read(bindings)
46 };
47 fut.await
48 }
49
50 fn get_resource(&self, binding: Binding) -> BindingResource<Server> {
51 self.server.lock().get_resource(binding)
52 }
53
54 fn create(&self, data: &[u8]) -> Handle {
55 self.server.lock().create(data)
56 }
57
58 fn empty(&self, size: usize) -> Handle {
59 self.server.lock().empty(size)
60 }
61
62 unsafe fn execute(
63 &self,
64 kernel: Server::Kernel,
65 count: CubeCount,
66 handles: Vec<Binding>,
67 kind: ExecutionMode,
68 ) {
69 self.server.lock().execute(kernel, count, handles, kind)
70 }
71
72 fn flush(&self) {
73 self.server.lock().flush();
74 }
75
76 async fn sync(&self) {
77 let fut = {
80 let mut server = self.server.lock();
81 server.sync()
82 };
83 fut.await
84 }
85
86 async fn sync_elapsed(&self) -> TimestampsResult {
87 let fut = {
90 let mut server = self.server.lock();
91 server.sync_elapsed()
92 };
93 fut.await
94 }
95
96 fn memory_usage(&self) -> crate::memory_management::MemoryUsage {
97 self.server.lock().memory_usage()
98 }
99
100 fn enable_timestamps(&self) {
101 self.server.lock().enable_timestamps();
102 }
103
104 fn disable_timestamps(&self) {
105 self.server.lock().disable_timestamps();
106 }
107}