1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
use crate::{LineMode, launch::LineSizeStrategy};
use cubecl::{
prelude::*, std::tensor::is_contiguous, tensor_line_size_parallel,
tensor_line_size_perpendicular,
};
/// Calculate the number of planes in a cube.
pub fn calculate_plane_count_per_cube(
working_units: usize,
plane_dim: u32,
num_cpu_cores: Option<u32>,
) -> u32 {
match num_cpu_cores {
Some(num_cores) => core::cmp::min(num_cores, working_units as u32),
None => {
let plane_count_max = core::cmp::max(1, working_units / plane_dim as usize);
// Ensures `plane_count` is a power of 2.
const NUM_PLANE_MAX: u32 = 8u32;
const NUM_PLANE_MAX_LOG2: u32 = NUM_PLANE_MAX.ilog2();
let plane_count_max_log2 =
core::cmp::min(NUM_PLANE_MAX_LOG2, usize::ilog2(plane_count_max));
2u32.pow(plane_count_max_log2)
}
}
}
pub fn generate_line_size<R: Runtime>(
client: &ComputeClient<R>,
input: &TensorHandleRef<R>,
output: &TensorHandleRef<R>,
axis: usize,
dtype: StorageType,
line_mode: LineMode,
strategy: &LineSizeStrategy,
) -> (usize, usize) {
let line_size_input = match line_mode {
LineMode::Parallel => tensor_line_size_parallel(
client.io_optimized_line_sizes(dtype.size()),
input.shape,
input.strides,
axis,
),
LineMode::Perpendicular => {
// To compute the maximum line size we can used,
// we first sort both the input and output axes by increasing strides.
// As example, consider
// input shape = [2, 4, 6, 8]
// input stride = [1, 16, 64, 2]
// output shape = [2, 1, 6, 8]
// output stride = [1, 1, 2, 12]
// axis = 1
//
// then we have
// input sorted axis = [0, 3, 1, 2]
// output sorted axis = [0, 1, 2, 3]
//
// From that point, we look at all the axes before the target axis in the sorted input.
// That is [0, 3] in the example.
// In the output, we remove the target axis leading to [0, 2, 3] in the example.
//
// In order to use perpendicular line, we are limited by the number of entries that are both
// contiguous in the input and output. This is obtained by taking the head of each list until they are different.
// In the above example, only the 0 axis is contiguous in both tensor, but it output sorted axis were [0, 1, 3, 2] instead,
// both the 0 and 3 axes would be contiguous in the two tensors.
// The corresponding number of entries is the product of the shape for the contiguous axes.
// In the example, it is simply 2.
//
// This gives us an upper bound on the line size we can used.
// Then, we use the regular method to find the best line size that match the device capacities.
let mut input_axis_and_strides = input.strides.iter().enumerate().collect::<Vec<_>>();
input_axis_and_strides.sort_by_key(|(_, stride)| *stride);
let input_sorted_axis = input_axis_and_strides
.into_iter()
.map(|(a, _)| a)
.take_while(|a| *a != axis);
let mut output_axis_and_strides = output.strides.iter().enumerate().collect::<Vec<_>>();
output_axis_and_strides.sort_by_key(|(_, stride)| *stride);
let output_sorted_axis = output_axis_and_strides
.into_iter()
.filter_map(|(a, _)| (a != axis).then_some(a));
let max_line_size = input_sorted_axis
.zip(output_sorted_axis)
.filter_map(|(i, o)| (i == o).then_some(output.shape[i]))
.product();
match client.properties().hardware.num_cpu_cores.is_some() {
true => {
// On CPU we benefit from bigger line size, which increases the number of
// consecutive loads from global memory on perpendicular reduce.
// R::supported_line_sizes() was always arbitrary, review this and find alternate
// algorithm. For now it replicates existing behaviour.
let supported_line_sizes = client.io_optimized_line_sizes(1).filter(|size| {
*size <= max_line_size && max_line_size.is_multiple_of(*size)
});
tensor_line_size_perpendicular(
supported_line_sizes,
input.shape,
input.strides,
axis,
)
}
false => {
let supported_line_sizes =
client
.io_optimized_line_sizes(dtype.size())
.filter(|&size| {
size <= max_line_size && max_line_size.is_multiple_of(size)
});
tensor_line_size_perpendicular(
supported_line_sizes,
input.shape,
input.strides,
axis,
)
}
}
}
};
let mut line_size_output = 1;
if line_size_input > 1 && line_mode == LineMode::Perpendicular {
// TODO that this can be improved
let rank = output.strides.len();
let is_contiguous = is_contiguous(&output.shape[axis..rank], &output.strides[axis..rank])
&& output.strides[rank - 1] == 1;
let shape = output.shape.get(axis + 1).copied().unwrap_or(1);
if is_contiguous && shape.is_multiple_of(line_size_input) {
line_size_output = line_size_input;
}
}
if strategy.parallel_output_vectorization
&& line_mode == LineMode::Parallel
&& line_size_input > 1
&& is_contiguous(input.shape, input.strides)
&& axis == input.shape.len() - 1
{
let supported_line_sizes = client.io_optimized_line_sizes(dtype.size());
let num_reduce = output.shape.iter().copied().product::<usize>();
line_size_output = supported_line_sizes
.filter(|&line_size| num_reduce % line_size == 0)
.max()
.unwrap_or(1);
}
(line_size_input, line_size_output)
}