1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
use bytemuck::{try_cast_slice, try_cast_slice_mut};
use crate::TensorDesc;
/// A simple single-threaded N-D convolution for f32 tensors.
pub fn f32_f32_f32_f32_cpu(
src_dims: &[i64],
weight_dims: &[i64],
dst_dims: &[i64],
src_bytes: &[u8],
weight_bytes: &[u8],
bias_bytes: Option<&[u8]>,
dst_ptr: &mut [u8],
stride: &[i64],
pads_begin: &[i64],
dilation: &[i64],
group: usize,
) {
// Cast to f32 slices
let src_f: &[f32] = try_cast_slice(src_bytes).expect("src bytes not f32");
let weight_f: &[f32] = try_cast_slice(weight_bytes).expect("weight bytes not f32");
let dst_f: &mut [f32] = try_cast_slice_mut(dst_ptr).expect("dst bytes not f32");
// weight layout: [M, C/group, k1, k2, ...]
// src layout: [N, C, D1, D2, ...]
// dst layout: [N, M, O1, O2, ...]
// primary shape dims as usize for looping/indexing
let n = src_dims[0] as usize;
let c = src_dims[1] as usize;
let m = weight_dims[0] as usize;
// Validate group configuration
if group == 0 || !c.is_multiple_of(group) || !m.is_multiple_of(group) {
panic!(
"f32_cpu: unsupported group configuration: group={}, C={}, M={}",
group, c, m
);
}
let m_per_group = m / group;
let c_per_group = c / group;
let spatial_rank = src_dims.len() - 2;
// Compute strides for indexing
let src_strides = TensorDesc::compute_strides(src_dims);
let dst_strides = TensorDesc::compute_strides(dst_dims);
let weight_strides = TensorDesc::compute_strides(weight_dims);
// Helper to compute linear offset
let offset = |idxs: &[usize], strides: &[usize]| -> usize {
idxs.iter().zip(strides.iter()).map(|(i, s)| i * s).sum()
};
// Bias as f32 slice if present
let bias_f: Option<&[f32]> = bias_bytes.map(|b| try_cast_slice(b).expect("bias bytes not f32"));
// For each batch n and output channel m and spatial location, compute convolution
// We'll iterate over N, M, and spatial output positions, and accumulate over input channels and kernel positions
// Precompute kernel spatial shape and number of kernel elements
let kernel_spatial: Vec<i64> = weight_dims[2..].to_vec();
let kernel_elems = kernel_spatial.iter().product();
// Iterate
for ni in 0..n {
for mi in 0..m {
// For each output spatial index, represented as a multi-index
let out_spatial_counts = &dst_dims[2..];
let mut out_index = vec![0; spatial_rank];
loop {
// compute dst linear index
let mut dst_idxs = vec![0; 2 + spatial_rank];
dst_idxs[0] = ni;
dst_idxs[1] = mi;
for (i, &v) in out_index.iter().enumerate() {
dst_idxs[2 + i] = v;
}
let dst_off = offset(&dst_idxs, &dst_strides);
let mut acc: f32 = 0.0;
// determine channel range for this output channel's group
let group_id = mi / m_per_group;
let c_start = group_id * c_per_group;
let c_end = c_start + c_per_group;
// accumulate over input channels in the same group and kernel positions
for ci in c_start..c_end {
for k_idx in 0..kernel_elems {
// convert k_idx to multi-index over kernel_spatial
let mut rem = k_idx;
let mut k_multi = vec![0; spatial_rank];
for d in (0..spatial_rank).rev() {
let dim = kernel_spatial[d];
k_multi[d] = rem % dim;
rem /= dim;
}
// compute input spatial position: in_pos = out_pos*stride - pad_begin + k*dilation
let mut src_idxs = vec![0; 2 + spatial_rank];
src_idxs[0] = ni;
src_idxs[1] = ci;
let mut in_bounds = true;
for (i, &out_v) in out_index.iter().enumerate() {
let o_i = out_v as i64;
let s_i = stride[i];
let p_b = pads_begin[i];
let dil = dilation[i];
let kpos = k_multi[i] as i64;
let in_pos = o_i * s_i - p_b + kpos * dil;
if in_pos < 0 || in_pos >= src_dims[2 + i] {
in_bounds = false;
break;
}
src_idxs[2 + i] = in_pos as usize;
}
if !in_bounds {
continue;
}
// linear offsets
let src_off = offset(&src_idxs, &src_strides);
// weight index: [mi, ci_in_group, k_multi...]
let mut w_idxs = vec![0; 2 + spatial_rank];
w_idxs[0] = mi;
w_idxs[1] = ci - c_start; // channel index within the group
for (i, &km) in k_multi.iter().enumerate() {
w_idxs[2 + i] = km as usize;
}
let w_off = offset(&w_idxs, &weight_strides);
acc += src_f[src_off] * weight_f[w_off];
}
}
// add bias
if let Some(bf) = bias_f {
acc += bf[mi];
}
dst_f[dst_off] = acc;
// increment out_index
let mut carry = 1;
for i in (0..spatial_rank).rev() {
out_index[i] += carry;
if out_index[i] >= out_spatial_counts[i] as usize {
out_index[i] = 0;
carry = 1;
} else {
carry = 0;
break;
}
}
if carry == 1 {
break;
}
}
}
}
}