cubecl_core/post_processing/
checked_io.rs

1use cubecl_ir::{Allocator, ExpandElement, Instruction, Operation, Operator, Processor, Scope};
2use cubecl_runtime::server::ExecutionMode;
3
4use crate::{
5    io::{read_tensor_atomic_checked, read_tensor_checked},
6    prelude::{Line, NumericExpand, expand_checked_index_assign},
7};
8
9#[derive(new, Debug)]
10pub struct CheckedIoProcessor {
11    mode: ExecutionMode,
12}
13
14impl Processor for CheckedIoProcessor {
15    fn transform(
16        &self,
17        mut processing: cubecl_ir::ScopeProcessing,
18        allocator: Allocator,
19    ) -> cubecl_ir::ScopeProcessing {
20        if matches!(self.mode, ExecutionMode::Unchecked) {
21            return processing;
22        }
23
24        let mut instructions = Vec::new();
25        core::mem::swap(&mut processing.instructions, &mut instructions);
26
27        for instruction in instructions {
28            if let Operation::Operator(operator) = &instruction.operation {
29                match operator {
30                    Operator::Index(op) => {
31                        let has_length = op.list.has_length();
32
33                        if has_length {
34                            let list = ExpandElement::Plain(op.list);
35                            let index = ExpandElement::Plain(op.index);
36                            let mut scope = Scope::root(false)
37                                .with_allocator(allocator.clone())
38                                .with_types(processing.typemap.clone());
39                            scope.register_type::<NumericExpand<0>>(op.list.storage_type());
40
41                            let input = if op.list.ty.is_atomic() {
42                                // Atomic can't really be checked, since the pointer needs to be
43                                // valid, so the kernel will probably not output the correct value if
44                                // not manually checked later, but will at least avoid out-of-bounds
45                                // memory access.
46                                read_tensor_atomic_checked::expand::<NumericExpand<0>>(
47                                    &mut scope,
48                                    list.into(),
49                                    index.into(),
50                                    op.unroll_factor,
51                                )
52                                .expand
53                            } else {
54                                read_tensor_checked::expand::<Line<NumericExpand<0>>>(
55                                    &mut scope,
56                                    list.into(),
57                                    index.into(),
58                                    op.unroll_factor,
59                                )
60                                .expand
61                            };
62                            let tmp_processing = scope.process([]);
63
64                            for inst in tmp_processing.instructions {
65                                processing.instructions.push(inst);
66                            }
67                            for var in tmp_processing.variables {
68                                processing.variables.push(var);
69                            }
70
71                            processing
72                                .instructions
73                                .push(Instruction::new(Operation::Copy(*input), instruction.out()));
74                            continue;
75                        }
76                    }
77                    Operator::IndexAssign(op) => {
78                        let out = instruction.out();
79
80                        if out.has_length() {
81                            let mut scope = Scope::root(false)
82                                .with_allocator(allocator.clone())
83                                .with_types(processing.typemap.clone());
84                            expand_checked_index_assign(
85                                &mut scope,
86                                op.index,
87                                op.value,
88                                out,
89                                op.unroll_factor,
90                            );
91
92                            let tmp_processing = scope.process([]);
93
94                            for inst in tmp_processing.instructions {
95                                processing.instructions.push(inst);
96                            }
97                            for var in tmp_processing.variables {
98                                processing.variables.push(var);
99                            }
100
101                            continue;
102                        }
103                    }
104                    _ => {}
105                }
106            }
107
108            // When we have nothing to do.
109            processing.instructions.push(instruction);
110        }
111        processing
112    }
113}