cubecl_core/frontend/container/tensor/
launch.rs1use std::{marker::PhantomData, num::NonZero};
2
3use serde::{Deserialize, Serialize};
4
5use crate::{
6 Runtime,
7 compute::{KernelBuilder, KernelLauncher},
8 ir::{Id, Item, Vectorization},
9 prelude::{
10 ArgSettings, CompilationArg, CubePrimitive, ExpandElementTyped, LaunchArg, LaunchArgExpand,
11 },
12};
13
14use super::Tensor;
15
16#[derive(Debug)]
18pub enum TensorArg<'a, R: Runtime> {
19 Handle {
21 handle: TensorHandleRef<'a, R>,
23 vectorization_factor: u8,
25 },
26 Alias {
28 input_pos: usize,
30 },
31}
32
33pub struct TensorHandleRef<'a, R: Runtime> {
36 pub handle: &'a cubecl_runtime::server::Handle,
37 pub strides: &'a [usize],
38 pub shape: &'a [usize],
39 pub elem_size: usize,
40 pub runtime: PhantomData<R>,
41}
42
43impl<R: Runtime> TensorHandleRef<'_, R> {
44 pub fn size(&self) -> usize {
45 self.shape.iter().product()
46 }
47}
48
49impl<R: Runtime> core::fmt::Debug for TensorHandleRef<'_, R> {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 writeln!(
52 f,
53 "TensorHandleRef {{ strides: {:?}, shape: {:?} }}",
54 self.strides, self.shape
55 )
56 }
57}
58
59#[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize, Deserialize)]
61pub struct TensorCompilationArg {
62 pub inplace: Option<Id>,
63 pub vectorisation: Vectorization,
64}
65
66impl CompilationArg for TensorCompilationArg {}
67
68impl<C: CubePrimitive> LaunchArgExpand for Tensor<C> {
69 type CompilationArg = TensorCompilationArg;
70
71 fn expand(
72 arg: &Self::CompilationArg,
73 builder: &mut KernelBuilder,
74 ) -> ExpandElementTyped<Tensor<C>> {
75 builder
76 .input_tensor(Item::vectorized(
77 C::as_elem(&builder.context),
78 arg.vectorisation,
79 ))
80 .into()
81 }
82 fn expand_output(
83 arg: &Self::CompilationArg,
84 builder: &mut KernelBuilder,
85 ) -> ExpandElementTyped<Tensor<C>> {
86 match arg.inplace {
87 Some(id) => builder.inplace_output(id).into(),
88 None => builder
89 .output_tensor(Item::vectorized(
90 C::as_elem(&builder.context),
91 arg.vectorisation,
92 ))
93 .into(),
94 }
95 }
96}
97
98impl<C: CubePrimitive> LaunchArg for Tensor<C> {
99 type RuntimeArg<'a, R: Runtime> = TensorArg<'a, R>;
100
101 fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
102 match runtime_arg {
103 TensorArg::Handle {
104 vectorization_factor,
105 ..
106 } => TensorCompilationArg {
107 inplace: None,
108 vectorisation: Vectorization::Some(NonZero::new(*vectorization_factor).unwrap()),
109 },
110 TensorArg::Alias { input_pos } => TensorCompilationArg {
111 inplace: Some(*input_pos as Id),
112 vectorisation: Vectorization::None,
113 },
114 }
115 }
116}
117
118impl<'a, R: Runtime> TensorArg<'a, R> {
119 pub unsafe fn from_raw_parts<E: CubePrimitive>(
126 handle: &'a cubecl_runtime::server::Handle,
127 strides: &'a [usize],
128 shape: &'a [usize],
129 factor: u8,
130 ) -> Self {
131 unsafe {
132 Self::Handle {
133 handle: TensorHandleRef::from_raw_parts(
134 handle,
135 strides,
136 shape,
137 E::size().expect("Element should have a size"),
138 ),
139 vectorization_factor: factor,
140 }
141 }
142 }
143
144 pub unsafe fn from_raw_parts_and_size(
152 handle: &'a cubecl_runtime::server::Handle,
153 strides: &'a [usize],
154 shape: &'a [usize],
155 factor: u8,
156 elem_size: usize,
157 ) -> Self {
158 unsafe {
159 Self::Handle {
160 handle: TensorHandleRef::from_raw_parts(handle, strides, shape, elem_size),
161 vectorization_factor: factor,
162 }
163 }
164 }
165
166 pub fn alias(position: usize) -> Self {
168 Self::Alias {
169 input_pos: position,
170 }
171 }
172}
173
174impl<R: Runtime> ArgSettings<R> for TensorArg<'_, R> {
175 fn register(&self, launcher: &mut KernelLauncher<R>) {
176 launcher.register_tensor(self);
177 }
178}
179
180impl<'a, R: Runtime> TensorHandleRef<'a, R> {
181 pub fn as_tensor_arg(&'a self, vectorisation: u8) -> TensorArg<'a, R> {
183 unsafe {
184 TensorArg::from_raw_parts_and_size(
185 self.handle,
186 self.strides,
187 self.shape,
188 vectorisation,
189 self.elem_size,
190 )
191 }
192 }
193 pub unsafe fn from_raw_parts(
200 handle: &'a cubecl_runtime::server::Handle,
201 strides: &'a [usize],
202 shape: &'a [usize],
203 elem_size: usize,
204 ) -> Self {
205 Self {
206 handle,
207 strides,
208 shape,
209 elem_size,
210 runtime: PhantomData,
211 }
212 }
213}