cubecl_core/post_processing/
checked_io.rs1use alloc::{string::String, vec::Vec};
2use cubecl_ir::{Allocator, Instruction, ManagedVariable, Operation, Operator, Processor, Scope};
3use cubecl_runtime::server::ExecutionMode;
4
5use crate::{
6 define_scalar, define_size,
7 io::{
8 expand_checked_index_assign, expand_validate_index_assign, read_tensor_atomic_checked,
9 read_tensor_atomic_validate, read_tensor_checked, read_tensor_validate,
10 },
11 prelude::Vector,
12};
13
14define_scalar!(ElemA);
15define_size!(SizeA);
16
17#[derive(new, Debug)]
18pub struct CheckedIoProcessor {
19 mode: ExecutionMode,
20 kernel_name: String,
21}
22
23impl Processor for CheckedIoProcessor {
24 fn transform(
25 &self,
26 processing: cubecl_ir::ScopeProcessing,
27 allocator: Allocator,
28 ) -> cubecl_ir::ScopeProcessing {
29 match self.mode {
30 ExecutionMode::Checked => self.transform_checked(processing, allocator),
31 ExecutionMode::Unchecked => processing,
32 ExecutionMode::Validate => self.transform_validate(processing, allocator),
33 }
34 }
35}
36
37impl CheckedIoProcessor {
38 fn transform_checked(
39 &self,
40 mut processing: cubecl_ir::ScopeProcessing,
41 allocator: Allocator,
42 ) -> cubecl_ir::ScopeProcessing {
43 let mut instructions = Vec::new();
44 core::mem::swap(&mut processing.instructions, &mut instructions);
45
46 for instruction in instructions {
47 if let Operation::Operator(operator) = &instruction.operation {
48 match operator {
49 Operator::Index(op) => {
50 let has_length = op.list.has_length();
51
52 if has_length {
53 let list = ManagedVariable::Plain(op.list);
54 let index = ManagedVariable::Plain(op.index);
55 let mut scope = Scope::root(false)
56 .with_allocator(allocator.clone())
57 .with_types(processing.typemap.clone());
58 scope.register_type::<ElemA>(op.list.storage_type());
59 scope.register_size::<SizeA>(op.list.vector_size());
60
61 let input = if op.list.ty.is_atomic() {
62 read_tensor_atomic_checked::expand::<ElemA>(
67 &mut scope,
68 list.into(),
69 index.into(),
70 op.unroll_factor,
71 )
72 .expand
73 } else {
74 read_tensor_checked::expand::<Vector<ElemA, SizeA>>(
75 &mut scope,
76 list.into(),
77 index.into(),
78 op.unroll_factor,
79 )
80 .expand
81 };
82 let tmp_processing = scope.process([]);
83
84 for inst in tmp_processing.instructions {
85 processing.instructions.push(inst);
86 }
87 for var in tmp_processing.variables {
88 processing.variables.push(var);
89 }
90
91 processing
92 .instructions
93 .push(Instruction::new(Operation::Copy(*input), instruction.out()));
94 continue;
95 }
96 }
97 Operator::IndexAssign(op) => {
98 let out = instruction.out();
99
100 if out.has_length() {
101 let mut scope = Scope::root(false)
102 .with_allocator(allocator.clone())
103 .with_types(processing.typemap.clone());
104 expand_checked_index_assign(
105 &mut scope,
106 op.index,
107 op.value,
108 out,
109 op.unroll_factor,
110 );
111
112 let tmp_processing = scope.process([]);
113
114 for inst in tmp_processing.instructions {
115 processing.instructions.push(inst);
116 }
117 for var in tmp_processing.variables {
118 processing.variables.push(var);
119 }
120
121 continue;
122 }
123 }
124 _ => {}
125 }
126 }
127
128 processing.instructions.push(instruction);
130 }
131 processing
132 }
133
134 fn transform_validate(
135 &self,
136 mut processing: cubecl_ir::ScopeProcessing,
137 allocator: Allocator,
138 ) -> cubecl_ir::ScopeProcessing {
139 let mut instructions = Vec::new();
140 core::mem::swap(&mut processing.instructions, &mut instructions);
141
142 for instruction in instructions {
143 if let Operation::Operator(operator) = &instruction.operation {
144 match operator {
145 Operator::Index(op) => {
146 let has_length = op.list.has_length();
147
148 if has_length {
149 let list = ManagedVariable::Plain(op.list);
150 let index = ManagedVariable::Plain(op.index);
151 let mut scope = Scope::root(false)
152 .with_allocator(allocator.clone())
153 .with_types(processing.typemap.clone());
154 scope.register_type::<ElemA>(op.list.storage_type());
155 scope.register_size::<SizeA>(op.list.vector_size());
156
157 let input = if op.list.ty.is_atomic() {
158 read_tensor_atomic_validate::expand::<ElemA>(
163 &mut scope,
164 list.into(),
165 index.into(),
166 op.unroll_factor,
167 self.kernel_name.clone(),
168 )
169 .expand
170 } else {
171 read_tensor_validate::expand::<Vector<ElemA, SizeA>>(
172 &mut scope,
173 list.into(),
174 index.into(),
175 op.unroll_factor,
176 self.kernel_name.clone(),
177 )
178 .expand
179 };
180 let tmp_processing = scope.process([]);
181
182 for inst in tmp_processing.instructions {
183 processing.instructions.push(inst);
184 }
185 for var in tmp_processing.variables {
186 processing.variables.push(var);
187 }
188
189 processing
190 .instructions
191 .push(Instruction::new(Operation::Copy(*input), instruction.out()));
192 continue;
193 }
194 }
195 Operator::IndexAssign(op) => {
196 let out = instruction.out();
197
198 if out.has_length() {
199 let mut scope = Scope::root(false)
200 .with_allocator(allocator.clone())
201 .with_types(processing.typemap.clone());
202 expand_validate_index_assign(
203 &mut scope,
204 op.index,
205 op.value,
206 out,
207 op.unroll_factor,
208 &self.kernel_name,
209 );
210
211 let tmp_processing = scope.process([]);
212
213 for inst in tmp_processing.instructions {
214 processing.instructions.push(inst);
215 }
216 for var in tmp_processing.variables {
217 processing.variables.push(var);
218 }
219
220 continue;
221 }
222 }
223 _ => {}
224 }
225 }
226
227 processing.instructions.push(instruction);
229 }
230 processing
231 }
232}