1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
use cubecl_common::ExecutionMode;
use cubecl_ir::{Allocator, ExpandElement, Instruction, Operation, Operator, Processor, Scope};
use crate::{
io::{read_tensor_atomic_checked, read_tensor_checked},
prelude::{Line, NumericExpand, expand_checked_index_assign},
};
#[derive(new, Debug)]
pub struct CheckedIoProcessor {
mode: ExecutionMode,
}
impl Processor for CheckedIoProcessor {
fn transform(
&self,
mut processing: cubecl_ir::ScopeProcessing,
allocator: Allocator,
) -> cubecl_ir::ScopeProcessing {
if matches!(self.mode, ExecutionMode::Unchecked) {
return processing;
}
let mut instructions = Vec::new();
core::mem::swap(&mut processing.instructions, &mut instructions);
for instruction in instructions {
if let Operation::Operator(operator) = &instruction.operation {
match operator {
Operator::Index(op) => {
let has_length = op.list.has_length();
if has_length {
let list = ExpandElement::Plain(op.list);
let index = ExpandElement::Plain(op.index);
let mut scope = Scope::root(false).with_allocator(allocator.clone());
scope.register_type::<NumericExpand<0>>(op.list.storage_type());
let input = if op.list.ty.is_atomic() {
// Atomic can't really be checked, since the pointer needs to be
// valid, so the kernel will probably not output the correct value if
// not manually checked later, but will at least avoid out-of-bounds
// memory access.
read_tensor_atomic_checked::expand::<NumericExpand<0>>(
&mut scope,
list.into(),
index.into(),
op.unroll_factor,
)
.expand
} else {
read_tensor_checked::expand::<Line<NumericExpand<0>>>(
&mut scope,
list.into(),
index.into(),
op.unroll_factor,
)
.expand
};
let tmp_processing = scope.process([]);
for inst in tmp_processing.instructions {
processing.instructions.push(inst);
}
for var in tmp_processing.variables {
processing.variables.push(var);
}
processing
.instructions
.push(Instruction::new(Operation::Copy(*input), instruction.out()));
continue;
}
}
Operator::IndexAssign(op) => {
let out = instruction.out();
if out.has_length() {
let mut scope = Scope::root(false).with_allocator(allocator.clone());
expand_checked_index_assign(
&mut scope,
op.index,
op.value,
out,
op.unroll_factor,
);
let tmp_processing = scope.process([]);
for inst in tmp_processing.instructions {
processing.instructions.push(inst);
}
for var in tmp_processing.variables {
processing.variables.push(var);
}
continue;
}
}
_ => {}
}
}
// When we have nothing to do.
processing.instructions.push(instruction);
}
processing
}
}