Skip to main content

locus_core/simd/
sampler.rs

1//! SIMD-vectorized image sampling kernels.
2use crate::image::ImageView;
3use multiversion::multiversion;
4
5/// Vectorized bilinear interpolation for 8 points simultaneously.
6///
7/// # Safety
8/// This function uses `_mm256_i32gather_epi32` to fetch 8-bit pixels by performing
9/// 32-bit unaligned loads. This requires the input image buffer to have at least
10/// **3 bytes of padding** at the end to avoid out-of-bounds reads when sampling
11/// pixels near the bottom-right corner.
12#[multiversion(targets("x86_64+avx2+fma", "aarch64+neon"))]
13pub fn sample_bilinear_v8(img: &ImageView, x: &[f32; 8], y: &[f32; 8], out: &mut [f32; 8]) {
14    #[cfg(all(
15        target_arch = "x86_64",
16        target_feature = "avx2",
17        target_feature = "fma"
18    ))]
19    if img.has_simd_padding() {
20        unsafe {
21            use std::arch::x86_64::*;
22
23            let vx = _mm256_loadu_ps(x.as_ptr());
24            let vy = _mm256_loadu_ps(y.as_ptr());
25
26            // Offset by -0.5 to match ImageView::sample_bilinear logic (pixel centers)
27            let half = _mm256_set1_ps(0.5);
28            let vx = _mm256_sub_ps(vx, half);
29            let vy = _mm256_sub_ps(vy, half);
30
31            // Integer parts (clamped to 0..width-2, 0..height-2 to ensure 2x2 neighborhood is safe)
32            let zero = _mm256_setzero_ps();
33            let max_x = _mm256_set1_ps((img.width as f32 - 2.0).max(0.0));
34            let max_y = _mm256_set1_ps((img.height as f32 - 2.0).max(0.0));
35
36            let vx_clamped = _mm256_max_ps(zero, _mm256_min_ps(vx, max_x));
37            let vy_clamped = _mm256_max_ps(zero, _mm256_min_ps(vy, max_y));
38
39            let vx_floor = _mm256_floor_ps(vx_clamped);
40            let vy_floor = _mm256_floor_ps(vy_clamped);
41
42            let vix = _mm256_cvtps_epi32(vx_floor);
43            let viy = _mm256_cvtps_epi32(vy_floor);
44
45            // Fractional parts (weights)
46            let fx = _mm256_sub_ps(vx_clamped, vx_floor);
47            let fy = _mm256_sub_ps(vy_clamped, vy_floor);
48
49            let one = _mm256_set1_ps(1.0);
50            let inv_fx = _mm256_sub_ps(one, fx);
51            let inv_fy = _mm256_sub_ps(one, fy);
52
53            // 1D memory offsets: idx = y_int * stride + x_int
54            let stride = _mm256_set1_epi32(img.stride as i32);
55            let idx_tl = _mm256_add_epi32(_mm256_mullo_epi32(viy, stride), vix);
56            let idx_tr = _mm256_add_epi32(idx_tl, _mm256_set1_epi32(1));
57            let idx_bl = _mm256_add_epi32(idx_tl, stride);
58            let idx_br = _mm256_add_epi32(idx_bl, _mm256_set1_epi32(1));
59
60            // Gather 4 surrounding pixels (as 32-bit ints, then convert to f32)
61            let base_ptr = img.data.as_ptr() as *const i32;
62
63            let gather_to_f32 = |offsets: __m256i| -> __m256 {
64                let gathered = _mm256_i32gather_epi32(base_ptr, offsets, 1);
65                let masked = _mm256_and_si256(gathered, _mm256_set1_epi32(0xFF));
66                _mm256_cvtepi32_ps(masked)
67            };
68
69            let v_tl = gather_to_f32(idx_tl);
70            let v_tr = gather_to_f32(idx_tr);
71            let v_bl = gather_to_f32(idx_bl);
72            let v_br = gather_to_f32(idx_br);
73
74            // Bilinear interpolation using FMA
75            // I = (1-fx)(1-fy)TL + fx(1-fy)TR + (1-fx)fy*BL + fx*fy*BR
76
77            let w_tl = _mm256_mul_ps(inv_fx, inv_fy);
78            let w_tr = _mm256_mul_ps(fx, inv_fy);
79            let w_bl = _mm256_mul_ps(inv_fx, fy);
80            let w_br = _mm256_mul_ps(fx, fy);
81
82            let mut res = _mm256_mul_ps(w_tl, v_tl);
83            res = _mm256_fmadd_ps(w_tr, v_tr, res);
84            res = _mm256_fmadd_ps(w_bl, v_bl, res);
85            res = _mm256_fmadd_ps(w_br, v_br, res);
86
87            _mm256_storeu_ps(out.as_mut_ptr(), res);
88            return;
89        }
90    }
91
92    #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
93    #[allow(unsafe_code)]
94    // SAFETY: NEON intrinsics are safe on aarch64 with neon feature.
95    // Buffer bounds are checked via ImageView.
96    unsafe {
97        use std::arch::aarch64::*;
98
99        // NEON processes 4 floats at a time (float32x4_t)
100        // We need to do it twice for 8 points.
101        for chunk in 0..2 {
102            let offset = chunk * 4;
103            let vx = vld1q_f32(x.as_ptr().add(offset));
104            let vy = vld1q_f32(y.as_ptr().add(offset));
105
106            // Offset by -0.5
107            let half = vdupq_n_f32(0.5);
108            let vx = vsubq_f32(vx, half);
109            let vy = vsubq_f32(vy, half);
110
111            // Clamp to bounds
112            let zero = vdupq_n_f32(0.0);
113            let max_x = vdupq_n_f32((img.width as f32 - 2.0).max(0.0));
114            let max_y = vdupq_n_f32((img.height as f32 - 2.0).max(0.0));
115
116            let vx_clamped = vmaxq_f32(zero, vminq_f32(vx, max_x));
117            let vy_clamped = vmaxq_f32(zero, vminq_f32(vy, max_y));
118
119            // Integer parts (floor)
120            // NEON floor: vrndmq_f32
121            let vx_floor = vrndmq_f32(vx_clamped);
122            let vy_floor = vrndmq_f32(vy_clamped);
123
124            let vix = vcvtq_s32_f32(vx_floor);
125            let viy = vcvtq_s32_f32(vy_floor);
126
127            // Fractional parts
128            let fx = vsubq_f32(vx_clamped, vx_floor);
129            let fy = vsubq_f32(vy_clamped, vy_floor);
130
131            let one = vdupq_n_f32(1.0);
132            let inv_fx = vsubq_f32(one, fx);
133            let inv_fy = vsubq_f32(one, fy);
134
135            // Fetch pixels (No gather, use scalar loads)
136            let mut pix_tl = [0.0f32; 4];
137            let mut pix_tr = [0.0f32; 4];
138            let mut pix_bl = [0.0f32; 4];
139            let mut pix_br = [0.0f32; 4];
140
141            let mut vix_arr = [0i32; 4];
142            let mut viy_arr = [0i32; 4];
143            vst1q_s32(vix_arr.as_mut_ptr(), vix);
144            vst1q_s32(viy_arr.as_mut_ptr(), viy);
145
146            for i in 0..4 {
147                let px = vix_arr[i] as usize;
148                let py = viy_arr[i] as usize;
149                let stride = img.stride;
150                let base = py * stride + px;
151                pix_tl[i] = f32::from(img.data[base]);
152                pix_tr[i] = f32::from(img.data[base + 1]);
153                pix_bl[i] = f32::from(img.data[base + stride]);
154                pix_br[i] = f32::from(img.data[base + stride + 1]);
155            }
156
157            let v_tl = vld1q_f32(pix_tl.as_ptr());
158            let v_tr = vld1q_f32(pix_tr.as_ptr());
159            let v_bl = vld1q_f32(pix_bl.as_ptr());
160            let v_br = vld1q_f32(pix_br.as_ptr());
161
162            // Bilinear interpolation
163            let w_tl = vmulq_f32(inv_fx, inv_fy);
164            let w_tr = vmulq_f32(fx, inv_fy);
165            let w_bl = vmulq_f32(inv_fx, fy);
166            let w_br = vmulq_f32(fx, fy);
167
168            let mut res = vmulq_f32(w_tl, v_tl);
169            res = vfmaq_f32(res, w_tr, v_tr);
170            res = vfmaq_f32(res, w_bl, v_bl);
171            res = vfmaq_f32(res, w_br, v_br);
172
173            vst1q_f32(out.as_mut_ptr().add(offset), res);
174        }
175        return;
176    }
177
178    #[cfg(not(any(
179        all(
180            target_arch = "x86_64",
181            target_feature = "avx2",
182            target_feature = "fma"
183        ),
184        all(target_arch = "aarch64", target_feature = "neon")
185    )))]
186    {
187        // Fallback: Scalar
188        for i in 0..8 {
189            out[i] = img.sample_bilinear(f64::from(x[i]), f64::from(y[i])) as f32;
190        }
191    }
192}