Skip to main content

cubecl_core/
io.rs

1use alloc::{
2    borrow::Cow,
3    string::{String, ToString},
4};
5use derive_more::Display;
6
7use crate as cubecl;
8use cubecl::prelude::*;
9use cubecl_ir::{ManagedVariable, Variable};
10
11define_scalar!(ElemA);
12define_size!(SizeA);
13
14/// Returns the value at `index` in `list` if `condition` is `true`, otherwise returns `value`.
15#[cube]
16pub fn read_masked<C: CubePrimitive>(mask: bool, list: Slice<C>, index: usize, value: C) -> C {
17    let index = index * usize::cast_from(mask);
18    let input = list.read_unchecked(index);
19
20    select(mask, input, value)
21}
22
23/// Returns the value at `index` in tensor within bounds.
24#[cube]
25pub fn read_tensor_checked<C: CubePrimitive + Default + IntoRuntime>(
26    tensor: Tensor<C>,
27    index: usize,
28    #[comptime] unroll_factor: usize,
29) -> C {
30    let len = tensor.buffer_len() * unroll_factor;
31    let in_bounds = index < len;
32    let index = index.min(len - 1);
33
34    select(in_bounds, tensor.read_unchecked(index), C::default())
35}
36
37/// Returns the value at `index` in tensor within bounds.
38#[cube]
39pub fn read_tensor_atomic_checked<C: Scalar>(
40    tensor: Tensor<Atomic<C>>,
41    index: usize,
42    #[comptime] unroll_factor: usize,
43) -> Atomic<C> {
44    let index = index.min(tensor.buffer_len() * unroll_factor - 1);
45
46    tensor.read_unchecked(index)
47}
48
49/// Returns the value at `index` in tensor within bounds.
50#[cube]
51pub fn read_tensor_validate<C: CubePrimitive + Default + IntoRuntime>(
52    tensor: Tensor<C>,
53    index: usize,
54    #[comptime] unroll_factor: usize,
55    #[comptime] kernel_name: String,
56) -> C {
57    let len = tensor.buffer_len() * unroll_factor;
58    let in_bounds = index < len;
59    if !in_bounds {
60        print_oob::<Tensor<C>>(kernel_name, OobKind::Read, index, len, &tensor);
61    }
62
63    let index = index.min(len - 1);
64
65    select(in_bounds, tensor.read_unchecked(index), C::default())
66}
67
68/// Returns the value at `index` in tensor within bounds.
69#[cube]
70pub fn read_tensor_atomic_validate<C: Scalar>(
71    tensor: Tensor<Atomic<C>>,
72    index: usize,
73    #[comptime] unroll_factor: usize,
74    #[comptime] kernel_name: String,
75) -> Atomic<C> {
76    let len = tensor.buffer_len() * unroll_factor;
77    if index >= len {
78        print_oob::<Tensor<Atomic<C>>>(kernel_name, OobKind::Read, index, len, &tensor);
79    }
80    let index = index.min(tensor.buffer_len() * unroll_factor - 1);
81
82    tensor.read_unchecked(index)
83}
84
85#[cube]
86fn checked_index_assign<E: Scalar, N: Size>(
87    index: usize,
88    value: Vector<E, N>,
89    out: &mut Array<Vector<E, N>>,
90    #[comptime] has_buffer_len: bool,
91    #[comptime] unroll_factor: usize,
92) {
93    let array_len = if has_buffer_len {
94        out.buffer_len()
95    } else {
96        out.len()
97    };
98
99    if index < array_len * unroll_factor {
100        unsafe { out.index_assign_unchecked(index, value) };
101    }
102}
103
104#[cube]
105fn validate_index_assign<E: Scalar, N: Size>(
106    index: usize,
107    value: Vector<E, N>,
108    out: &mut Array<Vector<E, N>>,
109    #[comptime] has_buffer_len: bool,
110    #[comptime] unroll_factor: usize,
111    #[comptime] kernel_name: String,
112) {
113    let array_len = if has_buffer_len {
114        out.buffer_len()
115    } else {
116        out.len()
117    };
118    let len = array_len * unroll_factor;
119
120    if index < len {
121        unsafe { out.index_assign_unchecked(index, value) };
122    } else {
123        print_oob::<Array<Vector<E, N>>>(kernel_name, OobKind::Write, index, len, out);
124    }
125}
126
127#[derive(Display)]
128enum OobKind {
129    #[display("read")]
130    Read,
131    #[display("write")]
132    Write,
133}
134
135#[cube]
136#[allow(unused)]
137fn print_oob<Out: CubeType<ExpandType: Into<Variable>>>(
138    #[comptime] kernel_name: String,
139    #[comptime] kind: OobKind,
140    index: usize,
141    len: usize,
142    buffer: &Out,
143) {
144    intrinsic!(|scope| {
145        let name = name_of_var(scope, buffer.into());
146        debug_print_expand!(
147            scope,
148            alloc::format!(
149                "[VALIDATION {kernel_name}]: Encountered OOB {kind} in {name} at %u, length is %u\n"
150            ),
151            index,
152            len
153        );
154    })
155}
156
157fn name_of_var(scope: &Scope, var: Variable) -> Cow<'static, str> {
158    let debug_name = scope.debug.variable_names.borrow().get(&var).cloned();
159    debug_name.unwrap_or_else(|| var.to_string().into())
160}
161
162#[allow(missing_docs)]
163pub fn expand_checked_index_assign(
164    scope: &mut Scope,
165    lhs: Variable,
166    rhs: Variable,
167    out: Variable,
168    unroll_factor: usize,
169) {
170    scope.register_type::<ElemA>(rhs.ty.storage_type());
171    scope.register_size::<SizeA>(rhs.ty.vector_size());
172    checked_index_assign::expand::<ElemA, SizeA>(
173        scope,
174        ManagedVariable::Plain(lhs).into(),
175        ManagedVariable::Plain(rhs).into(),
176        ManagedVariable::Plain(out).into(),
177        out.has_buffer_length(),
178        unroll_factor,
179    );
180}
181
182#[allow(missing_docs)]
183pub fn expand_validate_index_assign(
184    scope: &mut Scope,
185    lhs: Variable,
186    rhs: Variable,
187    out: Variable,
188    unroll_factor: usize,
189    kernel_name: &str,
190) {
191    scope.register_type::<ElemA>(rhs.ty.storage_type());
192    scope.register_size::<SizeA>(rhs.ty.vector_size());
193    validate_index_assign::expand::<ElemA, SizeA>(
194        scope,
195        ManagedVariable::Plain(lhs).into(),
196        ManagedVariable::Plain(rhs).into(),
197        ManagedVariable::Plain(out).into(),
198        out.has_buffer_length(),
199        unroll_factor,
200        kernel_name.to_string(),
201    );
202}