cubecl_core/post_processing/
checked_io.rs

1use cubecl_common::ExecutionMode;
2use cubecl_ir::{Allocator, ExpandElement, Instruction, Operation, Operator, Processor, Scope};
3
4use crate::{
5    io::{read_tensor_atomic_checked, read_tensor_checked},
6    prelude::{Line, NumericExpand, expand_checked_index_assign},
7};
8
9#[derive(new)]
10pub struct CheckedIoProcessor {
11    mode: ExecutionMode,
12}
13
14impl Processor for CheckedIoProcessor {
15    fn transform(
16        &self,
17        mut processing: cubecl_ir::ScopeProcessing,
18        allocator: Allocator,
19    ) -> cubecl_ir::ScopeProcessing {
20        if matches!(self.mode, ExecutionMode::Unchecked) {
21            return processing;
22        }
23
24        let mut instructions = Vec::new();
25        core::mem::swap(&mut processing.instructions, &mut instructions);
26
27        for instruction in instructions {
28            if let Operation::Operator(operator) = &instruction.operation {
29                match operator {
30                    Operator::Index(op) => {
31                        let has_length = op.list.has_length();
32
33                        if has_length {
34                            let list = ExpandElement::Plain(op.list);
35                            let index = ExpandElement::Plain(op.index);
36                            let mut scope = Scope::root(false).with_allocator(allocator.clone());
37                            scope.register_elem::<NumericExpand<0>>(op.list.elem());
38
39                            let input = if op.list.elem().is_atomic() {
40                                // Atomic can't really be checked, since the pointer needs to be
41                                // valid, so the kernel will probably not output the correct value if
42                                // not manually checked later, but will at least avoid out-of-bounds
43                                // memory access.
44                                read_tensor_atomic_checked::expand::<NumericExpand<0>>(
45                                    &mut scope,
46                                    list.into(),
47                                    index.into(),
48                                )
49                                .expand
50                            } else {
51                                read_tensor_checked::expand::<Line<NumericExpand<0>>>(
52                                    &mut scope,
53                                    list.into(),
54                                    index.into(),
55                                )
56                                .expand
57                            };
58                            let tmp_processing = scope.process([]);
59
60                            for inst in tmp_processing.instructions {
61                                processing.instructions.push(inst);
62                            }
63                            for var in tmp_processing.variables {
64                                processing.variables.push(var);
65                            }
66
67                            processing
68                                .instructions
69                                .push(Instruction::new(Operation::Copy(*input), instruction.out()));
70                            continue;
71                        }
72                    }
73                    Operator::IndexAssign(op) => {
74                        let out = instruction.out();
75
76                        if out.has_length() {
77                            let mut scope = Scope::root(false).with_allocator(allocator.clone());
78                            expand_checked_index_assign(&mut scope, op.index, op.value, out);
79
80                            let tmp_processing = scope.process([]);
81
82                            for inst in tmp_processing.instructions {
83                                processing.instructions.push(inst);
84                            }
85                            for var in tmp_processing.variables {
86                                processing.variables.push(var);
87                            }
88
89                            continue;
90                        }
91                    }
92                    _ => {}
93                }
94            }
95
96            // When we have nothing to do.
97            processing.instructions.push(instruction);
98        }
99        processing
100    }
101}