cubecl_runtime/channel/
mutex.rs1use super::ComputeChannel;
2use crate::logging::ServerLogger;
3use crate::server::{
4 Binding, BindingWithMeta, Bindings, ComputeServer, CubeCount, Handle, ProfileError,
5 ProfilingToken,
6};
7use crate::storage::{BindingResource, ComputeStorage};
8use alloc::sync::Arc;
9use alloc::vec::Vec;
10use cubecl_common::ExecutionMode;
11use cubecl_common::future::DynFut;
12use cubecl_common::profile::ProfileDuration;
13use spin::Mutex;
14
15#[derive(Debug)]
18pub struct MutexComputeChannel<Server> {
19 server: Arc<Mutex<Server>>,
20}
21
22impl<S> Clone for MutexComputeChannel<S> {
23 fn clone(&self) -> Self {
24 Self {
25 server: self.server.clone(),
26 }
27 }
28}
29impl<Server> MutexComputeChannel<Server>
30where
31 Server: ComputeServer,
32{
33 pub fn new(server: Server) -> Self {
35 Self {
36 server: Arc::new(Mutex::new(server)),
37 }
38 }
39}
40
41impl<Server> ComputeChannel<Server> for MutexComputeChannel<Server>
42where
43 Server: ComputeServer,
44{
45 fn read(&self, bindings: Vec<Binding>) -> DynFut<Vec<Vec<u8>>> {
46 let mut server = self.server.lock();
47 server.read(bindings)
48 }
49
50 fn read_tensor(&self, bindings: Vec<BindingWithMeta>) -> DynFut<Vec<Vec<u8>>> {
51 let mut server = self.server.lock();
52 server.read_tensor(bindings)
53 }
54
55 fn sync(&self) -> DynFut<()> {
56 let mut server = self.server.lock();
57 server.sync()
58 }
59
60 fn get_resource(
61 &self,
62 binding: Binding,
63 ) -> BindingResource<<Server::Storage as ComputeStorage>::Resource> {
64 self.server.lock().get_resource(binding)
65 }
66
67 fn create(&self, data: &[u8]) -> Handle {
68 self.server.lock().create(data)
69 }
70
71 fn create_tensors(
72 &self,
73 data: Vec<&[u8]>,
74 shape: Vec<&[usize]>,
75 elem_size: Vec<usize>,
76 ) -> Vec<(Handle, Vec<usize>)> {
77 self.server.lock().create_tensors(data, shape, elem_size)
78 }
79
80 fn empty(&self, size: usize) -> Handle {
81 self.server.lock().empty(size)
82 }
83
84 fn empty_tensors(
85 &self,
86 shape: Vec<&[usize]>,
87 elem_size: Vec<usize>,
88 ) -> Vec<(Handle, Vec<usize>)> {
89 self.server.lock().empty_tensors(shape, elem_size)
90 }
91
92 unsafe fn execute(
93 &self,
94 kernel: Server::Kernel,
95 count: CubeCount,
96 handles: Bindings,
97 kind: ExecutionMode,
98 logger: Arc<ServerLogger>,
99 ) {
100 unsafe {
101 self.server
102 .lock()
103 .execute(kernel, count, handles, kind, logger)
104 }
105 }
106
107 fn flush(&self) {
108 self.server.lock().flush();
109 }
110
111 fn memory_usage(&self) -> crate::memory_management::MemoryUsage {
112 self.server.lock().memory_usage()
113 }
114
115 fn memory_cleanup(&self) {
116 self.server.lock().memory_cleanup();
117 }
118
119 fn start_profile(&self) -> ProfilingToken {
120 self.server.lock().start_profile()
121 }
122
123 fn end_profile(&self, token: ProfilingToken) -> Result<ProfileDuration, ProfileError> {
124 self.server.lock().end_profile(token)
125 }
126}