cubecl_core/post_processing/
checked_io.rs1use cubecl_common::ExecutionMode;
2use cubecl_ir::{Allocator, ExpandElement, Instruction, Operation, Operator, Processor, Scope};
3
4use crate::{
5 io::{read_tensor_atomic_checked, read_tensor_checked},
6 prelude::{Line, NumericExpand, expand_checked_index_assign},
7};
8
9#[derive(new)]
10pub struct CheckedIoProcessor {
11 mode: ExecutionMode,
12}
13
14impl Processor for CheckedIoProcessor {
15 fn transform(
16 &self,
17 mut processing: cubecl_ir::ScopeProcessing,
18 allocator: Allocator,
19 ) -> cubecl_ir::ScopeProcessing {
20 if matches!(self.mode, ExecutionMode::Unchecked) {
21 return processing;
22 }
23
24 let mut instructions = Vec::new();
25 core::mem::swap(&mut processing.instructions, &mut instructions);
26
27 for instruction in instructions {
28 if let Operation::Operator(operator) = &instruction.operation {
29 match operator {
30 Operator::Index(op) => {
31 let has_length = op.list.has_length();
32
33 if has_length {
34 let list = ExpandElement::Plain(op.list);
35 let index = ExpandElement::Plain(op.index);
36 let mut scope = Scope::root(false).with_allocator(allocator.clone());
37 scope.register_elem::<NumericExpand<0>>(op.list.elem());
38
39 let input = if op.list.elem().is_atomic() {
40 read_tensor_atomic_checked::expand::<NumericExpand<0>>(
45 &mut scope,
46 list.into(),
47 index.into(),
48 )
49 .expand
50 } else {
51 read_tensor_checked::expand::<Line<NumericExpand<0>>>(
52 &mut scope,
53 list.into(),
54 index.into(),
55 )
56 .expand
57 };
58 let tmp_processing = scope.process([]);
59
60 for inst in tmp_processing.instructions {
61 processing.instructions.push(inst);
62 }
63 for var in tmp_processing.variables {
64 processing.variables.push(var);
65 }
66
67 processing
68 .instructions
69 .push(Instruction::new(Operation::Copy(*input), instruction.out()));
70 continue;
71 }
72 }
73 Operator::IndexAssign(op) => {
74 let out = instruction.out();
75
76 if out.has_length() {
77 let mut scope = Scope::root(false).with_allocator(allocator.clone());
78 expand_checked_index_assign(&mut scope, op.index, op.value, out);
79
80 let tmp_processing = scope.process([]);
81
82 for inst in tmp_processing.instructions {
83 processing.instructions.push(inst);
84 }
85 for var in tmp_processing.variables {
86 processing.variables.push(var);
87 }
88
89 continue;
90 }
91 }
92 _ => {}
93 }
94 }
95
96 processing.instructions.push(instruction);
98 }
99 processing
100 }
101}