cubecl_core/post_processing/
checked_io.rs1use cubecl_ir::{Allocator, ExpandElement, Instruction, Operation, Operator, Processor, Scope};
2use cubecl_runtime::server::ExecutionMode;
3
4use crate::{
5 io::{read_tensor_atomic_checked, read_tensor_checked},
6 prelude::{Line, NumericExpand, expand_checked_index_assign},
7};
8
9#[derive(new, Debug)]
10pub struct CheckedIoProcessor {
11 mode: ExecutionMode,
12}
13
14impl Processor for CheckedIoProcessor {
15 fn transform(
16 &self,
17 mut processing: cubecl_ir::ScopeProcessing,
18 allocator: Allocator,
19 ) -> cubecl_ir::ScopeProcessing {
20 if matches!(self.mode, ExecutionMode::Unchecked) {
21 return processing;
22 }
23
24 let mut instructions = Vec::new();
25 core::mem::swap(&mut processing.instructions, &mut instructions);
26
27 for instruction in instructions {
28 if let Operation::Operator(operator) = &instruction.operation {
29 match operator {
30 Operator::Index(op) => {
31 let has_length = op.list.has_length();
32
33 if has_length {
34 let list = ExpandElement::Plain(op.list);
35 let index = ExpandElement::Plain(op.index);
36 let mut scope = Scope::root(false)
37 .with_allocator(allocator.clone())
38 .with_types(processing.typemap.clone());
39 scope.register_type::<NumericExpand<0>>(op.list.storage_type());
40
41 let input = if op.list.ty.is_atomic() {
42 read_tensor_atomic_checked::expand::<NumericExpand<0>>(
47 &mut scope,
48 list.into(),
49 index.into(),
50 op.unroll_factor,
51 )
52 .expand
53 } else {
54 read_tensor_checked::expand::<Line<NumericExpand<0>>>(
55 &mut scope,
56 list.into(),
57 index.into(),
58 op.unroll_factor,
59 )
60 .expand
61 };
62 let tmp_processing = scope.process([]);
63
64 for inst in tmp_processing.instructions {
65 processing.instructions.push(inst);
66 }
67 for var in tmp_processing.variables {
68 processing.variables.push(var);
69 }
70
71 processing
72 .instructions
73 .push(Instruction::new(Operation::Copy(*input), instruction.out()));
74 continue;
75 }
76 }
77 Operator::IndexAssign(op) => {
78 let out = instruction.out();
79
80 if out.has_length() {
81 let mut scope = Scope::root(false)
82 .with_allocator(allocator.clone())
83 .with_types(processing.typemap.clone());
84 expand_checked_index_assign(
85 &mut scope,
86 op.index,
87 op.value,
88 out,
89 op.unroll_factor,
90 );
91
92 let tmp_processing = scope.process([]);
93
94 for inst in tmp_processing.instructions {
95 processing.instructions.push(inst);
96 }
97 for var in tmp_processing.variables {
98 processing.variables.push(var);
99 }
100
101 continue;
102 }
103 }
104 _ => {}
105 }
106 }
107
108 processing.instructions.push(instruction);
110 }
111 processing
112 }
113}