cubecl_core/frontend/container/tensor/
launch.rs1use core::marker::PhantomData;
2
3use cubecl_ir::AddressType;
4use cubecl_runtime::{runtime::Runtime, server::CopyDescriptor};
5use cubecl_zspace::{Shape, Strides};
6use serde::{Deserialize, Serialize};
7
8use crate::{
9 compute::{KernelBuilder, KernelLauncher},
10 ir::Id,
11 prelude::{ArrayArg, ArrayBinding, CubePrimitive, LaunchArg, NativeExpand},
12};
13
14use super::Tensor;
15
16#[derive(Debug)]
18pub enum TensorArg<R: Runtime> {
19 Handle {
21 handle: TensorBinding<R>,
23 },
24 Alias {
26 input_pos: usize,
28 strides: Strides,
29 shape: Shape,
30 },
31}
32
33pub struct TensorBinding<R: Runtime> {
36 pub handle: cubecl_runtime::server::Binding,
37 pub strides: Strides,
38 pub shape: Shape,
39 pub runtime: PhantomData<R>,
40}
41
42impl<R: Runtime> Clone for TensorBinding<R> {
43 fn clone(&self) -> Self {
44 Self {
45 handle: self.handle.clone(),
46 strides: self.strides.clone(),
47 shape: self.shape.clone(),
48 runtime: PhantomData,
49 }
50 }
51}
52
53impl<R: Runtime> TensorBinding<R> {
54 pub fn size(&self) -> usize {
55 self.shape.iter().product()
56 }
57
58 pub fn required_address_type(&self, elem_size: usize) -> AddressType {
60 AddressType::from_len(self.handle.size() as usize / elem_size)
61 }
62}
63
64impl<R: Runtime> core::fmt::Debug for TensorBinding<R> {
65 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
66 writeln!(
67 f,
68 "TensorHandleRef {{ strides: {:?}, shape: {:?} }}",
69 self.strides, self.shape
70 )
71 }
72}
73
74#[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize, Deserialize)]
76pub struct TensorCompilationArg {
77 pub inplace: Option<Id>,
78}
79
80impl<C: CubePrimitive> LaunchArg for Tensor<C> {
81 type RuntimeArg<R: Runtime> = TensorArg<R>;
82 type CompilationArg = TensorCompilationArg;
83
84 fn register<R: Runtime>(
85 arg: Self::RuntimeArg<R>,
86 launcher: &mut KernelLauncher<R>,
87 ) -> Self::CompilationArg {
88 let ty = launcher.with_scope(|scope| C::as_type(scope));
89 let compilation_arg = match &arg {
90 TensorArg::Handle { .. } => TensorCompilationArg { inplace: None },
91 TensorArg::Alias { input_pos, .. } => TensorCompilationArg {
92 inplace: Some(*input_pos as Id),
93 },
94 };
95 launcher.register_tensor(arg, ty);
96 compilation_arg
97 }
98
99 fn expand(_arg: &Self::CompilationArg, builder: &mut KernelBuilder) -> NativeExpand<Tensor<C>> {
100 builder.input_tensor(C::as_type(&builder.scope)).into()
101 }
102 fn expand_output(
103 arg: &Self::CompilationArg,
104 builder: &mut KernelBuilder,
105 ) -> NativeExpand<Tensor<C>> {
106 match arg.inplace {
107 Some(id) => builder.inplace_output(id).into(),
108 None => builder.output_tensor(C::as_type(&builder.scope)).into(),
109 }
110 }
111}
112
113impl<R: Runtime> TensorArg<R> {
114 pub unsafe fn from_raw_parts(
121 handle: cubecl_runtime::server::Handle,
122 strides: Strides,
123 shape: Shape,
124 ) -> Self {
125 unsafe { Self::from_raw_parts_binding(handle.binding(), strides, shape) }
126 }
127
128 pub(crate) unsafe fn from_raw_parts_binding(
129 handle: cubecl_runtime::server::Binding,
130 strides: Strides,
131 shape: Shape,
132 ) -> Self {
133 unsafe {
134 Self::Handle {
135 handle: TensorBinding::from_raw_parts_binding(handle, strides, shape),
136 }
137 }
138 }
139
140 pub fn into_alias(self, position: usize) -> Self {
142 match self {
143 TensorArg::Handle { handle } => handle.into_alias(position),
144 alias @ TensorArg::Alias { .. } => alias,
145 }
146 }
147
148 pub fn size(&self) -> usize {
149 match self {
150 TensorArg::Handle { handle } => handle.size(),
151 TensorArg::Alias { shape, .. } => shape.iter().product(),
152 }
153 }
154
155 pub fn shape(&self) -> &[usize] {
156 match self {
157 TensorArg::Handle { handle } => &handle.shape,
158 TensorArg::Alias { shape, .. } => shape,
159 }
160 }
161
162 pub fn strides(&self) -> &[usize] {
163 match self {
164 TensorArg::Handle { handle } => &handle.strides,
165 TensorArg::Alias { strides, .. } => strides,
166 }
167 }
168}
169
170impl<R: Runtime> TensorArg<R> {
171 pub fn into_array_arg(self) -> ArrayArg<R> {
172 match self {
173 TensorArg::Handle { handle } => {
174 let handle = unsafe {
175 let size = handle.size();
176 ArrayBinding::from_raw_parts_binding(handle.handle, size)
177 };
178 ArrayArg::Handle { handle }
179 }
180 TensorArg::Alias {
181 input_pos, shape, ..
182 } => ArrayArg::Alias {
183 input_pos,
184 length: [shape.iter().product()],
185 },
186 }
187 }
188}
189
190impl<R: Runtime> TensorBinding<R> {
191 pub fn into_tensor_arg(self) -> TensorArg<R> {
193 unsafe { TensorArg::from_raw_parts_binding(self.handle, self.strides, self.shape) }
194 }
195 pub fn into_alias(self, index: usize) -> TensorArg<R> {
197 TensorArg::Alias {
198 input_pos: index,
199 strides: self.strides,
200 shape: self.shape,
201 }
202 }
203 pub fn as_alias(&self, index: usize) -> TensorArg<R> {
205 TensorArg::Alias {
206 input_pos: index,
207 strides: self.strides.clone(),
208 shape: self.shape.clone(),
209 }
210 }
211 pub fn into_array_arg(self) -> ArrayArg<R> {
213 let length = self.shape.iter().product();
214 unsafe { ArrayArg::from_raw_parts_binding(self.handle, length) }
215 }
216
217 pub unsafe fn from_raw_parts(
224 handle: cubecl_runtime::server::Handle,
225 strides: Strides,
226 shape: Shape,
227 ) -> Self {
228 unsafe { Self::from_raw_parts_binding(handle.binding(), strides, shape) }
229 }
230
231 pub unsafe fn from_raw_parts_binding(
238 handle: cubecl_runtime::server::Binding,
239 strides: Strides,
240 shape: Shape,
241 ) -> Self {
242 Self {
243 handle,
244 strides,
245 shape,
246 runtime: PhantomData,
247 }
248 }
249
250 pub fn into_copy_descriptor(self, elem_size: usize) -> CopyDescriptor {
251 CopyDescriptor {
252 handle: self.handle,
253 shape: self.shape,
254 strides: self.strides,
255 elem_size,
256 }
257 }
258}