cubecl_core/frontend/container/array/
launch.rs

1use std::marker::PhantomData;
2
3use serde::{Deserialize, Serialize};
4
5use crate::{
6    Runtime,
7    compute::{KernelBuilder, KernelLauncher},
8    ir::{Id, LineSize, Type},
9    prelude::{
10        ArgSettings, CompilationArg, CubePrimitive, ExpandElementTyped, LaunchArg, TensorHandleRef,
11    },
12};
13
14use super::Array;
15
16#[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize, Deserialize)]
17pub struct ArrayCompilationArg {
18    pub inplace: Option<Id>,
19    pub line_size: LineSize,
20}
21
22impl CompilationArg for ArrayCompilationArg {}
23
24/// Tensor representation with a reference to the [server handle](cubecl_runtime::server::Handle).
25pub struct ArrayHandleRef<'a, R: Runtime> {
26    pub handle: &'a cubecl_runtime::server::Handle,
27    pub(crate) length: [usize; 1],
28    pub elem_size: usize,
29    runtime: PhantomData<R>,
30}
31
32pub enum ArrayArg<'a, R: Runtime> {
33    /// The array is passed with an array handle.
34    Handle {
35        /// The array handle.
36        handle: ArrayHandleRef<'a, R>,
37        /// The vectorization factor.
38        line_size: u8,
39    },
40    /// The array is aliasing another input array.
41    Alias {
42        /// The position of the input array.
43        input_pos: usize,
44    },
45}
46
47impl<R: Runtime> ArgSettings<R> for ArrayArg<'_, R> {
48    fn register(&self, launcher: &mut KernelLauncher<R>) {
49        launcher.register_array(self)
50    }
51}
52
53impl<'a, R: Runtime> ArrayArg<'a, R> {
54    /// Create a new array argument.
55    ///
56    /// # Safety
57    ///
58    /// Specifying the wrong length may lead to out-of-bounds reads and writes.
59    pub unsafe fn from_raw_parts<E: CubePrimitive>(
60        handle: &'a cubecl_runtime::server::Handle,
61        length: usize,
62        line_size: u8,
63    ) -> Self {
64        unsafe {
65            ArrayArg::Handle {
66                handle: ArrayHandleRef::from_raw_parts(
67                    handle,
68                    length,
69                    E::size().expect("Element should have a size"),
70                ),
71                line_size,
72            }
73        }
74    }
75
76    /// Create a new array argument with a manual element size in bytes.
77    ///
78    /// # Safety
79    ///
80    /// Specifying the wrong length may lead to out-of-bounds reads and writes.
81    pub unsafe fn from_raw_parts_and_size(
82        handle: &'a cubecl_runtime::server::Handle,
83        length: usize,
84        line_size: u8,
85        elem_size: usize,
86    ) -> Self {
87        unsafe {
88            ArrayArg::Handle {
89                handle: ArrayHandleRef::from_raw_parts(handle, length, elem_size),
90                line_size,
91            }
92        }
93    }
94}
95
96impl<'a, R: Runtime> ArrayHandleRef<'a, R> {
97    /// Create a new array handle reference.
98    ///
99    /// # Safety
100    ///
101    /// Specifying the wrong length may lead to out-of-bounds reads and writes.
102    pub unsafe fn from_raw_parts(
103        handle: &'a cubecl_runtime::server::Handle,
104        length: usize,
105        elem_size: usize,
106    ) -> Self {
107        Self {
108            handle,
109            length: [length],
110            elem_size,
111            runtime: PhantomData,
112        }
113    }
114
115    /// Return the handle as a tensor instead of an array.
116    pub fn as_tensor(&self) -> TensorHandleRef<'_, R> {
117        let shape = &self.length;
118
119        TensorHandleRef {
120            handle: self.handle,
121            strides: &[1],
122            shape,
123            elem_size: self.elem_size,
124            runtime: PhantomData,
125        }
126    }
127}
128
129impl<C: CubePrimitive> LaunchArg for Array<C> {
130    type RuntimeArg<'a, R: Runtime> = ArrayArg<'a, R>;
131    type CompilationArg = ArrayCompilationArg;
132
133    fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
134        match runtime_arg {
135            ArrayArg::Handle { line_size, .. } => ArrayCompilationArg {
136                inplace: None,
137                line_size: *line_size as u32,
138            },
139            ArrayArg::Alias { input_pos } => ArrayCompilationArg {
140                inplace: Some(*input_pos as Id),
141                line_size: 0,
142            },
143        }
144    }
145
146    fn expand(
147        arg: &Self::CompilationArg,
148        builder: &mut KernelBuilder,
149    ) -> ExpandElementTyped<Array<C>> {
150        builder
151            .input_array(Type::new(C::as_type(&builder.scope)).line(arg.line_size))
152            .into()
153    }
154    fn expand_output(
155        arg: &Self::CompilationArg,
156        builder: &mut KernelBuilder,
157    ) -> ExpandElementTyped<Array<C>> {
158        match arg.inplace {
159            Some(id) => builder.inplace_output(id).into(),
160            None => builder
161                .output_array(Type::new(C::as_type(&builder.scope)).line(arg.line_size))
162                .into(),
163        }
164    }
165}