cubecl_core/frontend/container/array/
launch.rs1use std::{marker::PhantomData, num::NonZero};
2
3use serde::{Deserialize, Serialize};
4
5use crate::{
6 Runtime,
7 compute::{KernelBuilder, KernelLauncher},
8 ir::{Id, Item, Vectorization},
9 prelude::{
10 ArgSettings, CompilationArg, CubePrimitive, ExpandElementTyped, LaunchArg, LaunchArgExpand,
11 TensorHandleRef,
12 },
13};
14
15use super::Array;
16
17#[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize, Deserialize)]
18pub struct ArrayCompilationArg {
19 pub inplace: Option<Id>,
20 pub vectorisation: Vectorization,
21}
22
23impl CompilationArg for ArrayCompilationArg {}
24
25pub struct ArrayHandleRef<'a, R: Runtime> {
27 pub handle: &'a cubecl_runtime::server::Handle,
28 pub(crate) length: [usize; 1],
29 pub elem_size: usize,
30 runtime: PhantomData<R>,
31}
32
33impl<C: CubePrimitive> LaunchArgExpand for Array<C> {
34 type CompilationArg = ArrayCompilationArg;
35
36 fn expand(
37 arg: &Self::CompilationArg,
38 builder: &mut KernelBuilder,
39 ) -> ExpandElementTyped<Array<C>> {
40 builder
41 .input_array(Item::vectorized(
42 C::as_elem(&builder.context),
43 arg.vectorisation,
44 ))
45 .into()
46 }
47 fn expand_output(
48 arg: &Self::CompilationArg,
49 builder: &mut KernelBuilder,
50 ) -> ExpandElementTyped<Array<C>> {
51 match arg.inplace {
52 Some(id) => builder.inplace_output(id).into(),
53 None => builder
54 .output_array(Item::vectorized(
55 C::as_elem(&builder.context),
56 arg.vectorisation,
57 ))
58 .into(),
59 }
60 }
61}
62
63pub enum ArrayArg<'a, R: Runtime> {
64 Handle {
66 handle: ArrayHandleRef<'a, R>,
68 vectorization_factor: u8,
70 },
71 Alias {
73 input_pos: usize,
75 },
76}
77
78impl<R: Runtime> ArgSettings<R> for ArrayArg<'_, R> {
79 fn register(&self, launcher: &mut KernelLauncher<R>) {
80 launcher.register_array(self)
81 }
82}
83
84impl<'a, R: Runtime> ArrayArg<'a, R> {
85 pub unsafe fn from_raw_parts<E: CubePrimitive>(
91 handle: &'a cubecl_runtime::server::Handle,
92 length: usize,
93 vectorization_factor: u8,
94 ) -> Self {
95 unsafe {
96 ArrayArg::Handle {
97 handle: ArrayHandleRef::from_raw_parts(
98 handle,
99 length,
100 E::size().expect("Element should have a size"),
101 ),
102 vectorization_factor,
103 }
104 }
105 }
106
107 pub unsafe fn from_raw_parts_and_size(
113 handle: &'a cubecl_runtime::server::Handle,
114 length: usize,
115 vectorization_factor: u8,
116 elem_size: usize,
117 ) -> Self {
118 unsafe {
119 ArrayArg::Handle {
120 handle: ArrayHandleRef::from_raw_parts(handle, length, elem_size),
121 vectorization_factor,
122 }
123 }
124 }
125}
126
127impl<'a, R: Runtime> ArrayHandleRef<'a, R> {
128 pub unsafe fn from_raw_parts(
134 handle: &'a cubecl_runtime::server::Handle,
135 length: usize,
136 elem_size: usize,
137 ) -> Self {
138 Self {
139 handle,
140 length: [length],
141 elem_size,
142 runtime: PhantomData,
143 }
144 }
145
146 pub fn as_tensor(&self) -> TensorHandleRef<'_, R> {
148 let shape = &self.length;
149
150 TensorHandleRef {
151 handle: self.handle,
152 strides: &[1],
153 shape,
154 elem_size: self.elem_size,
155 runtime: PhantomData,
156 }
157 }
158}
159
160impl<C: CubePrimitive> LaunchArg for Array<C> {
161 type RuntimeArg<'a, R: Runtime> = ArrayArg<'a, R>;
162
163 fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
164 match runtime_arg {
165 ArrayArg::Handle {
166 vectorization_factor,
167 ..
168 } => ArrayCompilationArg {
169 inplace: None,
170 vectorisation: Vectorization::Some(NonZero::new(*vectorization_factor).unwrap()),
171 },
172 ArrayArg::Alias { input_pos } => ArrayCompilationArg {
173 inplace: Some(*input_pos as Id),
174 vectorisation: Vectorization::None,
175 },
176 }
177 }
178}