cubecl_core/frontend/container/array/
launch.rs1use std::marker::PhantomData;
2
3use serde::{Deserialize, Serialize};
4
5use crate::{
6 Runtime,
7 compute::{KernelBuilder, KernelLauncher},
8 ir::{Id, LineSize, Type},
9 prelude::{
10 ArgSettings, CompilationArg, CubePrimitive, ExpandElementTyped, LaunchArg, TensorHandleRef,
11 },
12};
13
14use super::Array;
15
16#[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize, Deserialize)]
17pub struct ArrayCompilationArg {
18 pub inplace: Option<Id>,
19 pub line_size: LineSize,
20}
21
22impl CompilationArg for ArrayCompilationArg {}
23
24pub struct ArrayHandleRef<'a, R: Runtime> {
26 pub handle: &'a cubecl_runtime::server::Handle,
27 pub(crate) length: [usize; 1],
28 pub elem_size: usize,
29 runtime: PhantomData<R>,
30}
31
32pub enum ArrayArg<'a, R: Runtime> {
33 Handle {
35 handle: ArrayHandleRef<'a, R>,
37 line_size: u8,
39 },
40 Alias {
42 input_pos: usize,
44 },
45}
46
47impl<R: Runtime> ArgSettings<R> for ArrayArg<'_, R> {
48 fn register(&self, launcher: &mut KernelLauncher<R>) {
49 launcher.register_array(self)
50 }
51}
52
53impl<'a, R: Runtime> ArrayArg<'a, R> {
54 pub unsafe fn from_raw_parts<E: CubePrimitive>(
60 handle: &'a cubecl_runtime::server::Handle,
61 length: usize,
62 line_size: u8,
63 ) -> Self {
64 unsafe {
65 ArrayArg::Handle {
66 handle: ArrayHandleRef::from_raw_parts(
67 handle,
68 length,
69 E::size().expect("Element should have a size"),
70 ),
71 line_size,
72 }
73 }
74 }
75
76 pub unsafe fn from_raw_parts_and_size(
82 handle: &'a cubecl_runtime::server::Handle,
83 length: usize,
84 line_size: u8,
85 elem_size: usize,
86 ) -> Self {
87 unsafe {
88 ArrayArg::Handle {
89 handle: ArrayHandleRef::from_raw_parts(handle, length, elem_size),
90 line_size,
91 }
92 }
93 }
94}
95
96impl<'a, R: Runtime> ArrayHandleRef<'a, R> {
97 pub unsafe fn from_raw_parts(
103 handle: &'a cubecl_runtime::server::Handle,
104 length: usize,
105 elem_size: usize,
106 ) -> Self {
107 Self {
108 handle,
109 length: [length],
110 elem_size,
111 runtime: PhantomData,
112 }
113 }
114
115 pub fn as_tensor(&self) -> TensorHandleRef<'_, R> {
117 let shape = &self.length;
118
119 TensorHandleRef {
120 handle: self.handle,
121 strides: &[1],
122 shape,
123 elem_size: self.elem_size,
124 runtime: PhantomData,
125 }
126 }
127}
128
129impl<C: CubePrimitive> LaunchArg for Array<C> {
130 type RuntimeArg<'a, R: Runtime> = ArrayArg<'a, R>;
131 type CompilationArg = ArrayCompilationArg;
132
133 fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
134 match runtime_arg {
135 ArrayArg::Handle { line_size, .. } => ArrayCompilationArg {
136 inplace: None,
137 line_size: *line_size as u32,
138 },
139 ArrayArg::Alias { input_pos } => ArrayCompilationArg {
140 inplace: Some(*input_pos as Id),
141 line_size: 0,
142 },
143 }
144 }
145
146 fn expand(
147 arg: &Self::CompilationArg,
148 builder: &mut KernelBuilder,
149 ) -> ExpandElementTyped<Array<C>> {
150 builder
151 .input_array(Type::new(C::as_type(&builder.scope)).line(arg.line_size))
152 .into()
153 }
154 fn expand_output(
155 arg: &Self::CompilationArg,
156 builder: &mut KernelBuilder,
157 ) -> ExpandElementTyped<Array<C>> {
158 match arg.inplace {
159 Some(id) => builder.inplace_output(id).into(),
160 None => builder
161 .output_array(Type::new(C::as_type(&builder.scope)).line(arg.line_size))
162 .into(),
163 }
164 }
165}