cubecl_core/frontend/container/tensor/
launch.rs1use core::marker::PhantomData;
2
3use cubecl_ir::AddressType;
4use cubecl_runtime::{runtime::Runtime, server::CopyDescriptor};
5use serde::{Deserialize, Serialize};
6
7use crate::{
8 compute::{KernelBuilder, KernelLauncher},
9 ir::{Id, LineSize, Type},
10 prelude::{
11 ArgSettings, ArrayArg, CompilationArg, CubePrimitive, ExpandElementTyped, LaunchArg,
12 },
13};
14
15use super::Tensor;
16
17#[derive(Debug)]
19pub enum TensorArg<'a, R: Runtime> {
20 Handle {
22 handle: TensorHandleRef<'a, R>,
24 line_size: LineSize,
26 },
27 Alias {
29 input_pos: usize,
31 },
32}
33
34pub struct TensorHandleRef<'a, R: Runtime> {
37 pub handle: &'a cubecl_runtime::server::Handle,
38 pub strides: &'a [usize],
39 pub shape: &'a [usize],
40 pub elem_size: usize,
41 pub runtime: PhantomData<R>,
42}
43
44impl<'a, R: Runtime> Clone for TensorHandleRef<'a, R> {
45 fn clone(&self) -> Self {
46 *self
47 }
48}
49
50impl<'a, R: Runtime> Copy for TensorHandleRef<'a, R> {}
51
52impl<R: Runtime> TensorHandleRef<'_, R> {
53 pub fn size(&self) -> usize {
54 self.shape.iter().product()
55 }
56
57 pub fn required_address_type(&self) -> AddressType {
59 AddressType::from_len(self.handle.size() as usize / self.elem_size)
60 }
61}
62
63impl<R: Runtime> core::fmt::Debug for TensorHandleRef<'_, R> {
64 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
65 writeln!(
66 f,
67 "TensorHandleRef {{ strides: {:?}, shape: {:?} }}",
68 self.strides, self.shape
69 )
70 }
71}
72
73#[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize, Deserialize)]
75pub struct TensorCompilationArg {
76 pub inplace: Option<Id>,
77 pub line_size: LineSize,
78}
79
80impl CompilationArg for TensorCompilationArg {}
81
82impl<C: CubePrimitive> LaunchArg for Tensor<C> {
83 type RuntimeArg<'a, R: Runtime> = TensorArg<'a, R>;
84 type CompilationArg = TensorCompilationArg;
85
86 fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
87 match runtime_arg {
88 TensorArg::Handle { line_size, .. } => TensorCompilationArg {
89 inplace: None,
90 line_size: *line_size as LineSize,
91 },
92 TensorArg::Alias { input_pos } => TensorCompilationArg {
93 inplace: Some(*input_pos as Id),
94 line_size: 0,
95 },
96 }
97 }
98
99 fn expand(
100 arg: &Self::CompilationArg,
101 builder: &mut KernelBuilder,
102 ) -> ExpandElementTyped<Tensor<C>> {
103 builder
104 .input_tensor(Type::new(C::as_type(&builder.scope)).line(arg.line_size))
105 .into()
106 }
107 fn expand_output(
108 arg: &Self::CompilationArg,
109 builder: &mut KernelBuilder,
110 ) -> ExpandElementTyped<Tensor<C>> {
111 match arg.inplace {
112 Some(id) => builder.inplace_output(id).into(),
113 None => builder
114 .output_tensor(Type::new(C::as_type(&builder.scope)).line(arg.line_size))
115 .into(),
116 }
117 }
118}
119
120impl<'a, R: Runtime> TensorArg<'a, R> {
121 pub unsafe fn from_raw_parts<E: CubePrimitive>(
128 handle: &'a cubecl_runtime::server::Handle,
129 strides: &'a [usize],
130 shape: &'a [usize],
131 factor: LineSize,
132 ) -> Self {
133 unsafe {
134 Self::Handle {
135 handle: TensorHandleRef::from_raw_parts(
136 handle,
137 strides,
138 shape,
139 E::size().expect("Element should have a size"),
140 ),
141 line_size: factor,
142 }
143 }
144 }
145
146 pub unsafe fn from_raw_parts_and_size(
154 handle: &'a cubecl_runtime::server::Handle,
155 strides: &'a [usize],
156 shape: &'a [usize],
157 factor: LineSize,
158 elem_size: usize,
159 ) -> Self {
160 unsafe {
161 Self::Handle {
162 handle: TensorHandleRef::from_raw_parts(handle, strides, shape, elem_size),
163 line_size: factor,
164 }
165 }
166 }
167
168 pub fn alias(position: usize) -> Self {
170 Self::Alias {
171 input_pos: position,
172 }
173 }
174}
175
176impl<R: Runtime> ArgSettings<R> for TensorArg<'_, R> {
177 fn register(&self, launcher: &mut KernelLauncher<R>) {
178 launcher.register_tensor(self);
179 }
180}
181
182impl<'a, R: Runtime> TensorHandleRef<'a, R> {
183 pub fn as_tensor_arg(&'a self, line_size: LineSize) -> TensorArg<'a, R> {
185 unsafe {
186 TensorArg::from_raw_parts_and_size(
187 self.handle,
188 self.strides,
189 self.shape,
190 line_size,
191 self.elem_size,
192 )
193 }
194 }
195 pub fn as_array_arg(&'a self, line_size: LineSize) -> ArrayArg<'a, R> {
197 let length = self.shape.iter().product();
198 unsafe { ArrayArg::from_raw_parts_and_size(self.handle, length, line_size, self.elem_size) }
199 }
200 pub unsafe fn from_raw_parts(
207 handle: &'a cubecl_runtime::server::Handle,
208 strides: &'a [usize],
209 shape: &'a [usize],
210 elem_size: usize,
211 ) -> Self {
212 Self {
213 handle,
214 strides,
215 shape,
216 elem_size,
217 runtime: PhantomData,
218 }
219 }
220
221 pub fn as_copy_descriptor(&self) -> CopyDescriptor<'_> {
222 CopyDescriptor {
223 binding: self.handle.clone().binding(),
224 shape: self.shape,
225 strides: self.strides,
226 elem_size: self.elem_size,
227 }
228 }
229}