#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
pub mod transforms {
#[rustfmt::skip]
pub const G: [[f32; 3]; 4] = [
[ 1.0, 0.0, 0.0 ],
[ 0.5, 0.5, 0.5 ],
[ 0.5, -0.5, 0.5 ],
[ 0.0, 0.0, 1.0 ],
];
#[rustfmt::skip]
pub const G_T: [[f32; 4]; 3] = [
[ 1.0, 0.5, 0.5, 0.0 ],
[ 0.0, 0.5, -0.5, 0.0 ],
[ 0.0, 0.5, 0.5, 1.0 ],
];
#[rustfmt::skip]
pub const B_T: [[f32; 4]; 4] = [
[ 1.0, 0.0, -1.0, 0.0 ],
[ 0.0, 1.0, 1.0, 0.0 ],
[ 0.0, -1.0, 1.0, 0.0 ],
[ 0.0, 1.0, 0.0, -1.0 ],
];
#[rustfmt::skip]
pub const B: [[f32; 4]; 4] = [
[ 1.0, 0.0, 0.0, 0.0 ],
[ 0.0, 1.0, -1.0, 1.0 ],
[-1.0, 1.0, 1.0, 0.0 ],
[ 0.0, 0.0, 0.0, -1.0 ],
];
#[rustfmt::skip]
pub const A_T: [[f32; 4]; 2] = [
[ 1.0, 1.0, 1.0, 0.0 ],
[ 0.0, 1.0, -1.0, -1.0 ],
];
#[rustfmt::skip]
pub const A: [[f32; 2]; 4] = [
[ 1.0, 0.0 ],
[ 1.0, 1.0 ],
[ 1.0, -1.0 ],
[ 0.0, -1.0 ],
];
}
pub fn transform_filter(filter: &[f32; 9]) -> [f32; 16] {
let g = [
[filter[0], filter[1], filter[2]],
[filter[3], filter[4], filter[5]],
[filter[6], filter[7], filter[8]],
];
let mut gg = [[0.0f32; 3]; 4];
for i in 0..4 {
for j in 0..3 {
for k in 0..3 {
gg[i][j] += transforms::G[i][k] * g[k][j];
}
}
}
let mut u = [0.0f32; 16];
for i in 0..4 {
for j in 0..4 {
let mut sum = 0.0f32;
for k in 0..3 {
sum += gg[i][k] * transforms::G_T[k][j];
}
u[i * 4 + j] = sum;
}
}
u
}
pub fn transform_input(tile: &[f32; 16]) -> [f32; 16] {
let d = [
[tile[0], tile[1], tile[2], tile[3]],
[tile[4], tile[5], tile[6], tile[7]],
[tile[8], tile[9], tile[10], tile[11]],
[tile[12], tile[13], tile[14], tile[15]],
];
let mut btd = [[0.0f32; 4]; 4];
for i in 0..4 {
for j in 0..4 {
for k in 0..4 {
btd[i][j] += transforms::B_T[i][k] * d[k][j];
}
}
}
let mut v = [0.0f32; 16];
for i in 0..4 {
for j in 0..4 {
let mut sum = 0.0f32;
for k in 0..4 {
sum += btd[i][k] * transforms::B[k][j];
}
v[i * 4 + j] = sum;
}
}
v
}
pub fn transform_output(m: &[f32; 16]) -> [f32; 4] {
let m_mat = [
[m[0], m[1], m[2], m[3]],
[m[4], m[5], m[6], m[7]],
[m[8], m[9], m[10], m[11]],
[m[12], m[13], m[14], m[15]],
];
let mut atm = [[0.0f32; 4]; 2];
for i in 0..2 {
for j in 0..4 {
for k in 0..4 {
atm[i][j] += transforms::A_T[i][k] * m_mat[k][j];
}
}
}
let mut y = [0.0f32; 4];
for i in 0..2 {
for j in 0..2 {
let mut sum = 0.0f32;
for k in 0..4 {
sum += atm[i][k] * transforms::A[k][j];
}
y[i * 2 + j] = sum;
}
}
y
}
pub fn winograd_multiply(u: &[f32; 16], v: &[f32; 16]) -> [f32; 16] {
let mut m = [0.0f32; 16];
for i in 0..16 {
m[i] = u[i] * v[i];
}
m
}
#[derive(Debug, Clone)]
pub struct WinogradFilterCache {
pub filters: Vec<f32>,
pub out_channels: usize,
pub in_channels: usize,
}
impl WinogradFilterCache {
pub fn new(filters: &[f32], out_channels: usize, in_channels: usize) -> Self {
let mut transformed = vec![0.0f32; out_channels * in_channels * 16];
for oc in 0..out_channels {
for ic in 0..in_channels {
let filter_offset = (oc * in_channels + ic) * 9;
let mut filter_3x3 = [0.0f32; 9];
filter_3x3.copy_from_slice(&filters[filter_offset..filter_offset + 9]);
let transformed_filter = transform_filter(&filter_3x3);
let cache_offset = (oc * in_channels + ic) * 16;
transformed[cache_offset..cache_offset + 16].copy_from_slice(&transformed_filter);
}
}
Self {
filters: transformed,
out_channels,
in_channels,
}
}
#[inline]
pub fn get(&self, out_c: usize, in_c: usize) -> &[f32] {
let offset = (out_c * self.in_channels + in_c) * 16;
&self.filters[offset..offset + 16]
}
}
pub fn conv_3x3_winograd(
input: &[f32],
filter_cache: &WinogradFilterCache,
output: &mut [f32],
h: usize,
w: usize,
padding: usize,
) {
let in_c = filter_cache.in_channels;
let out_c = filter_cache.out_channels;
let out_h = (h + 2 * padding - 2) / 2;
let out_w = (w + 2 * padding - 2) / 2;
for oh_tile in 0..out_h {
for ow_tile in 0..out_w {
let oh0 = oh_tile * 2;
let ow0 = ow_tile * 2;
let mut tile_output = vec![0.0f32; out_c * 4];
for ic in 0..in_c {
let mut input_tile = [0.0f32; 16];
for ti in 0..4 {
for tj in 0..4 {
let ih = (oh0 + ti) as isize - padding as isize;
let iw = (ow0 + tj) as isize - padding as isize;
if ih >= 0 && ih < h as isize && iw >= 0 && iw < w as isize {
let idx = (ih as usize * w + iw as usize) * in_c + ic;
input_tile[ti * 4 + tj] = input[idx];
}
}
}
let v = transform_input(&input_tile);
for oc in 0..out_c {
let u = filter_cache.get(oc, ic);
let mut m = [0.0f32; 16];
for i in 0..16 {
m[i] = u[i] * v[i];
}
let y = transform_output(&m);
for i in 0..4 {
tile_output[oc * 4 + i] += y[i];
}
}
}
for oi in 0..2 {
for oj in 0..2 {
let oh = oh0 + oi;
let ow = ow0 + oj;
if oh < out_h * 2 && ow < out_w * 2 {
for oc in 0..out_c {
let out_idx = (oh * out_w * 2 + ow) * out_c + oc;
if out_idx < output.len() {
output[out_idx] = tile_output[oc * 4 + oi * 2 + oj];
}
}
}
}
}
}
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn transform_input_avx2(tiles: &[[f32; 16]; 4]) -> [[f32; 16]; 4] {
let mut result = [[0.0f32; 16]; 4];
for t in 0..4 {
result[t] = transform_input(&tiles[t]);
}
result
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn transform_output_avx2(m_tiles: &[[f32; 16]; 4]) -> [[f32; 4]; 4] {
let mut result = [[0.0f32; 4]; 4];
for t in 0..4 {
result[t] = transform_output(&m_tiles[t]);
}
result
}
#[cfg(not(target_arch = "x86_64"))]
pub unsafe fn transform_input_avx2(_tiles: &[[f32; 16]; 4]) -> [[f32; 16]; 4] {
[[0.0f32; 16]; 4]
}
#[cfg(not(target_arch = "x86_64"))]
pub unsafe fn transform_output_avx2(_m_tiles: &[[f32; 16]; 4]) -> [[f32; 4]; 4] {
[[0.0f32; 4]; 4]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_filter_transform_roundtrip() {
let filter = [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0];
let transformed = transform_filter(&filter);
assert!(transformed[5].abs() > 0.1 || transformed[6].abs() > 0.1);
}
#[test]
fn test_input_transform() {
let tile = [
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
];
let transformed = transform_input(&tile);
let sum: f32 = transformed.iter().map(|x| x.abs()).sum();
assert!(sum > 0.0);
}
#[test]
fn test_output_transform() {
let m = [
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
];
let output = transform_output(&m);
assert_eq!(output.len(), 4);
}
#[test]
fn test_winograd_filter_cache() {
let filters = vec![1.0, 0.0, -1.0, 2.0, 0.0, -2.0, 1.0, 0.0, -1.0];
let cache = WinogradFilterCache::new(&filters, 1, 1);
assert_eq!(cache.filters.len(), 16);
assert_eq!(cache.out_channels, 1);
assert_eq!(cache.in_channels, 1);
}
#[test]
fn test_winograd_identity_conv() {
let filters = [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0];
let cache = WinogradFilterCache::new(&filters, 1, 1);
let input = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
];
let mut output = vec![0.0f32; 16];
conv_3x3_winograd(&input, &cache, &mut output, 4, 4, 1);
let output_sum: f32 = output.iter().sum();
assert!(output_sum.abs() > 0.0);
}
}