cubecl_core/frontend/container/array/
launch.rs

1use std::{marker::PhantomData, num::NonZero};
2
3use serde::{Deserialize, Serialize};
4
5use crate::{
6    compute::{KernelBuilder, KernelLauncher},
7    ir::{Id, Item, Vectorization},
8    prelude::{
9        ArgSettings, CompilationArg, CubePrimitive, ExpandElementTyped, LaunchArg, LaunchArgExpand,
10        TensorHandleRef,
11    },
12    Runtime,
13};
14
15use super::Array;
16
17#[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize, Deserialize)]
18pub struct ArrayCompilationArg {
19    pub inplace: Option<Id>,
20    pub vectorisation: Vectorization,
21}
22
23impl CompilationArg for ArrayCompilationArg {}
24
25/// Tensor representation with a reference to the [server handle](cubecl_runtime::server::Handle).
26pub struct ArrayHandleRef<'a, R: Runtime> {
27    pub handle: &'a cubecl_runtime::server::Handle,
28    pub(crate) length: [usize; 1],
29    pub elem_size: usize,
30    runtime: PhantomData<R>,
31}
32
33impl<C: CubePrimitive> LaunchArgExpand for Array<C> {
34    type CompilationArg = ArrayCompilationArg;
35
36    fn expand(
37        arg: &Self::CompilationArg,
38        builder: &mut KernelBuilder,
39    ) -> ExpandElementTyped<Array<C>> {
40        builder
41            .input_array(Item::vectorized(
42                C::as_elem(&builder.context),
43                arg.vectorisation,
44            ))
45            .into()
46    }
47    fn expand_output(
48        arg: &Self::CompilationArg,
49        builder: &mut KernelBuilder,
50    ) -> ExpandElementTyped<Array<C>> {
51        match arg.inplace {
52            Some(id) => builder.inplace_output(id).into(),
53            None => builder
54                .output_array(Item::vectorized(
55                    C::as_elem(&builder.context),
56                    arg.vectorisation,
57                ))
58                .into(),
59        }
60    }
61}
62
63pub enum ArrayArg<'a, R: Runtime> {
64    /// The array is passed with an array handle.
65    Handle {
66        /// The array handle.
67        handle: ArrayHandleRef<'a, R>,
68        /// The vectorization factor.
69        vectorization_factor: u8,
70    },
71    /// The array is aliasing another input array.
72    Alias {
73        /// The position of the input array.
74        input_pos: usize,
75    },
76}
77
78impl<R: Runtime> ArgSettings<R> for ArrayArg<'_, R> {
79    fn register(&self, launcher: &mut KernelLauncher<R>) {
80        launcher.register_array(self)
81    }
82}
83
84impl<'a, R: Runtime> ArrayArg<'a, R> {
85    /// Create a new array argument.
86    ///
87    /// # Safety
88    ///
89    /// Specifying the wrong length may lead to out-of-bounds reads and writes.
90    pub unsafe fn from_raw_parts<E: CubePrimitive>(
91        handle: &'a cubecl_runtime::server::Handle,
92        length: usize,
93        vectorization_factor: u8,
94    ) -> Self {
95        ArrayArg::Handle {
96            handle: ArrayHandleRef::from_raw_parts(
97                handle,
98                length,
99                E::size().expect("Element should have a size"),
100            ),
101            vectorization_factor,
102        }
103    }
104
105    /// Create a new array argument with a manual element size in bytes.
106    ///
107    /// # Safety
108    ///
109    /// Specifying the wrong length may lead to out-of-bounds reads and writes.
110    pub unsafe fn from_raw_parts_and_size(
111        handle: &'a cubecl_runtime::server::Handle,
112        length: usize,
113        vectorization_factor: u8,
114        elem_size: usize,
115    ) -> Self {
116        ArrayArg::Handle {
117            handle: ArrayHandleRef::from_raw_parts(handle, length, elem_size),
118            vectorization_factor,
119        }
120    }
121}
122
123impl<'a, R: Runtime> ArrayHandleRef<'a, R> {
124    /// Create a new array handle reference.
125    ///
126    /// # Safety
127    ///
128    /// Specifying the wrong length may lead to out-of-bounds reads and writes.
129    pub unsafe fn from_raw_parts(
130        handle: &'a cubecl_runtime::server::Handle,
131        length: usize,
132        elem_size: usize,
133    ) -> Self {
134        Self {
135            handle,
136            length: [length],
137            elem_size,
138            runtime: PhantomData,
139        }
140    }
141
142    /// Return the handle as a tensor instead of an array.
143    pub fn as_tensor(&self) -> TensorHandleRef<'_, R> {
144        let shape = &self.length;
145
146        TensorHandleRef {
147            handle: self.handle,
148            strides: &[1],
149            shape,
150            elem_size: self.elem_size,
151            runtime: PhantomData,
152        }
153    }
154}
155
156impl<C: CubePrimitive> LaunchArg for Array<C> {
157    type RuntimeArg<'a, R: Runtime> = ArrayArg<'a, R>;
158
159    fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
160        match runtime_arg {
161            ArrayArg::Handle {
162                vectorization_factor,
163                ..
164            } => ArrayCompilationArg {
165                inplace: None,
166                vectorisation: Vectorization::Some(NonZero::new(*vectorization_factor).unwrap()),
167            },
168            ArrayArg::Alias { input_pos } => ArrayCompilationArg {
169                inplace: Some(*input_pos as Id),
170                vectorisation: Vectorization::None,
171            },
172        }
173    }
174}