cubecl_core/post_processing/
checked_io.rs1use cubecl_common::ExecutionMode;
2use cubecl_ir::{Allocator, ExpandElement, Instruction, Operation, Operator, Processor, Scope};
3
4use crate::{
5 io::{read_tensor_atomic_checked, read_tensor_checked},
6 prelude::{Line, NumericExpand, expand_checked_index_assign},
7};
8
9#[derive(new, Debug)]
10pub struct CheckedIoProcessor {
11 mode: ExecutionMode,
12}
13
14impl Processor for CheckedIoProcessor {
15 fn transform(
16 &self,
17 mut processing: cubecl_ir::ScopeProcessing,
18 allocator: Allocator,
19 ) -> cubecl_ir::ScopeProcessing {
20 if matches!(self.mode, ExecutionMode::Unchecked) {
21 return processing;
22 }
23
24 let mut instructions = Vec::new();
25 core::mem::swap(&mut processing.instructions, &mut instructions);
26
27 for instruction in instructions {
28 if let Operation::Operator(operator) = &instruction.operation {
29 match operator {
30 Operator::Index(op) => {
31 let has_length = op.list.has_length();
32
33 if has_length {
34 let list = ExpandElement::Plain(op.list);
35 let index = ExpandElement::Plain(op.index);
36 let mut scope = Scope::root(false).with_allocator(allocator.clone());
37 scope.register_type::<NumericExpand<0>>(op.list.storage_type());
38
39 let input = if op.list.ty.is_atomic() {
40 read_tensor_atomic_checked::expand::<NumericExpand<0>>(
45 &mut scope,
46 list.into(),
47 index.into(),
48 op.unroll_factor,
49 )
50 .expand
51 } else {
52 read_tensor_checked::expand::<Line<NumericExpand<0>>>(
53 &mut scope,
54 list.into(),
55 index.into(),
56 op.unroll_factor,
57 )
58 .expand
59 };
60 let tmp_processing = scope.process([]);
61
62 for inst in tmp_processing.instructions {
63 processing.instructions.push(inst);
64 }
65 for var in tmp_processing.variables {
66 processing.variables.push(var);
67 }
68
69 processing
70 .instructions
71 .push(Instruction::new(Operation::Copy(*input), instruction.out()));
72 continue;
73 }
74 }
75 Operator::IndexAssign(op) => {
76 let out = instruction.out();
77
78 if out.has_length() {
79 let mut scope = Scope::root(false).with_allocator(allocator.clone());
80 expand_checked_index_assign(
81 &mut scope,
82 op.index,
83 op.value,
84 out,
85 op.unroll_factor,
86 );
87
88 let tmp_processing = scope.process([]);
89
90 for inst in tmp_processing.instructions {
91 processing.instructions.push(inst);
92 }
93 for var in tmp_processing.variables {
94 processing.variables.push(var);
95 }
96
97 continue;
98 }
99 }
100 _ => {}
101 }
102 }
103
104 processing.instructions.push(instruction);
106 }
107 processing
108 }
109}