Skip to main content

cubecl_core/frontend/container/array/
launch.rs

1use core::marker::PhantomData;
2
3use cubecl_runtime::runtime::Runtime;
4use serde::{Deserialize, Serialize};
5
6use crate::{
7    compute::{KernelBuilder, KernelLauncher},
8    ir::Id,
9    prelude::{CubePrimitive, LaunchArg, NativeExpand, TensorBinding},
10};
11
12use super::Array;
13
14#[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize, Deserialize)]
15pub struct ArrayCompilationArg {
16    pub inplace: Option<Id>,
17}
18
19/// Tensor representation with a reference to the [server handle](cubecl_runtime::server::Handle).
20pub struct ArrayBinding<R: Runtime> {
21    pub handle: cubecl_runtime::server::Binding,
22    pub(crate) length: [usize; 1],
23    runtime: PhantomData<R>,
24}
25
26pub enum ArrayArg<R: Runtime> {
27    /// The array is passed with an array handle.
28    Handle {
29        /// The array handle.
30        handle: ArrayBinding<R>,
31    },
32    /// The array is aliasing another input array.
33    Alias {
34        /// The position of the input array.
35        input_pos: usize,
36        /// The length of the underlying handle
37        length: [usize; 1],
38    },
39}
40
41impl<R: Runtime> ArrayArg<R> {
42    /// Create a new array argument.
43    ///
44    /// # Safety
45    ///
46    /// Specifying the wrong length may lead to out-of-bounds reads and writes.
47    pub unsafe fn from_raw_parts(handle: cubecl_runtime::server::Handle, length: usize) -> Self {
48        unsafe {
49            ArrayArg::Handle {
50                handle: ArrayBinding::from_raw_parts(handle, length),
51            }
52        }
53    }
54    /// Create a new array argument from a binding.
55    ///
56    /// # Safety
57    ///
58    /// Specifying the wrong length may lead to out-of-bounds reads and writes.
59    pub unsafe fn from_raw_parts_binding(
60        binding: cubecl_runtime::server::Binding,
61        length: usize,
62    ) -> Self {
63        unsafe {
64            ArrayArg::Handle {
65                handle: ArrayBinding::from_raw_parts_binding(binding, length),
66            }
67        }
68    }
69
70    pub fn size(&self) -> usize {
71        match self {
72            ArrayArg::Handle { handle } => handle.length[0],
73            ArrayArg::Alias { length, .. } => length[0],
74        }
75    }
76
77    pub fn shape(&self) -> &[usize] {
78        match self {
79            ArrayArg::Handle { handle } => &handle.length,
80            ArrayArg::Alias { length, .. } => length,
81        }
82    }
83}
84
85impl<R: Runtime> ArrayBinding<R> {
86    /// Create a new array handle reference.
87    ///
88    /// # Safety
89    ///
90    /// Specifying the wrong length may lead to out-of-bounds reads and writes.
91    pub unsafe fn from_raw_parts(handle: cubecl_runtime::server::Handle, length: usize) -> Self {
92        unsafe { Self::from_raw_parts_binding(handle.binding(), length) }
93    }
94
95    /// Create a new array handle reference.
96    ///
97    /// # Safety
98    ///
99    /// Specifying the wrong length or size, may lead to out-of-bounds reads and writes.
100    pub unsafe fn from_raw_parts_binding(
101        handle: cubecl_runtime::server::Binding,
102        length: usize,
103    ) -> Self {
104        Self {
105            handle,
106            length: [length],
107            runtime: PhantomData,
108        }
109    }
110
111    /// Return the handle as a tensor instead of an array.
112    pub fn into_tensor(self) -> TensorBinding<R> {
113        let shape = self.length.into();
114
115        TensorBinding {
116            handle: self.handle,
117            strides: [1].into(),
118            shape,
119            runtime: PhantomData,
120        }
121    }
122}
123
124impl<C: CubePrimitive> LaunchArg for Array<C> {
125    type RuntimeArg<R: Runtime> = ArrayArg<R>;
126    type CompilationArg = ArrayCompilationArg;
127
128    fn register<R: Runtime>(
129        arg: Self::RuntimeArg<R>,
130        launcher: &mut KernelLauncher<R>,
131    ) -> Self::CompilationArg {
132        let ty = launcher.with_scope(|scope| C::as_type(scope));
133        let compilation_arg = match &arg {
134            ArrayArg::Handle { .. } => ArrayCompilationArg { inplace: None },
135            ArrayArg::Alias { input_pos, .. } => ArrayCompilationArg {
136                inplace: Some(*input_pos as Id),
137            },
138        };
139        launcher.register_array(arg, ty);
140        compilation_arg
141    }
142
143    fn expand(_arg: &Self::CompilationArg, builder: &mut KernelBuilder) -> NativeExpand<Array<C>> {
144        let ty = C::as_type(&builder.scope);
145        builder.input_array(ty).into()
146    }
147    fn expand_output(
148        arg: &Self::CompilationArg,
149        builder: &mut KernelBuilder,
150    ) -> NativeExpand<Array<C>> {
151        match arg.inplace {
152            Some(id) => builder.inplace_output(id).into(),
153            None => builder.output_array(C::as_type(&builder.scope)).into(),
154        }
155    }
156}