1use alloc::{
2 borrow::Cow,
3 string::{String, ToString},
4};
5use derive_more::Display;
6
7use crate as cubecl;
8use cubecl::prelude::*;
9use cubecl_ir::{ManagedVariable, Variable};
10
11define_scalar!(ElemA);
12define_size!(SizeA);
13
14#[cube]
16pub fn read_masked<C: CubePrimitive>(mask: bool, list: Slice<C>, index: usize, value: C) -> C {
17 let index = index * usize::cast_from(mask);
18 let input = list.read_unchecked(index);
19
20 select(mask, input, value)
21}
22
23#[cube]
25pub fn read_tensor_checked<C: CubePrimitive + Default + IntoRuntime>(
26 tensor: Tensor<C>,
27 index: usize,
28 #[comptime] unroll_factor: usize,
29) -> C {
30 let len = tensor.buffer_len() * unroll_factor;
31 let in_bounds = index < len;
32 let index = index.min(len - 1);
33
34 select(in_bounds, tensor.read_unchecked(index), C::default())
35}
36
37#[cube]
39pub fn read_tensor_atomic_checked<C: Scalar>(
40 tensor: Tensor<Atomic<C>>,
41 index: usize,
42 #[comptime] unroll_factor: usize,
43) -> Atomic<C> {
44 let index = index.min(tensor.buffer_len() * unroll_factor - 1);
45
46 tensor.read_unchecked(index)
47}
48
49#[cube]
51pub fn read_tensor_validate<C: CubePrimitive + Default + IntoRuntime>(
52 tensor: Tensor<C>,
53 index: usize,
54 #[comptime] unroll_factor: usize,
55 #[comptime] kernel_name: String,
56) -> C {
57 let len = tensor.buffer_len() * unroll_factor;
58 let in_bounds = index < len;
59 if !in_bounds {
60 print_oob::<Tensor<C>>(kernel_name, OobKind::Read, index, len, &tensor);
61 }
62
63 let index = index.min(len - 1);
64
65 select(in_bounds, tensor.read_unchecked(index), C::default())
66}
67
68#[cube]
70pub fn read_tensor_atomic_validate<C: Scalar>(
71 tensor: Tensor<Atomic<C>>,
72 index: usize,
73 #[comptime] unroll_factor: usize,
74 #[comptime] kernel_name: String,
75) -> Atomic<C> {
76 let len = tensor.buffer_len() * unroll_factor;
77 if index >= len {
78 print_oob::<Tensor<Atomic<C>>>(kernel_name, OobKind::Read, index, len, &tensor);
79 }
80 let index = index.min(tensor.buffer_len() * unroll_factor - 1);
81
82 tensor.read_unchecked(index)
83}
84
85#[cube]
86fn checked_index_assign<E: Scalar, N: Size>(
87 index: usize,
88 value: Vector<E, N>,
89 out: &mut Array<Vector<E, N>>,
90 #[comptime] has_buffer_len: bool,
91 #[comptime] unroll_factor: usize,
92) {
93 let array_len = if has_buffer_len {
94 out.buffer_len()
95 } else {
96 out.len()
97 };
98
99 if index < array_len * unroll_factor {
100 unsafe { out.index_assign_unchecked(index, value) };
101 }
102}
103
104#[cube]
105fn validate_index_assign<E: Scalar, N: Size>(
106 index: usize,
107 value: Vector<E, N>,
108 out: &mut Array<Vector<E, N>>,
109 #[comptime] has_buffer_len: bool,
110 #[comptime] unroll_factor: usize,
111 #[comptime] kernel_name: String,
112) {
113 let array_len = if has_buffer_len {
114 out.buffer_len()
115 } else {
116 out.len()
117 };
118 let len = array_len * unroll_factor;
119
120 if index < len {
121 unsafe { out.index_assign_unchecked(index, value) };
122 } else {
123 print_oob::<Array<Vector<E, N>>>(kernel_name, OobKind::Write, index, len, out);
124 }
125}
126
127#[derive(Display)]
128enum OobKind {
129 #[display("read")]
130 Read,
131 #[display("write")]
132 Write,
133}
134
135#[cube]
136#[allow(unused)]
137fn print_oob<Out: CubeType<ExpandType: Into<Variable>>>(
138 #[comptime] kernel_name: String,
139 #[comptime] kind: OobKind,
140 index: usize,
141 len: usize,
142 buffer: &Out,
143) {
144 intrinsic!(|scope| {
145 let name = name_of_var(scope, buffer.into());
146 debug_print_expand!(
147 scope,
148 alloc::format!(
149 "[VALIDATION {kernel_name}]: Encountered OOB {kind} in {name} at %u, length is %u\n"
150 ),
151 index,
152 len
153 );
154 })
155}
156
157fn name_of_var(scope: &Scope, var: Variable) -> Cow<'static, str> {
158 let debug_name = scope.debug.variable_names.borrow().get(&var).cloned();
159 debug_name.unwrap_or_else(|| var.to_string().into())
160}
161
162#[allow(missing_docs)]
163pub fn expand_checked_index_assign(
164 scope: &mut Scope,
165 lhs: Variable,
166 rhs: Variable,
167 out: Variable,
168 unroll_factor: usize,
169) {
170 scope.register_type::<ElemA>(rhs.ty.storage_type());
171 scope.register_size::<SizeA>(rhs.ty.vector_size());
172 checked_index_assign::expand::<ElemA, SizeA>(
173 scope,
174 ManagedVariable::Plain(lhs).into(),
175 ManagedVariable::Plain(rhs).into(),
176 ManagedVariable::Plain(out).into(),
177 out.has_buffer_length(),
178 unroll_factor,
179 );
180}
181
182#[allow(missing_docs)]
183pub fn expand_validate_index_assign(
184 scope: &mut Scope,
185 lhs: Variable,
186 rhs: Variable,
187 out: Variable,
188 unroll_factor: usize,
189 kernel_name: &str,
190) {
191 scope.register_type::<ElemA>(rhs.ty.storage_type());
192 scope.register_size::<SizeA>(rhs.ty.vector_size());
193 validate_index_assign::expand::<ElemA, SizeA>(
194 scope,
195 ManagedVariable::Plain(lhs).into(),
196 ManagedVariable::Plain(rhs).into(),
197 ManagedVariable::Plain(out).into(),
198 out.has_buffer_length(),
199 unroll_factor,
200 kernel_name.to_string(),
201 );
202}