cubecl_runtime/channel/
mutex.rs1use super::ComputeChannel;
2use crate::server::{Binding, BindingWithMeta, Bindings, ComputeServer, CubeCount, Handle};
3use crate::storage::{BindingResource, ComputeStorage};
4use alloc::sync::Arc;
5use alloc::vec::Vec;
6use cubecl_common::ExecutionMode;
7use cubecl_common::benchmark::ProfileDuration;
8use spin::Mutex;
9
10#[derive(Debug)]
13pub struct MutexComputeChannel<Server> {
14 server: Arc<Mutex<Server>>,
15}
16
17impl<S> Clone for MutexComputeChannel<S> {
18 fn clone(&self) -> Self {
19 Self {
20 server: self.server.clone(),
21 }
22 }
23}
24impl<Server> MutexComputeChannel<Server>
25where
26 Server: ComputeServer,
27{
28 pub fn new(server: Server) -> Self {
30 Self {
31 server: Arc::new(Mutex::new(server)),
32 }
33 }
34}
35
36impl<Server> ComputeChannel<Server> for MutexComputeChannel<Server>
37where
38 Server: ComputeServer,
39{
40 async fn read(&self, bindings: Vec<Binding>) -> Vec<Vec<u8>> {
41 let fut = {
44 let mut server = self.server.lock();
45 server.read(bindings)
46 };
47 fut.await
48 }
49
50 async fn read_tensor(&self, bindings: Vec<BindingWithMeta>) -> Vec<Vec<u8>> {
51 let fut = {
54 let mut server = self.server.lock();
55 server.read_tensor(bindings)
56 };
57 fut.await
58 }
59
60 fn get_resource(
61 &self,
62 binding: Binding,
63 ) -> BindingResource<<Server::Storage as ComputeStorage>::Resource> {
64 self.server.lock().get_resource(binding)
65 }
66
67 fn create(&self, data: &[u8]) -> Handle {
68 self.server.lock().create(data)
69 }
70
71 fn create_tensor(
72 &self,
73 data: &[u8],
74 shape: &[usize],
75 elem_size: usize,
76 ) -> (Handle, Vec<usize>) {
77 self.server.lock().create_tensor(data, shape, elem_size)
78 }
79
80 fn empty(&self, size: usize) -> Handle {
81 self.server.lock().empty(size)
82 }
83
84 fn empty_tensor(&self, shape: &[usize], elem_size: usize) -> (Handle, Vec<usize>) {
85 self.server.lock().empty_tensor(shape, elem_size)
86 }
87
88 unsafe fn execute(
89 &self,
90 kernel: Server::Kernel,
91 count: CubeCount,
92 handles: Bindings,
93 kind: ExecutionMode,
94 ) {
95 unsafe { self.server.lock().execute(kernel, count, handles, kind) }
96 }
97
98 fn flush(&self) {
99 self.server.lock().flush();
100 }
101
102 async fn sync(&self) {
103 let fut = {
106 let mut server = self.server.lock();
107 server.sync()
108 };
109 fut.await
110 }
111
112 fn memory_usage(&self) -> crate::memory_management::MemoryUsage {
113 self.server.lock().memory_usage()
114 }
115
116 fn memory_cleanup(&self) {
117 self.server.lock().memory_cleanup();
118 }
119
120 fn start_profile(&self) {
121 self.server.lock().start_profile();
122 }
123
124 fn end_profile(&self) -> ProfileDuration {
125 self.server.lock().end_profile()
126 }
127}