1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
use bytemuck::{try_cast_slice, try_cast_slice_mut};
use crate::TensorDesc;
/// N-D max pooling f32 CPU implementation. No indices are produced.
pub fn f32_f32_cpu(
src_dims: &[i64],
dst_dims: &[i64],
src_bytes: &[u8],
dst_ptr: &mut [u8],
kernel: &[i64],
stride: &[i64],
pads_begin: &[i64],
dilation: &[i64],
) {
let src_f: &[f32] = try_cast_slice(src_bytes).expect("src bytes not f32");
let dst_f: &mut [f32] = try_cast_slice_mut(dst_ptr).expect("dst bytes not f32");
// Layouts: src [N, C, D1..], dst [N, C, O1..]
assert!(
src_dims.len() >= 2 && dst_dims.len() >= 2,
"MaxPool: dims too small"
);
// convert primary dims to usize for indexing
let n = src_dims[0] as usize;
let c = src_dims[1] as usize;
let spatial_rank = src_dims.len() - 2;
// compute strides
let src_strides = TensorDesc::compute_strides(src_dims);
let dst_strides = TensorDesc::compute_strides(dst_dims);
let offset = |idxs: &[usize], strides: &[usize]| -> usize {
idxs.iter().zip(strides.iter()).map(|(i, s)| i * s).sum()
};
// For each (n, c, out_spatial...) compute max over kernel window
for ni in 0..n {
for ci in 0..c {
// iterate over output positions via mixed-radix counting
let out_counts = &dst_dims[2..];
let mut out_index = vec![0; spatial_rank];
loop {
// compute dst offset
let mut dst_idxs = vec![0; 2 + spatial_rank];
dst_idxs[0] = ni;
dst_idxs[1] = ci;
for (i, &v) in out_index.iter().enumerate() {
dst_idxs[2 + i] = v;
}
let dst_off = offset(&dst_idxs, &dst_strides);
// scan kernel positions and compute max
// If a window contains no valid input positions (fully padded) we write 0.0
let mut max_val: f32 = f32::NEG_INFINITY;
let mut found = false;
// nested loops over kernel elements via mixed radix
let mut k_multi = vec![0; spatial_rank];
loop {
// compute input positions
let mut src_idxs = vec![0; 2 + spatial_rank];
src_idxs[0] = ni;
src_idxs[1] = ci;
let mut in_bounds = true;
for (i, &out_v) in out_index.iter().enumerate() {
let o_i = out_v as i64;
let s = stride[i];
let p = pads_begin.get(i).copied().unwrap_or(0);
let dil = dilation[i];
let kpos = k_multi[i] as i64;
let in_pos = o_i * s - p + kpos * dil;
if in_pos < 0 || in_pos >= src_dims[2 + i] {
in_bounds = false;
break;
}
src_idxs[2 + i] = in_pos as usize;
}
if in_bounds {
let src_off = offset(&src_idxs, &src_strides);
let val = src_f[src_off];
if val > max_val {
max_val = val;
}
found = true;
}
// increment k_multi
let mut carry = 1;
for i in (0..spatial_rank).rev() {
k_multi[i] += carry;
if k_multi[i] >= kernel[i] {
k_multi[i] = 0;
carry = 1;
} else {
carry = 0;
break;
}
}
if carry == 1 {
break;
}
}
// If no valid in-bounds values found (fully padded window), write 0.0.
// Otherwise write the computed max value.
if found {
dst_f[dst_off] = max_val;
} else {
dst_f[dst_off] = 0.0;
}
// increment out_index
let mut carry = 1;
for i in (0..spatial_rank).rev() {
out_index[i] += carry;
if out_index[i] >= out_counts[i] as usize {
out_index[i] = 0;
carry = 1;
} else {
carry = 0;
break;
}
}
if carry == 1 {
break;
}
}
}
}
}