cubecl_core/post_processing/
checked_io.rs

1use cubecl_common::ExecutionMode;
2use cubecl_ir::{Allocator, ExpandElement, Instruction, Operation, Operator, Processor, Scope};
3
4use crate::{
5    io::{read_tensor_atomic_checked, read_tensor_checked},
6    prelude::{Line, NumericExpand, expand_checked_index_assign},
7};
8
9#[derive(new, Debug)]
10pub struct CheckedIoProcessor {
11    mode: ExecutionMode,
12}
13
14impl Processor for CheckedIoProcessor {
15    fn transform(
16        &self,
17        mut processing: cubecl_ir::ScopeProcessing,
18        allocator: Allocator,
19    ) -> cubecl_ir::ScopeProcessing {
20        if matches!(self.mode, ExecutionMode::Unchecked) {
21            return processing;
22        }
23
24        let mut instructions = Vec::new();
25        core::mem::swap(&mut processing.instructions, &mut instructions);
26
27        for instruction in instructions {
28            if let Operation::Operator(operator) = &instruction.operation {
29                match operator {
30                    Operator::Index(op) => {
31                        let has_length = op.list.has_length();
32
33                        if has_length {
34                            let list = ExpandElement::Plain(op.list);
35                            let index = ExpandElement::Plain(op.index);
36                            let mut scope = Scope::root(false).with_allocator(allocator.clone());
37                            scope.register_type::<NumericExpand<0>>(op.list.storage_type());
38
39                            let input = if op.list.ty.is_atomic() {
40                                // Atomic can't really be checked, since the pointer needs to be
41                                // valid, so the kernel will probably not output the correct value if
42                                // not manually checked later, but will at least avoid out-of-bounds
43                                // memory access.
44                                read_tensor_atomic_checked::expand::<NumericExpand<0>>(
45                                    &mut scope,
46                                    list.into(),
47                                    index.into(),
48                                    op.unroll_factor,
49                                )
50                                .expand
51                            } else {
52                                read_tensor_checked::expand::<Line<NumericExpand<0>>>(
53                                    &mut scope,
54                                    list.into(),
55                                    index.into(),
56                                    op.unroll_factor,
57                                )
58                                .expand
59                            };
60                            let tmp_processing = scope.process([]);
61
62                            for inst in tmp_processing.instructions {
63                                processing.instructions.push(inst);
64                            }
65                            for var in tmp_processing.variables {
66                                processing.variables.push(var);
67                            }
68
69                            processing
70                                .instructions
71                                .push(Instruction::new(Operation::Copy(*input), instruction.out()));
72                            continue;
73                        }
74                    }
75                    Operator::IndexAssign(op) => {
76                        let out = instruction.out();
77
78                        if out.has_length() {
79                            let mut scope = Scope::root(false).with_allocator(allocator.clone());
80                            expand_checked_index_assign(
81                                &mut scope,
82                                op.index,
83                                op.value,
84                                out,
85                                op.unroll_factor,
86                            );
87
88                            let tmp_processing = scope.process([]);
89
90                            for inst in tmp_processing.instructions {
91                                processing.instructions.push(inst);
92                            }
93                            for var in tmp_processing.variables {
94                                processing.variables.push(var);
95                            }
96
97                            continue;
98                        }
99                    }
100                    _ => {}
101                }
102            }
103
104            // When we have nothing to do.
105            processing.instructions.push(instruction);
106        }
107        processing
108    }
109}