cubecl_core/frontend/container/array/
launch.rs1use core::marker::PhantomData;
2
3use cubecl_runtime::runtime::Runtime;
4use serde::{Deserialize, Serialize};
5
6use crate::{
7 compute::{KernelBuilder, KernelLauncher},
8 ir::Id,
9 prelude::{CubePrimitive, LaunchArg, NativeExpand, TensorBinding},
10};
11
12use super::Array;
13
14#[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize, Deserialize)]
15pub struct ArrayCompilationArg {
16 pub inplace: Option<Id>,
17}
18
19pub struct ArrayBinding<R: Runtime> {
21 pub handle: cubecl_runtime::server::Binding,
22 pub(crate) length: [usize; 1],
23 runtime: PhantomData<R>,
24}
25
26pub enum ArrayArg<R: Runtime> {
27 Handle {
29 handle: ArrayBinding<R>,
31 },
32 Alias {
34 input_pos: usize,
36 length: [usize; 1],
38 },
39}
40
41impl<R: Runtime> ArrayArg<R> {
42 pub unsafe fn from_raw_parts(handle: cubecl_runtime::server::Handle, length: usize) -> Self {
48 unsafe {
49 ArrayArg::Handle {
50 handle: ArrayBinding::from_raw_parts(handle, length),
51 }
52 }
53 }
54 pub unsafe fn from_raw_parts_binding(
60 binding: cubecl_runtime::server::Binding,
61 length: usize,
62 ) -> Self {
63 unsafe {
64 ArrayArg::Handle {
65 handle: ArrayBinding::from_raw_parts_binding(binding, length),
66 }
67 }
68 }
69
70 pub fn size(&self) -> usize {
71 match self {
72 ArrayArg::Handle { handle } => handle.length[0],
73 ArrayArg::Alias { length, .. } => length[0],
74 }
75 }
76
77 pub fn shape(&self) -> &[usize] {
78 match self {
79 ArrayArg::Handle { handle } => &handle.length,
80 ArrayArg::Alias { length, .. } => length,
81 }
82 }
83}
84
85impl<R: Runtime> ArrayBinding<R> {
86 pub unsafe fn from_raw_parts(handle: cubecl_runtime::server::Handle, length: usize) -> Self {
92 unsafe { Self::from_raw_parts_binding(handle.binding(), length) }
93 }
94
95 pub unsafe fn from_raw_parts_binding(
101 handle: cubecl_runtime::server::Binding,
102 length: usize,
103 ) -> Self {
104 Self {
105 handle,
106 length: [length],
107 runtime: PhantomData,
108 }
109 }
110
111 pub fn into_tensor(self) -> TensorBinding<R> {
113 let shape = self.length.into();
114
115 TensorBinding {
116 handle: self.handle,
117 strides: [1].into(),
118 shape,
119 runtime: PhantomData,
120 }
121 }
122}
123
124impl<C: CubePrimitive> LaunchArg for Array<C> {
125 type RuntimeArg<R: Runtime> = ArrayArg<R>;
126 type CompilationArg = ArrayCompilationArg;
127
128 fn register<R: Runtime>(
129 arg: Self::RuntimeArg<R>,
130 launcher: &mut KernelLauncher<R>,
131 ) -> Self::CompilationArg {
132 let ty = launcher.with_scope(|scope| C::as_type(scope));
133 let compilation_arg = match &arg {
134 ArrayArg::Handle { .. } => ArrayCompilationArg { inplace: None },
135 ArrayArg::Alias { input_pos, .. } => ArrayCompilationArg {
136 inplace: Some(*input_pos as Id),
137 },
138 };
139 launcher.register_array(arg, ty);
140 compilation_arg
141 }
142
143 fn expand(_arg: &Self::CompilationArg, builder: &mut KernelBuilder) -> NativeExpand<Array<C>> {
144 let ty = C::as_type(&builder.scope);
145 builder.input_array(ty).into()
146 }
147 fn expand_output(
148 arg: &Self::CompilationArg,
149 builder: &mut KernelBuilder,
150 ) -> NativeExpand<Array<C>> {
151 match arg.inplace {
152 Some(id) => builder.inplace_output(id).into(),
153 None => builder.output_array(C::as_type(&builder.scope)).into(),
154 }
155 }
156}