Skip to main content

cubecl_core/compute/
launcher.rs

1use alloc::{boxed::Box, collections::BTreeMap, vec, vec::Vec};
2use core::marker::PhantomData;
3
4use crate::prelude::{ArrayArg, TensorArg, TensorMapArg, TensorMapKind};
5use crate::{CubeScalar, KernelSettings};
6use crate::{MetadataBuilder, Runtime};
7#[cfg(feature = "std")]
8use core::cell::RefCell;
9use cubecl_ir::{AddressType, StorageType};
10use cubecl_runtime::server::{Binding, CubeCount, LaunchError, ScalarBinding, TensorMapBinding};
11use cubecl_runtime::{
12    client::ComputeClient,
13    kernel::{CubeKernel, KernelTask},
14    server::Bindings,
15};
16
17/// Prepare a kernel for [launch](KernelLauncher::launch).
18pub struct KernelLauncher<R: Runtime> {
19    tensors: TensorState<R>,
20    scalars: ScalarState,
21    pub settings: KernelSettings,
22    runtime: PhantomData<R>,
23}
24
25impl<R: Runtime> KernelLauncher<R> {
26    /// Register a tensor to be launched.
27    pub fn register_tensor(&mut self, tensor: &TensorArg<'_, R>) {
28        self.tensors.push_tensor(tensor);
29    }
30
31    /// Register a mapped tensor to be launched.
32    pub fn register_tensor_map<K: TensorMapKind>(&mut self, tensor: &TensorMapArg<'_, R, K>) {
33        self.tensors.push_tensor_map(tensor);
34    }
35
36    /// Register an input array to be launched.
37    pub fn register_array(&mut self, array: &ArrayArg<'_, R>) {
38        self.tensors.push_array(array);
39    }
40
41    /// Register a scalar to be launched.
42    pub fn register_scalar<C: CubeScalar>(&mut self, scalar: C) {
43        self.scalars.push(scalar);
44    }
45
46    /// Register a scalar to be launched from raw data.
47    pub fn register_scalar_raw(&mut self, bytes: &[u8], dtype: StorageType) {
48        self.scalars.push_raw(bytes, dtype);
49    }
50
51    /// Launch the kernel.
52    #[track_caller]
53    pub fn launch<K: CubeKernel>(
54        self,
55        cube_count: CubeCount,
56        kernel: K,
57        client: &ComputeClient<R>,
58    ) -> Result<(), LaunchError> {
59        let bindings = self.into_bindings();
60        let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
61
62        client.launch(kernel, cube_count, bindings)
63    }
64
65    /// Launch the kernel without check bounds.
66    ///
67    /// # Safety
68    ///
69    /// The kernel must not:
70    /// - Contain any out of bounds reads or writes. Doing so is immediate UB.
71    /// - Contain any loops that never terminate. These may be optimized away entirely or cause
72    ///   other unpredictable behaviour.
73    #[track_caller]
74    pub unsafe fn launch_unchecked<K: CubeKernel>(
75        self,
76        cube_count: CubeCount,
77        kernel: K,
78        client: &ComputeClient<R>,
79    ) -> Result<(), LaunchError> {
80        unsafe {
81            let bindings = self.into_bindings();
82            let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
83
84            client.launch_unchecked(kernel, cube_count, bindings)
85        }
86    }
87
88    /// We need to create the bindings in the same order they are defined in the compilation step.
89    ///
90    /// The function [`crate::KernelIntegrator::integrate`] stars by registering the input tensors followed
91    /// by the output tensors. Then the tensor metadata, and the scalars at the end. The scalars
92    /// are registered in the same order they are added. This is why we store the scalar data type
93    /// in the `scalar_order` vector, so that we can register them in the same order.
94    ///
95    /// Also returns an ordered list of constant bindings. The ordering between constants and tensors
96    /// is up to the runtime.
97    fn into_bindings(self) -> Bindings {
98        let mut bindings = Bindings::new();
99
100        self.tensors.register(&mut bindings);
101        self.scalars.register(&mut bindings);
102
103        bindings
104    }
105}
106
107#[cfg(feature = "std")]
108std::thread_local! {
109    static METADATA: RefCell<MetadataBuilder> = RefCell::new(MetadataBuilder::default());
110}
111
112/// Handles the tensor state.
113pub enum TensorState<R: Runtime> {
114    /// No tensor is registered yet.
115    Empty { addr_type: AddressType },
116    /// The registered tensors.
117    Some {
118        buffers: Vec<Binding>,
119        tensor_maps: Vec<TensorMapBinding>,
120        addr_type: AddressType,
121        runtime: PhantomData<R>,
122        #[cfg(not(feature = "std"))]
123        metadata: MetadataBuilder,
124    },
125}
126
127/// Handles the scalar state of an element type
128///
129/// The scalars are grouped to reduce the number of buffers needed to send data to the compute device.
130#[derive(Default, Clone)]
131pub struct ScalarState {
132    data: BTreeMap<StorageType, ScalarValues>,
133}
134
135/// Stores the data and type for a scalar arg
136pub type ScalarValues = Vec<u8>;
137
138impl<R: Runtime> TensorState<R> {
139    fn maybe_init(&mut self) {
140        if let TensorState::Empty { addr_type } = self {
141            *self = TensorState::Some {
142                buffers: Vec::new(),
143                tensor_maps: Vec::new(),
144                addr_type: *addr_type,
145                runtime: PhantomData,
146                #[cfg(not(feature = "std"))]
147                metadata: MetadataBuilder::default(),
148            };
149        }
150    }
151
152    #[cfg(feature = "std")]
153    fn with_metadata<T>(&mut self, fun: impl FnMut(&mut MetadataBuilder) -> T) -> T {
154        METADATA.with_borrow_mut(fun)
155    }
156
157    #[cfg(not(feature = "std"))]
158    fn with_metadata<T>(&mut self, mut fun: impl FnMut(&mut MetadataBuilder) -> T) -> T {
159        self.maybe_init();
160        let TensorState::Some { metadata, .. } = self else {
161            panic!("Should be init");
162        };
163        fun(metadata)
164    }
165
166    fn buffers(&mut self) -> &mut Vec<Binding> {
167        self.maybe_init();
168        let TensorState::Some { buffers, .. } = self else {
169            panic!("Should be init");
170        };
171        buffers
172    }
173
174    fn tensor_maps(&mut self) -> &mut Vec<TensorMapBinding> {
175        self.maybe_init();
176        let TensorState::Some { tensor_maps, .. } = self else {
177            panic!("Should be init");
178        };
179        tensor_maps
180    }
181
182    fn address_type(&self) -> AddressType {
183        match self {
184            TensorState::Empty { addr_type } => *addr_type,
185            TensorState::Some { addr_type, .. } => *addr_type,
186        }
187    }
188
189    /// Push a new input tensor to the state.
190    pub fn push_tensor(&mut self, tensor: &TensorArg<'_, R>) {
191        if let Some(tensor) = self.process_tensor(tensor) {
192            self.buffers().push(tensor);
193        }
194    }
195
196    fn process_tensor(&mut self, tensor: &TensorArg<'_, R>) -> Option<Binding> {
197        let (tensor, vectorization) = match tensor {
198            TensorArg::Handle {
199                handle,
200                line_size: vectorization_factor,
201                ..
202            } => (handle, vectorization_factor),
203            TensorArg::Alias { .. } => return None,
204        };
205
206        let elem_size = tensor.elem_size * *vectorization;
207        let buffer_len = tensor.handle.size() / elem_size as u64;
208        let len = tensor.shape.iter().product::<usize>() / *vectorization;
209        let address_type = self.address_type();
210        self.with_metadata(|meta| {
211            meta.register_tensor(
212                tensor.strides.len() as u64,
213                buffer_len,
214                len as u64,
215                tensor.shape,
216                tensor.strides,
217                address_type,
218            )
219        });
220        Some(tensor.handle.clone().binding())
221    }
222
223    /// Push a new input array to the state.
224    pub fn push_array(&mut self, array: &ArrayArg<'_, R>) {
225        if let Some(tensor) = self.process_array(array) {
226            self.buffers().push(tensor);
227        }
228    }
229
230    fn process_array(&mut self, array: &ArrayArg<'_, R>) -> Option<Binding> {
231        let (array, vectorization) = match array {
232            ArrayArg::Handle {
233                handle,
234                line_size: vectorization_factor,
235                ..
236            } => (handle, vectorization_factor),
237            ArrayArg::Alias { .. } => return None,
238        };
239
240        let elem_size = array.elem_size * *vectorization;
241        let buffer_len = array.handle.size() / elem_size as u64;
242        let address_type = self.address_type();
243        self.with_metadata(|meta| {
244            meta.register_array(
245                buffer_len,
246                array.length[0] as u64 / *vectorization as u64,
247                address_type,
248            )
249        });
250        Some(array.handle.clone().binding())
251    }
252
253    /// Push a new tensor to the state.
254    pub fn push_tensor_map<K: TensorMapKind>(&mut self, map: &TensorMapArg<'_, R, K>) {
255        let binding = self
256            .process_tensor(&map.tensor)
257            .expect("Can't use alias for TensorMap");
258
259        let map = map.metadata.clone();
260        self.tensor_maps().push(TensorMapBinding { binding, map });
261    }
262
263    fn register(mut self, bindings_global: &mut Bindings) {
264        let metadata = matches!(self, Self::Some { .. }).then(|| {
265            let addr_type = self.address_type();
266            self.with_metadata(|meta| meta.finish(addr_type))
267        });
268        if let Self::Some {
269            buffers,
270            tensor_maps,
271            ..
272        } = self
273        {
274            let metadata = metadata.unwrap();
275
276            bindings_global.buffers = buffers;
277            bindings_global.tensor_maps = tensor_maps;
278            bindings_global.metadata = metadata;
279        }
280    }
281}
282
283impl ScalarState {
284    /// Add a new scalar value to the state.
285    pub fn push<T: CubeScalar>(&mut self, val: T) {
286        let val = [val];
287        let bytes = T::as_bytes(&val);
288        self.data
289            .entry(T::cube_type())
290            .or_default()
291            .extend(bytes.iter().copied());
292    }
293
294    /// Add a new raw value to the state.
295    pub fn push_raw(&mut self, bytes: &[u8], dtype: StorageType) {
296        self.data
297            .entry(dtype)
298            .or_default()
299            .extend(bytes.iter().copied());
300    }
301
302    fn register(&self, bindings: &mut Bindings) {
303        for (ty, values) in self.data.iter() {
304            let len = values.len() / ty.size();
305            let len_u64 = len.div_ceil(size_of::<u64>() / ty.size());
306
307            let mut data = vec![0; len_u64];
308            let slice = bytemuck::cast_slice_mut::<u64, u8>(&mut data);
309            slice[0..values.len()].copy_from_slice(values);
310            bindings
311                .scalars
312                .insert(*ty, ScalarBinding::new(*ty, len, data));
313        }
314    }
315}
316
317impl<R: Runtime> KernelLauncher<R> {
318    pub fn new(settings: KernelSettings) -> Self {
319        Self {
320            tensors: TensorState::Empty {
321                addr_type: settings.address_type,
322            },
323            scalars: Default::default(),
324            settings,
325            runtime: PhantomData,
326        }
327    }
328}