cubecl_core/compute/
launcher.rs1use alloc::{boxed::Box, vec::Vec};
2use core::marker::PhantomData;
3
4use crate::Runtime;
5use crate::prelude::{ArrayArg, TensorArg, TensorMapArg, TensorMapKind};
6use crate::{InfoBuilder, KernelSettings, ScalarArgType};
7#[cfg(feature = "std")]
8use core::cell::RefCell;
9use cubecl_ir::{AddressType, Scope, StorageType, Type};
10use cubecl_runtime::server::{Binding, CubeCount, TensorMapBinding};
11use cubecl_runtime::{
12 client::ComputeClient,
13 kernel::{CubeKernel, KernelTask},
14 server::KernelArguments,
15};
16
17#[cfg(feature = "std")]
18std::thread_local! {
19 static INFO: RefCell<InfoBuilder> = RefCell::new(InfoBuilder::default());
20 static SCOPE: RefCell<Scope> = RefCell::new(Scope::root(false));
22}
23
24pub struct KernelLauncher<R: Runtime> {
26 buffers: Vec<Binding>,
27 tensor_maps: Vec<TensorMapBinding>,
28 address_type: AddressType,
29 pub settings: KernelSettings,
30 #[cfg(not(feature = "std"))]
31 info: InfoBuilder,
32 #[cfg(not(feature = "std"))]
33 pub scope: Scope,
34 _runtime: PhantomData<R>,
35}
36
37impl<R: Runtime> KernelLauncher<R> {
38 #[cfg(feature = "std")]
39 pub fn with_scope<T>(&mut self, fun: impl FnMut(&mut Scope) -> T) -> T {
40 SCOPE.with_borrow_mut(fun)
41 }
42
43 #[cfg(not(feature = "std"))]
44 pub fn with_scope<T>(&mut self, mut fun: impl FnMut(&mut Scope) -> T) -> T {
45 fun(&mut self.scope)
46 }
47
48 #[cfg(feature = "std")]
49 fn with_info<T>(&mut self, fun: impl FnMut(&mut InfoBuilder) -> T) -> T {
50 INFO.with_borrow_mut(fun)
51 }
52
53 #[cfg(not(feature = "std"))]
54 fn with_info<T>(&mut self, mut fun: impl FnMut(&mut InfoBuilder) -> T) -> T {
55 fun(&mut self.info)
56 }
57
58 pub fn register_scalar<C: ScalarArgType>(&mut self, scalar: C) {
60 self.with_info(|info| info.scalars.push(scalar));
61 }
62
63 pub fn register_scalar_raw(&mut self, bytes: &[u8], dtype: StorageType) {
65 self.with_info(|info| info.scalars.push_raw(bytes, dtype));
66 }
67
68 #[track_caller]
70 pub fn launch<K: CubeKernel>(
71 self,
72 cube_count: CubeCount,
73 kernel: K,
74 client: &ComputeClient<R>,
75 ) {
76 let bindings = self.into_bindings();
77 let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
78
79 client.launch(kernel, cube_count, bindings)
80 }
81
82 #[track_caller]
91 pub unsafe fn launch_unchecked<K: CubeKernel>(
92 self,
93 cube_count: CubeCount,
94 kernel: K,
95 client: &ComputeClient<R>,
96 ) {
97 unsafe {
98 let bindings = self.into_bindings();
99 let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
100
101 client.launch_unchecked(kernel, cube_count, bindings)
102 }
103 }
104
105 fn into_bindings(mut self) -> KernelArguments {
115 let mut bindings = KernelArguments::new();
116 let address_type = self.address_type;
117 let info = self.with_info(|info| info.finish(address_type));
118
119 bindings.buffers = self.buffers;
120 bindings.tensor_maps = self.tensor_maps;
121 bindings.info = info;
122
123 bindings
124 }
125}
126
127impl<R: Runtime> KernelLauncher<R> {
129 pub fn register_tensor(&mut self, tensor: TensorArg<R>, ty: Type) {
131 if let Some(tensor) = self.process_tensor(tensor, ty) {
132 self.buffers.push(tensor);
133 }
134 }
135
136 fn process_tensor(&mut self, tensor: TensorArg<R>, ty: Type) -> Option<Binding> {
137 let tensor = match tensor {
138 TensorArg::Handle { handle, .. } => handle,
139 TensorArg::Alias { .. } => return None,
140 };
141
142 let elem_size = ty.size();
143 let vectorization = ty.vector_size();
144
145 let buffer_len = tensor.handle.size_in_used() / elem_size as u64;
146 let len = tensor.shape.iter().product::<usize>() / vectorization;
147 let address_type = self.address_type;
148 self.with_info(|info| {
149 info.metadata.register_tensor(
150 tensor.strides.len() as u64,
151 buffer_len,
152 len as u64,
153 tensor.shape.clone(),
154 tensor.strides.clone(),
155 address_type,
156 )
157 });
158 Some(tensor.handle)
159 }
160
161 pub fn register_array(&mut self, array: ArrayArg<R>, ty: Type) {
163 if let Some(tensor) = self.process_array(array, ty) {
164 self.buffers.push(tensor);
165 }
166 }
167
168 fn process_array(&mut self, array: ArrayArg<R>, ty: Type) -> Option<Binding> {
169 let array = match array {
170 ArrayArg::Handle { handle, .. } => handle,
171 ArrayArg::Alias { .. } => return None,
172 };
173
174 let elem_size = ty.size();
175 let vectorization = ty.vector_size();
176
177 let buffer_len = array.handle.size_in_used() / elem_size as u64;
178 let address_type = self.address_type;
179 self.with_info(|info| {
180 info.metadata.register_array(
181 buffer_len,
182 array.length[0] as u64 / vectorization as u64,
183 address_type,
184 )
185 });
186 Some(array.handle)
187 }
188
189 pub fn register_tensor_map<K: TensorMapKind>(&mut self, map: TensorMapArg<R, K>, ty: Type) {
191 let binding = self
192 .process_tensor(map.tensor, ty)
193 .expect("Can't use alias for TensorMap");
194
195 let map = map.metadata.clone();
196 self.tensor_maps.push(TensorMapBinding { binding, map });
197 }
198}
199
200impl<R: Runtime> KernelLauncher<R> {
201 pub fn new(settings: KernelSettings) -> Self {
202 Self {
203 address_type: settings.address_type,
204 settings,
205 buffers: Vec::new(),
206 tensor_maps: Vec::new(),
207 _runtime: PhantomData,
208 #[cfg(not(feature = "std"))]
209 info: InfoBuilder::default(),
210 #[cfg(not(feature = "std"))]
211 scope: Scope::root(false),
212 }
213 }
214}