cubecl_runtime/channel/
cell.rs1use super::ComputeChannel;
2use crate::logging::ServerLogger;
3use crate::server::{
4 Binding, BindingWithMeta, Bindings, ComputeServer, CubeCount, Handle, ProfileError,
5 ProfilingToken,
6};
7use crate::storage::{BindingResource, ComputeStorage};
8use alloc::sync::Arc;
9use alloc::vec::Vec;
10use cubecl_common::ExecutionMode;
11use cubecl_common::future::DynFut;
12use cubecl_common::profile::ProfileDuration;
13
14#[derive(Debug)]
24pub struct RefCellComputeChannel<Server> {
25 server: Arc<core::cell::RefCell<Server>>,
26}
27
28impl<S> Clone for RefCellComputeChannel<S> {
29 fn clone(&self) -> Self {
30 Self {
31 server: self.server.clone(),
32 }
33 }
34}
35
36impl<Server> RefCellComputeChannel<Server>
37where
38 Server: ComputeServer,
39{
40 pub fn new(server: Server) -> Self {
42 Self {
43 server: Arc::new(core::cell::RefCell::new(server)),
44 }
45 }
46}
47
48impl<Server> ComputeChannel<Server> for RefCellComputeChannel<Server>
49where
50 Server: ComputeServer + Send,
51{
52 fn read(&self, bindings: Vec<Binding>) -> DynFut<Vec<Vec<u8>>> {
53 let mut server = self.server.borrow_mut();
54 server.read(bindings)
55 }
56
57 fn read_tensor(&self, bindings: Vec<BindingWithMeta>) -> DynFut<Vec<Vec<u8>>> {
58 let mut server = self.server.borrow_mut();
59 server.read_tensor(bindings)
60 }
61
62 fn sync(&self) -> DynFut<()> {
63 let mut server = self.server.borrow_mut();
64 server.sync()
65 }
66
67 fn get_resource(
68 &self,
69 binding: Binding,
70 ) -> BindingResource<<Server::Storage as ComputeStorage>::Resource> {
71 self.server.borrow_mut().get_resource(binding)
72 }
73
74 fn create(&self, resource: &[u8]) -> Handle {
75 self.server.borrow_mut().create(resource)
76 }
77
78 fn create_tensors(
79 &self,
80 data: Vec<&[u8]>,
81 shape: Vec<&[usize]>,
82 elem_size: Vec<usize>,
83 ) -> Vec<(Handle, Vec<usize>)> {
84 self.server
85 .borrow_mut()
86 .create_tensors(data, shape, elem_size)
87 }
88
89 fn empty(&self, size: usize) -> Handle {
90 self.server.borrow_mut().empty(size)
91 }
92
93 fn empty_tensors(
94 &self,
95 shape: Vec<&[usize]>,
96 elem_size: Vec<usize>,
97 ) -> Vec<(Handle, Vec<usize>)> {
98 self.server.borrow_mut().empty_tensors(shape, elem_size)
99 }
100
101 unsafe fn execute(
102 &self,
103 kernel_description: Server::Kernel,
104 count: CubeCount,
105 bindings: Bindings,
106 kind: ExecutionMode,
107 logger: Arc<ServerLogger>,
108 ) {
109 unsafe {
110 self.server
111 .borrow_mut()
112 .execute(kernel_description, count, bindings, kind, logger)
113 }
114 }
115
116 fn flush(&self) {
117 self.server.borrow_mut().flush()
118 }
119
120 fn memory_usage(&self) -> crate::memory_management::MemoryUsage {
121 self.server.borrow_mut().memory_usage()
122 }
123
124 fn memory_cleanup(&self) {
125 self.server.borrow_mut().memory_cleanup();
126 }
127
128 fn start_profile(&self) -> ProfilingToken {
129 self.server.borrow_mut().start_profile()
130 }
131
132 fn end_profile(&self, token: ProfilingToken) -> Result<ProfileDuration, ProfileError> {
133 self.server.borrow_mut().end_profile(token)
134 }
135}
136
137unsafe impl<Server: ComputeServer> Send for RefCellComputeChannel<Server> {}
140unsafe impl<Server: ComputeServer> Sync for RefCellComputeChannel<Server> {}