cubecl_core/frontend/container/tensor/
launch.rs

1use std::marker::PhantomData;
2
3use serde::{Deserialize, Serialize};
4
5use crate::{
6    Runtime,
7    compute::{KernelBuilder, KernelLauncher},
8    ir::{Id, LineSize, Type},
9    prelude::{
10        ArgSettings, ArrayArg, CompilationArg, CubePrimitive, ExpandElementTyped, LaunchArg,
11    },
12};
13
14use super::Tensor;
15
16/// Argument to be used for [tensors](Tensor) passed as arguments to kernels.
17#[derive(Debug)]
18pub enum TensorArg<'a, R: Runtime> {
19    /// The tensor is passed with a tensor handle.
20    Handle {
21        /// The tensor handle.
22        handle: TensorHandleRef<'a, R>,
23        /// The vectorization factor.
24        line_size: u8,
25    },
26    /// The tensor is aliasing another input tensor.
27    Alias {
28        /// The position of the input tensor.
29        input_pos: usize,
30    },
31}
32
33/// Tensor representation with a reference to the [server handle](cubecl_runtime::server::Handle),
34/// the strides and the shape.
35pub struct TensorHandleRef<'a, R: Runtime> {
36    pub handle: &'a cubecl_runtime::server::Handle,
37    pub strides: &'a [usize],
38    pub shape: &'a [usize],
39    pub elem_size: usize,
40    pub runtime: PhantomData<R>,
41}
42
43impl<'a, R: Runtime> Clone for TensorHandleRef<'a, R> {
44    fn clone(&self) -> Self {
45        *self
46    }
47}
48
49impl<'a, R: Runtime> Copy for TensorHandleRef<'a, R> {}
50
51impl<R: Runtime> TensorHandleRef<'_, R> {
52    pub fn size(&self) -> usize {
53        self.shape.iter().product()
54    }
55}
56
57impl<R: Runtime> core::fmt::Debug for TensorHandleRef<'_, R> {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        writeln!(
60            f,
61            "TensorHandleRef {{ strides: {:?}, shape: {:?} }}",
62            self.strides, self.shape
63        )
64    }
65}
66
67/// Compilation argument for a [tensor](Tensor).
68#[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize, Deserialize)]
69pub struct TensorCompilationArg {
70    pub inplace: Option<Id>,
71    pub line_size: LineSize,
72}
73
74impl CompilationArg for TensorCompilationArg {}
75
76impl<C: CubePrimitive> LaunchArg for Tensor<C> {
77    type RuntimeArg<'a, R: Runtime> = TensorArg<'a, R>;
78    type CompilationArg = TensorCompilationArg;
79
80    fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
81        match runtime_arg {
82            TensorArg::Handle { line_size, .. } => TensorCompilationArg {
83                inplace: None,
84                line_size: *line_size as u32,
85            },
86            TensorArg::Alias { input_pos } => TensorCompilationArg {
87                inplace: Some(*input_pos as Id),
88                line_size: 0,
89            },
90        }
91    }
92
93    fn expand(
94        arg: &Self::CompilationArg,
95        builder: &mut KernelBuilder,
96    ) -> ExpandElementTyped<Tensor<C>> {
97        builder
98            .input_tensor(Type::new(C::as_type(&builder.scope)).line(arg.line_size))
99            .into()
100    }
101    fn expand_output(
102        arg: &Self::CompilationArg,
103        builder: &mut KernelBuilder,
104    ) -> ExpandElementTyped<Tensor<C>> {
105        match arg.inplace {
106            Some(id) => builder.inplace_output(id).into(),
107            None => builder
108                .output_tensor(Type::new(C::as_type(&builder.scope)).line(arg.line_size))
109                .into(),
110        }
111    }
112}
113
114impl<'a, R: Runtime> TensorArg<'a, R> {
115    /// Create a new tensor argument specified with its vectorization factor.
116    ///
117    /// # Safety
118    ///
119    /// If you provide wrong strides or shapes, it might create undefined behavior caused by
120    /// out-of-bound reads and writes.
121    pub unsafe fn from_raw_parts<E: CubePrimitive>(
122        handle: &'a cubecl_runtime::server::Handle,
123        strides: &'a [usize],
124        shape: &'a [usize],
125        factor: u8,
126    ) -> Self {
127        unsafe {
128            Self::Handle {
129                handle: TensorHandleRef::from_raw_parts(
130                    handle,
131                    strides,
132                    shape,
133                    E::size().expect("Element should have a size"),
134                ),
135                line_size: factor,
136            }
137        }
138    }
139
140    /// Create a new tensor argument specified with its vectorization factor with a manual element
141    /// size in bytes.
142    ///
143    /// # Safety
144    ///
145    /// If you provide wrong strides or shapes, it might create undefined behavior caused by
146    /// out-of-bound reads and writes.
147    pub unsafe fn from_raw_parts_and_size(
148        handle: &'a cubecl_runtime::server::Handle,
149        strides: &'a [usize],
150        shape: &'a [usize],
151        factor: u8,
152        elem_size: usize,
153    ) -> Self {
154        unsafe {
155            Self::Handle {
156                handle: TensorHandleRef::from_raw_parts(handle, strides, shape, elem_size),
157                line_size: factor,
158            }
159        }
160    }
161
162    /// Create an alias argument.
163    pub fn alias(position: usize) -> Self {
164        Self::Alias {
165            input_pos: position,
166        }
167    }
168}
169
170impl<R: Runtime> ArgSettings<R> for TensorArg<'_, R> {
171    fn register(&self, launcher: &mut KernelLauncher<R>) {
172        launcher.register_tensor(self);
173    }
174}
175
176impl<'a, R: Runtime> TensorHandleRef<'a, R> {
177    /// Convert the handle into a [tensor argument](TensorArg).
178    pub fn as_tensor_arg(&'a self, vectorisation: u8) -> TensorArg<'a, R> {
179        unsafe {
180            TensorArg::from_raw_parts_and_size(
181                self.handle,
182                self.strides,
183                self.shape,
184                vectorisation,
185                self.elem_size,
186            )
187        }
188    }
189    /// Convert the handle into an [array argument](ArrayArg).
190    pub fn as_array_arg(&'a self, line_size: u8) -> ArrayArg<'a, R> {
191        let length = self.shape.iter().product();
192        unsafe { ArrayArg::from_raw_parts_and_size(self.handle, length, line_size, self.elem_size) }
193    }
194    /// Create a handle from raw parts.
195    ///
196    /// # Safety
197    ///
198    /// If you provide wrong strides or shapes, it might create undefined behavior caused by
199    /// out-of-bounds reads and writes.
200    pub unsafe fn from_raw_parts(
201        handle: &'a cubecl_runtime::server::Handle,
202        strides: &'a [usize],
203        shape: &'a [usize],
204        elem_size: usize,
205    ) -> Self {
206        Self {
207            handle,
208            strides,
209            shape,
210            elem_size,
211            runtime: PhantomData,
212        }
213    }
214}