cubecl_runtime/channel/
cell.rs1use super::ComputeChannel;
2use crate::server::{Binding, BindingWithMeta, Bindings, ComputeServer, CubeCount, Handle};
3use crate::storage::{BindingResource, ComputeStorage};
4use alloc::sync::Arc;
5use alloc::vec::Vec;
6use cubecl_common::ExecutionMode;
7use cubecl_common::benchmark::ProfileDuration;
8
9#[derive(Debug)]
19pub struct RefCellComputeChannel<Server> {
20 server: Arc<core::cell::RefCell<Server>>,
21}
22
23impl<S> Clone for RefCellComputeChannel<S> {
24 fn clone(&self) -> Self {
25 Self {
26 server: self.server.clone(),
27 }
28 }
29}
30
31impl<Server> RefCellComputeChannel<Server>
32where
33 Server: ComputeServer,
34{
35 pub fn new(server: Server) -> Self {
37 Self {
38 server: Arc::new(core::cell::RefCell::new(server)),
39 }
40 }
41}
42
43impl<Server> ComputeChannel<Server> for RefCellComputeChannel<Server>
44where
45 Server: ComputeServer + Send,
46{
47 async fn read(&self, bindings: Vec<Binding>) -> Vec<Vec<u8>> {
48 let future = {
49 let mut server = self.server.borrow_mut();
50 server.read(bindings)
51 };
52 future.await
53 }
54
55 async fn read_tensor(&self, bindings: Vec<BindingWithMeta>) -> Vec<Vec<u8>> {
56 let future = {
57 let mut server = self.server.borrow_mut();
58 server.read_tensor(bindings)
59 };
60 future.await
61 }
62
63 fn get_resource(
64 &self,
65 binding: Binding,
66 ) -> BindingResource<<Server::Storage as ComputeStorage>::Resource> {
67 self.server.borrow_mut().get_resource(binding)
68 }
69
70 fn create(&self, resource: &[u8]) -> Handle {
71 self.server.borrow_mut().create(resource)
72 }
73
74 fn create_tensor(
75 &self,
76 data: &[u8],
77 shape: &[usize],
78 elem_size: usize,
79 ) -> (Handle, Vec<usize>) {
80 self.server
81 .borrow_mut()
82 .create_tensor(data, shape, elem_size)
83 }
84
85 fn empty(&self, size: usize) -> Handle {
86 self.server.borrow_mut().empty(size)
87 }
88
89 fn empty_tensor(&self, shape: &[usize], elem_size: usize) -> (Handle, Vec<usize>) {
90 self.server.borrow_mut().empty_tensor(shape, elem_size)
91 }
92
93 unsafe fn execute(
94 &self,
95 kernel_description: Server::Kernel,
96 count: CubeCount,
97 bindings: Bindings,
98 kind: ExecutionMode,
99 ) {
100 unsafe {
101 self.server
102 .borrow_mut()
103 .execute(kernel_description, count, bindings, kind)
104 }
105 }
106
107 fn flush(&self) {
108 self.server.borrow_mut().flush()
109 }
110
111 async fn sync(&self) {
112 let future = {
113 let mut server = self.server.borrow_mut();
114 server.sync()
115 };
116 future.await
117 }
118
119 fn memory_usage(&self) -> crate::memory_management::MemoryUsage {
120 self.server.borrow_mut().memory_usage()
121 }
122
123 fn memory_cleanup(&self) {
124 self.server.borrow_mut().memory_cleanup();
125 }
126
127 fn start_profile(&self) {
128 self.server.borrow_mut().start_profile()
129 }
130
131 fn end_profile(&self) -> ProfileDuration {
132 self.server.borrow_mut().end_profile()
133 }
134}
135
136unsafe impl<Server: ComputeServer> Send for RefCellComputeChannel<Server> {}
139unsafe impl<Server: ComputeServer> Sync for RefCellComputeChannel<Server> {}