cubecl_core/frontend/container/array/
launch.rs

1use std::{marker::PhantomData, num::NonZero};
2
3use serde::{Deserialize, Serialize};
4
5use crate::{
6    Runtime,
7    compute::{KernelBuilder, KernelLauncher},
8    ir::{Id, Item, Vectorization},
9    prelude::{
10        ArgSettings, CompilationArg, CubePrimitive, ExpandElementTyped, LaunchArg, LaunchArgExpand,
11        TensorHandleRef,
12    },
13};
14
15use super::Array;
16
17#[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize, Deserialize)]
18pub struct ArrayCompilationArg {
19    pub inplace: Option<Id>,
20    pub vectorisation: Vectorization,
21}
22
23impl CompilationArg for ArrayCompilationArg {}
24
25/// Tensor representation with a reference to the [server handle](cubecl_runtime::server::Handle).
26pub struct ArrayHandleRef<'a, R: Runtime> {
27    pub handle: &'a cubecl_runtime::server::Handle,
28    pub(crate) length: [usize; 1],
29    pub elem_size: usize,
30    runtime: PhantomData<R>,
31}
32
33impl<C: CubePrimitive> LaunchArgExpand for Array<C> {
34    type CompilationArg = ArrayCompilationArg;
35
36    fn expand(
37        arg: &Self::CompilationArg,
38        builder: &mut KernelBuilder,
39    ) -> ExpandElementTyped<Array<C>> {
40        builder
41            .input_array(Item::vectorized(
42                C::as_elem(&builder.context),
43                arg.vectorisation,
44            ))
45            .into()
46    }
47    fn expand_output(
48        arg: &Self::CompilationArg,
49        builder: &mut KernelBuilder,
50    ) -> ExpandElementTyped<Array<C>> {
51        match arg.inplace {
52            Some(id) => builder.inplace_output(id).into(),
53            None => builder
54                .output_array(Item::vectorized(
55                    C::as_elem(&builder.context),
56                    arg.vectorisation,
57                ))
58                .into(),
59        }
60    }
61}
62
63pub enum ArrayArg<'a, R: Runtime> {
64    /// The array is passed with an array handle.
65    Handle {
66        /// The array handle.
67        handle: ArrayHandleRef<'a, R>,
68        /// The vectorization factor.
69        vectorization_factor: u8,
70    },
71    /// The array is aliasing another input array.
72    Alias {
73        /// The position of the input array.
74        input_pos: usize,
75    },
76}
77
78impl<R: Runtime> ArgSettings<R> for ArrayArg<'_, R> {
79    fn register(&self, launcher: &mut KernelLauncher<R>) {
80        launcher.register_array(self)
81    }
82}
83
84impl<'a, R: Runtime> ArrayArg<'a, R> {
85    /// Create a new array argument.
86    ///
87    /// # Safety
88    ///
89    /// Specifying the wrong length may lead to out-of-bounds reads and writes.
90    pub unsafe fn from_raw_parts<E: CubePrimitive>(
91        handle: &'a cubecl_runtime::server::Handle,
92        length: usize,
93        vectorization_factor: u8,
94    ) -> Self {
95        unsafe {
96            ArrayArg::Handle {
97                handle: ArrayHandleRef::from_raw_parts(
98                    handle,
99                    length,
100                    E::size().expect("Element should have a size"),
101                ),
102                vectorization_factor,
103            }
104        }
105    }
106
107    /// Create a new array argument with a manual element size in bytes.
108    ///
109    /// # Safety
110    ///
111    /// Specifying the wrong length may lead to out-of-bounds reads and writes.
112    pub unsafe fn from_raw_parts_and_size(
113        handle: &'a cubecl_runtime::server::Handle,
114        length: usize,
115        vectorization_factor: u8,
116        elem_size: usize,
117    ) -> Self {
118        unsafe {
119            ArrayArg::Handle {
120                handle: ArrayHandleRef::from_raw_parts(handle, length, elem_size),
121                vectorization_factor,
122            }
123        }
124    }
125}
126
127impl<'a, R: Runtime> ArrayHandleRef<'a, R> {
128    /// Create a new array handle reference.
129    ///
130    /// # Safety
131    ///
132    /// Specifying the wrong length may lead to out-of-bounds reads and writes.
133    pub unsafe fn from_raw_parts(
134        handle: &'a cubecl_runtime::server::Handle,
135        length: usize,
136        elem_size: usize,
137    ) -> Self {
138        Self {
139            handle,
140            length: [length],
141            elem_size,
142            runtime: PhantomData,
143        }
144    }
145
146    /// Return the handle as a tensor instead of an array.
147    pub fn as_tensor(&self) -> TensorHandleRef<'_, R> {
148        let shape = &self.length;
149
150        TensorHandleRef {
151            handle: self.handle,
152            strides: &[1],
153            shape,
154            elem_size: self.elem_size,
155            runtime: PhantomData,
156        }
157    }
158}
159
160impl<C: CubePrimitive> LaunchArg for Array<C> {
161    type RuntimeArg<'a, R: Runtime> = ArrayArg<'a, R>;
162
163    fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
164        match runtime_arg {
165            ArrayArg::Handle {
166                vectorization_factor,
167                ..
168            } => ArrayCompilationArg {
169                inplace: None,
170                vectorisation: Vectorization::Some(NonZero::new(*vectorization_factor).unwrap()),
171            },
172            ArrayArg::Alias { input_pos } => ArrayCompilationArg {
173                inplace: Some(*input_pos as Id),
174                vectorisation: Vectorization::None,
175            },
176        }
177    }
178}