Skip to main content

cubecl_core/post_processing/
checked_io.rs

1use alloc::vec::Vec;
2use cubecl_ir::{Allocator, ExpandElement, Instruction, Operation, Operator, Processor, Scope};
3use cubecl_runtime::server::ExecutionMode;
4
5use crate::{
6    io::{read_tensor_atomic_checked, read_tensor_checked},
7    prelude::{Line, NumericExpand, expand_checked_index_assign},
8};
9
10#[derive(new, Debug)]
11pub struct CheckedIoProcessor {
12    mode: ExecutionMode,
13}
14
15impl Processor for CheckedIoProcessor {
16    fn transform(
17        &self,
18        mut processing: cubecl_ir::ScopeProcessing,
19        allocator: Allocator,
20    ) -> cubecl_ir::ScopeProcessing {
21        if matches!(self.mode, ExecutionMode::Unchecked) {
22            return processing;
23        }
24
25        let mut instructions = Vec::new();
26        core::mem::swap(&mut processing.instructions, &mut instructions);
27
28        for instruction in instructions {
29            if let Operation::Operator(operator) = &instruction.operation {
30                match operator {
31                    Operator::Index(op) => {
32                        let has_length = op.list.has_length();
33
34                        if has_length {
35                            let list = ExpandElement::Plain(op.list);
36                            let index = ExpandElement::Plain(op.index);
37                            let mut scope = Scope::root(false)
38                                .with_allocator(allocator.clone())
39                                .with_types(processing.typemap.clone());
40                            scope.register_type::<NumericExpand<0>>(op.list.storage_type());
41
42                            let input = if op.list.ty.is_atomic() {
43                                // Atomic can't really be checked, since the pointer needs to be
44                                // valid, so the kernel will probably not output the correct value if
45                                // not manually checked later, but will at least avoid out-of-bounds
46                                // memory access.
47                                read_tensor_atomic_checked::expand::<NumericExpand<0>>(
48                                    &mut scope,
49                                    list.into(),
50                                    index.into(),
51                                    op.unroll_factor,
52                                )
53                                .expand
54                            } else {
55                                read_tensor_checked::expand::<Line<NumericExpand<0>>>(
56                                    &mut scope,
57                                    list.into(),
58                                    index.into(),
59                                    op.unroll_factor,
60                                )
61                                .expand
62                            };
63                            let tmp_processing = scope.process([]);
64
65                            for inst in tmp_processing.instructions {
66                                processing.instructions.push(inst);
67                            }
68                            for var in tmp_processing.variables {
69                                processing.variables.push(var);
70                            }
71
72                            processing
73                                .instructions
74                                .push(Instruction::new(Operation::Copy(*input), instruction.out()));
75                            continue;
76                        }
77                    }
78                    Operator::IndexAssign(op) => {
79                        let out = instruction.out();
80
81                        if out.has_length() {
82                            let mut scope = Scope::root(false)
83                                .with_allocator(allocator.clone())
84                                .with_types(processing.typemap.clone());
85                            expand_checked_index_assign(
86                                &mut scope,
87                                op.index,
88                                op.value,
89                                out,
90                                op.unroll_factor,
91                            );
92
93                            let tmp_processing = scope.process([]);
94
95                            for inst in tmp_processing.instructions {
96                                processing.instructions.push(inst);
97                            }
98                            for var in tmp_processing.variables {
99                                processing.variables.push(var);
100                            }
101
102                            continue;
103                        }
104                    }
105                    _ => {}
106                }
107            }
108
109            // When we have nothing to do.
110            processing.instructions.push(instruction);
111        }
112        processing
113    }
114}