cubecl_core/post_processing/
checked_io.rs1use alloc::vec::Vec;
2use cubecl_ir::{Allocator, ExpandElement, Instruction, Operation, Operator, Processor, Scope};
3use cubecl_runtime::server::ExecutionMode;
4
5use crate::{
6 io::{read_tensor_atomic_checked, read_tensor_checked},
7 prelude::{Line, NumericExpand, expand_checked_index_assign},
8};
9
10#[derive(new, Debug)]
11pub struct CheckedIoProcessor {
12 mode: ExecutionMode,
13}
14
15impl Processor for CheckedIoProcessor {
16 fn transform(
17 &self,
18 mut processing: cubecl_ir::ScopeProcessing,
19 allocator: Allocator,
20 ) -> cubecl_ir::ScopeProcessing {
21 if matches!(self.mode, ExecutionMode::Unchecked) {
22 return processing;
23 }
24
25 let mut instructions = Vec::new();
26 core::mem::swap(&mut processing.instructions, &mut instructions);
27
28 for instruction in instructions {
29 if let Operation::Operator(operator) = &instruction.operation {
30 match operator {
31 Operator::Index(op) => {
32 let has_length = op.list.has_length();
33
34 if has_length {
35 let list = ExpandElement::Plain(op.list);
36 let index = ExpandElement::Plain(op.index);
37 let mut scope = Scope::root(false)
38 .with_allocator(allocator.clone())
39 .with_types(processing.typemap.clone());
40 scope.register_type::<NumericExpand<0>>(op.list.storage_type());
41
42 let input = if op.list.ty.is_atomic() {
43 read_tensor_atomic_checked::expand::<NumericExpand<0>>(
48 &mut scope,
49 list.into(),
50 index.into(),
51 op.unroll_factor,
52 )
53 .expand
54 } else {
55 read_tensor_checked::expand::<Line<NumericExpand<0>>>(
56 &mut scope,
57 list.into(),
58 index.into(),
59 op.unroll_factor,
60 )
61 .expand
62 };
63 let tmp_processing = scope.process([]);
64
65 for inst in tmp_processing.instructions {
66 processing.instructions.push(inst);
67 }
68 for var in tmp_processing.variables {
69 processing.variables.push(var);
70 }
71
72 processing
73 .instructions
74 .push(Instruction::new(Operation::Copy(*input), instruction.out()));
75 continue;
76 }
77 }
78 Operator::IndexAssign(op) => {
79 let out = instruction.out();
80
81 if out.has_length() {
82 let mut scope = Scope::root(false)
83 .with_allocator(allocator.clone())
84 .with_types(processing.typemap.clone());
85 expand_checked_index_assign(
86 &mut scope,
87 op.index,
88 op.value,
89 out,
90 op.unroll_factor,
91 );
92
93 let tmp_processing = scope.process([]);
94
95 for inst in tmp_processing.instructions {
96 processing.instructions.push(inst);
97 }
98 for var in tmp_processing.variables {
99 processing.variables.push(var);
100 }
101
102 continue;
103 }
104 }
105 _ => {}
106 }
107 }
108
109 processing.instructions.push(instruction);
111 }
112 processing
113 }
114}