1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
use crate::frame; #[repr(align(32))] struct SixteenAlignedF32([f32; 16]); #[derive(Copy, Clone, Debug)] pub struct SConvFma16x6; #[target_feature(enable = "fma")] unsafe fn fma( k: usize, a: *const f32, b_tops: *const *const f32, b_down_offsets: *const isize, c: *mut f32, rsc: usize, csc: usize, ) { use std::arch::x86_64::*; assert!(a as usize % 32 == 0); let mut ab1 = [_mm256_setzero_ps(); 6]; let mut ab2 = [_mm256_setzero_ps(); 6]; for i in 0..k { let down_offset = *b_down_offsets.offset(i as isize) >> 2; let ar1 = _mm256_load_ps(a.offset((i * 16) as isize)); let ar2 = _mm256_load_ps(a.offset((i * 16 + 8) as isize)); for j in 0usize..6 { let bp = *(*b_tops.offset(j as isize)).offset(down_offset); let br = _mm256_set1_ps(bp); ab1[j] = _mm256_fmadd_ps(ar1, br, ab1[j]); ab2[j] = _mm256_fmadd_ps(ar2, br, ab2[j]); } } for x in 0..6 { let mut col = SixteenAlignedF32([0f32; 16]); _mm256_store_ps(col.0.as_mut_ptr(), ab1[x]); _mm256_store_ps(col.0.as_mut_ptr().offset(8), ab2[x]); for y in 0..16 { *c.offset((y * rsc + x * csc) as isize) = col.0[y]; } } } impl frame::conv::ConvKer<f32> for SConvFma16x6 { #[inline(always)] fn name() -> &'static str { "fma" } #[inline(always)] fn mr() -> usize { 16 } #[inline(always)] fn nr() -> usize { 6 } fn alignment_bytes_a() -> usize { 32 } fn alignment_bytes_b() -> usize { 4 } #[inline(always)] fn kernel( k: usize, a: *const f32, b_tops: *const *const f32, b_down_offsets: *const isize, c: *mut f32, rsc: usize, csc: usize, ) { unsafe { fma(k, a, b_tops, b_down_offsets, c, rsc, csc) } } } #[cfg(test)] #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"),))] mod test { use super::*; use crate::frame::conv::test::*; use crate::frame::PackedConv; use proptest::*; proptest! { #[test] fn conv(pb in strat_conv_1d()) { if !is_x86_feature_detected!("fma") { return Ok(()) } let (kernel_offsets, data_offsets) = pb.offsets(); let conv = PackedConv::<SConvFma16x6, f32>::new(pb.co, kernel_offsets, data_offsets); let found = pb.run(&conv); let expected = pb.expected(); let dist = found.iter().zip(expected.iter()).map(|(f,e)| (f - e).abs()).sum::<f32>(); prop_assert!(dist < 0.00001, "Expected: {:?} found, {:?}", expected, found); } } }