cubecl_core/frontend/container/array/
launch.rs1use std::{marker::PhantomData, num::NonZero};
2
3use serde::{Deserialize, Serialize};
4
5use crate::{
6 compute::{KernelBuilder, KernelLauncher},
7 ir::{Id, Item, Vectorization},
8 prelude::{
9 ArgSettings, CompilationArg, CubePrimitive, ExpandElementTyped, LaunchArg, LaunchArgExpand,
10 TensorHandleRef,
11 },
12 Runtime,
13};
14
15use super::Array;
16
17#[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize, Deserialize)]
18pub struct ArrayCompilationArg {
19 pub inplace: Option<Id>,
20 pub vectorisation: Vectorization,
21}
22
23impl CompilationArg for ArrayCompilationArg {}
24
25pub struct ArrayHandleRef<'a, R: Runtime> {
27 pub handle: &'a cubecl_runtime::server::Handle,
28 pub(crate) length: [usize; 1],
29 pub elem_size: usize,
30 runtime: PhantomData<R>,
31}
32
33impl<C: CubePrimitive> LaunchArgExpand for Array<C> {
34 type CompilationArg = ArrayCompilationArg;
35
36 fn expand(
37 arg: &Self::CompilationArg,
38 builder: &mut KernelBuilder,
39 ) -> ExpandElementTyped<Array<C>> {
40 builder
41 .input_array(Item::vectorized(
42 C::as_elem(&builder.context),
43 arg.vectorisation,
44 ))
45 .into()
46 }
47 fn expand_output(
48 arg: &Self::CompilationArg,
49 builder: &mut KernelBuilder,
50 ) -> ExpandElementTyped<Array<C>> {
51 match arg.inplace {
52 Some(id) => builder.inplace_output(id).into(),
53 None => builder
54 .output_array(Item::vectorized(
55 C::as_elem(&builder.context),
56 arg.vectorisation,
57 ))
58 .into(),
59 }
60 }
61}
62
63pub enum ArrayArg<'a, R: Runtime> {
64 Handle {
66 handle: ArrayHandleRef<'a, R>,
68 vectorization_factor: u8,
70 },
71 Alias {
73 input_pos: usize,
75 },
76}
77
78impl<R: Runtime> ArgSettings<R> for ArrayArg<'_, R> {
79 fn register(&self, launcher: &mut KernelLauncher<R>) {
80 launcher.register_array(self)
81 }
82}
83
84impl<'a, R: Runtime> ArrayArg<'a, R> {
85 pub unsafe fn from_raw_parts<E: CubePrimitive>(
91 handle: &'a cubecl_runtime::server::Handle,
92 length: usize,
93 vectorization_factor: u8,
94 ) -> Self {
95 ArrayArg::Handle {
96 handle: ArrayHandleRef::from_raw_parts(
97 handle,
98 length,
99 E::size().expect("Element should have a size"),
100 ),
101 vectorization_factor,
102 }
103 }
104
105 pub unsafe fn from_raw_parts_and_size(
111 handle: &'a cubecl_runtime::server::Handle,
112 length: usize,
113 vectorization_factor: u8,
114 elem_size: usize,
115 ) -> Self {
116 ArrayArg::Handle {
117 handle: ArrayHandleRef::from_raw_parts(handle, length, elem_size),
118 vectorization_factor,
119 }
120 }
121}
122
123impl<'a, R: Runtime> ArrayHandleRef<'a, R> {
124 pub unsafe fn from_raw_parts(
130 handle: &'a cubecl_runtime::server::Handle,
131 length: usize,
132 elem_size: usize,
133 ) -> Self {
134 Self {
135 handle,
136 length: [length],
137 elem_size,
138 runtime: PhantomData,
139 }
140 }
141
142 pub fn as_tensor(&self) -> TensorHandleRef<'_, R> {
144 let shape = &self.length;
145
146 TensorHandleRef {
147 handle: self.handle,
148 strides: &[1],
149 shape,
150 elem_size: self.elem_size,
151 runtime: PhantomData,
152 }
153 }
154}
155
156impl<C: CubePrimitive> LaunchArg for Array<C> {
157 type RuntimeArg<'a, R: Runtime> = ArrayArg<'a, R>;
158
159 fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
160 match runtime_arg {
161 ArrayArg::Handle {
162 vectorization_factor,
163 ..
164 } => ArrayCompilationArg {
165 inplace: None,
166 vectorisation: Vectorization::Some(NonZero::new(*vectorization_factor).unwrap()),
167 },
168 ArrayArg::Alias { input_pos } => ArrayCompilationArg {
169 inplace: Some(*input_pos as Id),
170 vectorisation: Vectorization::None,
171 },
172 }
173 }
174}