Skip to main content

cubecl_core/post_processing/
checked_io.rs

1use alloc::{string::String, vec::Vec};
2use cubecl_ir::{Allocator, Instruction, ManagedVariable, Operation, Operator, Processor, Scope};
3use cubecl_runtime::server::ExecutionMode;
4
5use crate::{
6    define_scalar, define_size,
7    io::{
8        expand_checked_index_assign, expand_validate_index_assign, read_tensor_atomic_checked,
9        read_tensor_atomic_validate, read_tensor_checked, read_tensor_validate,
10    },
11    prelude::Vector,
12};
13
14define_scalar!(ElemA);
15define_size!(SizeA);
16
17#[derive(new, Debug)]
18pub struct CheckedIoProcessor {
19    mode: ExecutionMode,
20    kernel_name: String,
21}
22
23impl Processor for CheckedIoProcessor {
24    fn transform(
25        &self,
26        processing: cubecl_ir::ScopeProcessing,
27        allocator: Allocator,
28    ) -> cubecl_ir::ScopeProcessing {
29        match self.mode {
30            ExecutionMode::Checked => self.transform_checked(processing, allocator),
31            ExecutionMode::Unchecked => processing,
32            ExecutionMode::Validate => self.transform_validate(processing, allocator),
33        }
34    }
35}
36
37impl CheckedIoProcessor {
38    fn transform_checked(
39        &self,
40        mut processing: cubecl_ir::ScopeProcessing,
41        allocator: Allocator,
42    ) -> cubecl_ir::ScopeProcessing {
43        let mut instructions = Vec::new();
44        core::mem::swap(&mut processing.instructions, &mut instructions);
45
46        for instruction in instructions {
47            if let Operation::Operator(operator) = &instruction.operation {
48                match operator {
49                    Operator::Index(op) => {
50                        let has_length = op.list.has_length();
51
52                        if has_length {
53                            let list = ManagedVariable::Plain(op.list);
54                            let index = ManagedVariable::Plain(op.index);
55                            let mut scope = Scope::root(false)
56                                .with_allocator(allocator.clone())
57                                .with_types(processing.typemap.clone());
58                            scope.register_type::<ElemA>(op.list.storage_type());
59                            scope.register_size::<SizeA>(op.list.vector_size());
60
61                            let input = if op.list.ty.is_atomic() {
62                                // Atomic can't really be checked, since the pointer needs to be
63                                // valid, so the kernel will probably not output the correct value if
64                                // not manually checked later, but will at least avoid out-of-bounds
65                                // memory access.
66                                read_tensor_atomic_checked::expand::<ElemA>(
67                                    &mut scope,
68                                    list.into(),
69                                    index.into(),
70                                    op.unroll_factor,
71                                )
72                                .expand
73                            } else {
74                                read_tensor_checked::expand::<Vector<ElemA, SizeA>>(
75                                    &mut scope,
76                                    list.into(),
77                                    index.into(),
78                                    op.unroll_factor,
79                                )
80                                .expand
81                            };
82                            let tmp_processing = scope.process([]);
83
84                            for inst in tmp_processing.instructions {
85                                processing.instructions.push(inst);
86                            }
87                            for var in tmp_processing.variables {
88                                processing.variables.push(var);
89                            }
90
91                            processing
92                                .instructions
93                                .push(Instruction::new(Operation::Copy(*input), instruction.out()));
94                            continue;
95                        }
96                    }
97                    Operator::IndexAssign(op) => {
98                        let out = instruction.out();
99
100                        if out.has_length() {
101                            let mut scope = Scope::root(false)
102                                .with_allocator(allocator.clone())
103                                .with_types(processing.typemap.clone());
104                            expand_checked_index_assign(
105                                &mut scope,
106                                op.index,
107                                op.value,
108                                out,
109                                op.unroll_factor,
110                            );
111
112                            let tmp_processing = scope.process([]);
113
114                            for inst in tmp_processing.instructions {
115                                processing.instructions.push(inst);
116                            }
117                            for var in tmp_processing.variables {
118                                processing.variables.push(var);
119                            }
120
121                            continue;
122                        }
123                    }
124                    _ => {}
125                }
126            }
127
128            // When we have nothing to do.
129            processing.instructions.push(instruction);
130        }
131        processing
132    }
133
134    fn transform_validate(
135        &self,
136        mut processing: cubecl_ir::ScopeProcessing,
137        allocator: Allocator,
138    ) -> cubecl_ir::ScopeProcessing {
139        let mut instructions = Vec::new();
140        core::mem::swap(&mut processing.instructions, &mut instructions);
141
142        for instruction in instructions {
143            if let Operation::Operator(operator) = &instruction.operation {
144                match operator {
145                    Operator::Index(op) => {
146                        let has_length = op.list.has_length();
147
148                        if has_length {
149                            let list = ManagedVariable::Plain(op.list);
150                            let index = ManagedVariable::Plain(op.index);
151                            let mut scope = Scope::root(false)
152                                .with_allocator(allocator.clone())
153                                .with_types(processing.typemap.clone());
154                            scope.register_type::<ElemA>(op.list.storage_type());
155                            scope.register_size::<SizeA>(op.list.vector_size());
156
157                            let input = if op.list.ty.is_atomic() {
158                                // Atomic can't really be checked, since the pointer needs to be
159                                // valid, so the kernel will probably not output the correct value if
160                                // not manually checked later, but will at least avoid out-of-bounds
161                                // memory access.
162                                read_tensor_atomic_validate::expand::<ElemA>(
163                                    &mut scope,
164                                    list.into(),
165                                    index.into(),
166                                    op.unroll_factor,
167                                    self.kernel_name.clone(),
168                                )
169                                .expand
170                            } else {
171                                read_tensor_validate::expand::<Vector<ElemA, SizeA>>(
172                                    &mut scope,
173                                    list.into(),
174                                    index.into(),
175                                    op.unroll_factor,
176                                    self.kernel_name.clone(),
177                                )
178                                .expand
179                            };
180                            let tmp_processing = scope.process([]);
181
182                            for inst in tmp_processing.instructions {
183                                processing.instructions.push(inst);
184                            }
185                            for var in tmp_processing.variables {
186                                processing.variables.push(var);
187                            }
188
189                            processing
190                                .instructions
191                                .push(Instruction::new(Operation::Copy(*input), instruction.out()));
192                            continue;
193                        }
194                    }
195                    Operator::IndexAssign(op) => {
196                        let out = instruction.out();
197
198                        if out.has_length() {
199                            let mut scope = Scope::root(false)
200                                .with_allocator(allocator.clone())
201                                .with_types(processing.typemap.clone());
202                            expand_validate_index_assign(
203                                &mut scope,
204                                op.index,
205                                op.value,
206                                out,
207                                op.unroll_factor,
208                                &self.kernel_name,
209                            );
210
211                            let tmp_processing = scope.process([]);
212
213                            for inst in tmp_processing.instructions {
214                                processing.instructions.push(inst);
215                            }
216                            for var in tmp_processing.variables {
217                                processing.variables.push(var);
218                            }
219
220                            continue;
221                        }
222                    }
223                    _ => {}
224                }
225            }
226
227            // When we have nothing to do.
228            processing.instructions.push(instruction);
229        }
230        processing
231    }
232}