cubecl_core/frontend/container/tensor/
launch.rs1use std::marker::PhantomData;
2
3use serde::{Deserialize, Serialize};
4
5use crate::{
6 Runtime,
7 compute::{KernelBuilder, KernelLauncher},
8 ir::{Id, LineSize, Type},
9 prelude::{
10 ArgSettings, ArrayArg, CompilationArg, CubePrimitive, ExpandElementTyped, LaunchArg,
11 },
12};
13
14use super::Tensor;
15
16#[derive(Debug)]
18pub enum TensorArg<'a, R: Runtime> {
19 Handle {
21 handle: TensorHandleRef<'a, R>,
23 line_size: u8,
25 },
26 Alias {
28 input_pos: usize,
30 },
31}
32
33pub struct TensorHandleRef<'a, R: Runtime> {
36 pub handle: &'a cubecl_runtime::server::Handle,
37 pub strides: &'a [usize],
38 pub shape: &'a [usize],
39 pub elem_size: usize,
40 pub runtime: PhantomData<R>,
41}
42
43impl<'a, R: Runtime> Clone for TensorHandleRef<'a, R> {
44 fn clone(&self) -> Self {
45 *self
46 }
47}
48
49impl<'a, R: Runtime> Copy for TensorHandleRef<'a, R> {}
50
51impl<R: Runtime> TensorHandleRef<'_, R> {
52 pub fn size(&self) -> usize {
53 self.shape.iter().product()
54 }
55}
56
57impl<R: Runtime> core::fmt::Debug for TensorHandleRef<'_, R> {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 writeln!(
60 f,
61 "TensorHandleRef {{ strides: {:?}, shape: {:?} }}",
62 self.strides, self.shape
63 )
64 }
65}
66
67#[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize, Deserialize)]
69pub struct TensorCompilationArg {
70 pub inplace: Option<Id>,
71 pub line_size: LineSize,
72}
73
74impl CompilationArg for TensorCompilationArg {}
75
76impl<C: CubePrimitive> LaunchArg for Tensor<C> {
77 type RuntimeArg<'a, R: Runtime> = TensorArg<'a, R>;
78 type CompilationArg = TensorCompilationArg;
79
80 fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
81 match runtime_arg {
82 TensorArg::Handle { line_size, .. } => TensorCompilationArg {
83 inplace: None,
84 line_size: *line_size as u32,
85 },
86 TensorArg::Alias { input_pos } => TensorCompilationArg {
87 inplace: Some(*input_pos as Id),
88 line_size: 0,
89 },
90 }
91 }
92
93 fn expand(
94 arg: &Self::CompilationArg,
95 builder: &mut KernelBuilder,
96 ) -> ExpandElementTyped<Tensor<C>> {
97 builder
98 .input_tensor(Type::new(C::as_type(&builder.scope)).line(arg.line_size))
99 .into()
100 }
101 fn expand_output(
102 arg: &Self::CompilationArg,
103 builder: &mut KernelBuilder,
104 ) -> ExpandElementTyped<Tensor<C>> {
105 match arg.inplace {
106 Some(id) => builder.inplace_output(id).into(),
107 None => builder
108 .output_tensor(Type::new(C::as_type(&builder.scope)).line(arg.line_size))
109 .into(),
110 }
111 }
112}
113
114impl<'a, R: Runtime> TensorArg<'a, R> {
115 pub unsafe fn from_raw_parts<E: CubePrimitive>(
122 handle: &'a cubecl_runtime::server::Handle,
123 strides: &'a [usize],
124 shape: &'a [usize],
125 factor: u8,
126 ) -> Self {
127 unsafe {
128 Self::Handle {
129 handle: TensorHandleRef::from_raw_parts(
130 handle,
131 strides,
132 shape,
133 E::size().expect("Element should have a size"),
134 ),
135 line_size: factor,
136 }
137 }
138 }
139
140 pub unsafe fn from_raw_parts_and_size(
148 handle: &'a cubecl_runtime::server::Handle,
149 strides: &'a [usize],
150 shape: &'a [usize],
151 factor: u8,
152 elem_size: usize,
153 ) -> Self {
154 unsafe {
155 Self::Handle {
156 handle: TensorHandleRef::from_raw_parts(handle, strides, shape, elem_size),
157 line_size: factor,
158 }
159 }
160 }
161
162 pub fn alias(position: usize) -> Self {
164 Self::Alias {
165 input_pos: position,
166 }
167 }
168}
169
170impl<R: Runtime> ArgSettings<R> for TensorArg<'_, R> {
171 fn register(&self, launcher: &mut KernelLauncher<R>) {
172 launcher.register_tensor(self);
173 }
174}
175
176impl<'a, R: Runtime> TensorHandleRef<'a, R> {
177 pub fn as_tensor_arg(&'a self, vectorisation: u8) -> TensorArg<'a, R> {
179 unsafe {
180 TensorArg::from_raw_parts_and_size(
181 self.handle,
182 self.strides,
183 self.shape,
184 vectorisation,
185 self.elem_size,
186 )
187 }
188 }
189 pub fn as_array_arg(&'a self, line_size: u8) -> ArrayArg<'a, R> {
191 let length = self.shape.iter().product();
192 unsafe { ArrayArg::from_raw_parts_and_size(self.handle, length, line_size, self.elem_size) }
193 }
194 pub unsafe fn from_raw_parts(
201 handle: &'a cubecl_runtime::server::Handle,
202 strides: &'a [usize],
203 shape: &'a [usize],
204 elem_size: usize,
205 ) -> Self {
206 Self {
207 handle,
208 strides,
209 shape,
210 elem_size,
211 runtime: PhantomData,
212 }
213 }
214}