Skip to main content

yscv_kernels/ops/
conv.rs

1use rayon::{ThreadPool, prelude::*};
2use yscv_tensor::{AlignedVec, Tensor, TensorError};
3
4use super::super::error::KernelError;
5use super::config::{
6    Conv2dPlan, Conv2dSpec, DepthwiseConv2dPlan, DepthwiseConv2dSpec, ParallelElementwiseConfig,
7    SeparableConv2dKernels, SeparableConv2dSpec, should_parallelize_len,
8};
9
10pub fn conv2d_nhwc_with_config_and_pool(
11    input: &Tensor,
12    kernel: &Tensor,
13    bias: Option<&Tensor>,
14    spec: Conv2dSpec,
15    config: ParallelElementwiseConfig,
16    thread_pool: Option<&ThreadPool>,
17) -> Result<Tensor, KernelError> {
18    let plan = build_conv2d_plan(input, kernel, bias, spec)?;
19
20    // Direct 3×3 microkernel for small inputs (no im2col overhead).
21    // For small spatial sizes the im2col copy + BLAS call setup dominates; a
22    // direct SIMD kernel that walks the input in-place is significantly faster.
23    // WHY 4096: below this, im2col memory copy (~KH*KW*C*HW floats) + BLAS setup exceeds direct conv cost.
24    #[cfg(target_arch = "aarch64")]
25    if plan.kernel_h == 3
26        && plan.kernel_w == 3
27        && plan.batch == 1
28        && !cfg!(miri)
29        && std::arch::is_aarch64_feature_detected!("neon")
30        && plan.out_h * plan.out_w < 4096
31    {
32        #[allow(unsafe_code)]
33        let mut output = AlignedVec::<f32>::uninitialized(plan.output_len);
34        #[allow(unsafe_code)]
35        unsafe {
36            conv2d_3x3_direct_neon(
37                input.data(),
38                kernel.data(),
39                &mut output,
40                plan.in_w,
41                plan.in_channels,
42                plan.out_channels,
43                plan.out_h,
44                plan.out_w,
45                plan.stride_h,
46                plan.stride_w,
47            );
48        }
49        if let Some(b) = bias {
50            let bd = b.data();
51            for i in 0..plan.out_h * plan.out_w {
52                for c in 0..plan.out_channels {
53                    output[i * plan.out_channels + c] += bd[c];
54                }
55            }
56        }
57        return Tensor::from_aligned(vec![1, plan.out_h, plan.out_w, plan.out_channels], output)
58            .map_err(Into::into);
59    }
60
61    #[cfg(target_arch = "x86_64")]
62    if plan.kernel_h == 3
63        && plan.kernel_w == 3
64        && plan.batch == 1
65        && !cfg!(miri)
66        && is_x86_feature_detected!("avx")
67        && is_x86_feature_detected!("fma")
68        && plan.out_h * plan.out_w < 4096
69    {
70        #[allow(unsafe_code)]
71        let mut output = AlignedVec::<f32>::uninitialized(plan.output_len);
72        #[allow(unsafe_code)]
73        unsafe {
74            conv2d_3x3_direct_avx(
75                input.data(),
76                kernel.data(),
77                &mut output,
78                plan.in_w,
79                plan.in_channels,
80                plan.out_channels,
81                plan.out_h,
82                plan.out_w,
83                plan.stride_h,
84                plan.stride_w,
85            );
86        }
87        if let Some(b) = bias {
88            let bd = b.data();
89            for i in 0..plan.out_h * plan.out_w {
90                for c in 0..plan.out_channels {
91                    output[i * plan.out_channels + c] += bd[c];
92                }
93            }
94        }
95        return Tensor::from_aligned(vec![1, plan.out_h, plan.out_w, plan.out_channels], output)
96            .map_err(Into::into);
97    }
98
99    #[cfg(target_arch = "x86_64")]
100    if plan.kernel_h == 3
101        && plan.kernel_w == 3
102        && plan.batch == 1
103        && !cfg!(miri)
104        && is_x86_feature_detected!("fma")
105        && plan.out_h * plan.out_w < 4096
106    {
107        #[allow(unsafe_code)]
108        let mut output = AlignedVec::<f32>::uninitialized(plan.output_len);
109        #[allow(unsafe_code)]
110        unsafe {
111            conv2d_3x3_direct_sse(
112                input.data(),
113                kernel.data(),
114                &mut output,
115                plan.in_w,
116                plan.in_channels,
117                plan.out_channels,
118                plan.out_h,
119                plan.out_w,
120                plan.stride_h,
121                plan.stride_w,
122            );
123        }
124        if let Some(b) = bias {
125            let bd = b.data();
126            for i in 0..plan.out_h * plan.out_w {
127                for c in 0..plan.out_channels {
128                    output[i * plan.out_channels + c] += bd[c];
129                }
130            }
131        }
132        return Tensor::from_aligned(vec![1, plan.out_h, plan.out_w, plan.out_channels], output)
133            .map_err(Into::into);
134    }
135
136    // Fast path: im2col + BLAS sgemm for single-batch convolutions with enough output
137    // positions to amortise BLAS/im2col overhead.
138    #[cfg(feature = "blas")]
139    if !cfg!(miri) && plan.batch == 1 {
140        return conv2d_im2col_gemm(&plan, input.data(), kernel.data(), bias.map(Tensor::data));
141    }
142
143    let input_data = input.data();
144    let kernel_data = kernel.data();
145    let bias_data = bias.map(Tensor::data);
146    let out_row_len = plan.out_w * plan.out_channels;
147    if plan.output_len == 0 || out_row_len == 0 {
148        return Tensor::from_vec(
149            vec![plan.batch, plan.out_h, plan.out_w, plan.out_channels],
150            vec![],
151        )
152        .map_err(Into::into);
153    }
154
155    // SAFETY: `conv2d_nhwc_row` writes every element in each output row.
156    #[allow(unsafe_code)]
157    let mut output = AlignedVec::<f32>::uninitialized(plan.output_len);
158
159    if should_parallelize_len(plan.output_len, config.min_parallel_elements, thread_pool) {
160        let mut work = || {
161            output
162                .par_chunks_mut(out_row_len)
163                .enumerate()
164                .for_each(|(row_idx, out_row)| {
165                    conv2d_nhwc_row(input_data, kernel_data, bias_data, plan, row_idx, out_row);
166                });
167        };
168        if let Some(pool) = thread_pool {
169            pool.install(work);
170        } else {
171            work();
172        }
173    } else {
174        for (row_idx, out_row) in output.chunks_mut(out_row_len).enumerate() {
175            conv2d_nhwc_row(input_data, kernel_data, bias_data, plan, row_idx, out_row);
176        }
177    }
178
179    Tensor::from_aligned(
180        vec![plan.batch, plan.out_h, plan.out_w, plan.out_channels],
181        output,
182    )
183    .map_err(Into::into)
184}
185
186pub fn depthwise_conv2d_nhwc_with_config_and_pool(
187    input: &Tensor,
188    kernel: &Tensor,
189    bias: Option<&Tensor>,
190    spec: DepthwiseConv2dSpec,
191    config: ParallelElementwiseConfig,
192    thread_pool: Option<&ThreadPool>,
193) -> Result<Tensor, KernelError> {
194    let plan = build_depthwise_conv2d_plan(input, kernel, bias, spec)?;
195    let input_data = input.data();
196    let kernel_data = kernel.data();
197    let bias_data = bias.map(Tensor::data);
198    let out_row_len = plan.out_w * plan.out_channels;
199    if plan.output_len == 0 || out_row_len == 0 {
200        return Tensor::from_aligned(
201            vec![plan.batch, plan.out_h, plan.out_w, plan.out_channels],
202            AlignedVec::<f32>::calloc(plan.output_len),
203        )
204        .map_err(Into::into);
205    }
206
207    let mut output = AlignedVec::<f32>::uninitialized(plan.output_len);
208
209    if should_parallelize_len(plan.output_len, config.min_parallel_elements, thread_pool) {
210        let mut work = || {
211            output
212                .par_chunks_mut(out_row_len)
213                .enumerate()
214                .for_each(|(row_idx, out_row)| {
215                    depthwise_conv2d_nhwc_row(
216                        input_data,
217                        kernel_data,
218                        bias_data,
219                        plan,
220                        row_idx,
221                        out_row,
222                    );
223                });
224        };
225        if let Some(pool) = thread_pool {
226            pool.install(work);
227        } else {
228            work();
229        }
230    } else {
231        for (row_idx, out_row) in output.chunks_mut(out_row_len).enumerate() {
232            depthwise_conv2d_nhwc_row(input_data, kernel_data, bias_data, plan, row_idx, out_row);
233        }
234    }
235
236    Tensor::from_aligned(
237        vec![plan.batch, plan.out_h, plan.out_w, plan.out_channels],
238        output,
239    )
240    .map_err(Into::into)
241}
242
243pub fn separable_conv2d_nhwc_with_config_and_pool(
244    input: &Tensor,
245    kernels: SeparableConv2dKernels<'_>,
246    spec: SeparableConv2dSpec,
247    config: ParallelElementwiseConfig,
248    thread_pool: Option<&ThreadPool>,
249) -> Result<Tensor, KernelError> {
250    if kernels.pointwise_kernel.rank() != 4
251        || kernels.pointwise_kernel.shape()[0] != 1
252        || kernels.pointwise_kernel.shape()[1] != 1
253    {
254        return Err(KernelError::InvalidSeparablePointwiseKernelShape {
255            pointwise_shape: kernels.pointwise_kernel.shape().to_vec(),
256        });
257    }
258
259    let depthwise_out = depthwise_conv2d_nhwc_with_config_and_pool(
260        input,
261        kernels.depthwise_kernel,
262        kernels.depthwise_bias,
263        DepthwiseConv2dSpec {
264            stride_h: spec.stride_h,
265            stride_w: spec.stride_w,
266        },
267        config,
268        thread_pool,
269    )?;
270
271    conv2d_nhwc_with_config_and_pool(
272        &depthwise_out,
273        kernels.pointwise_kernel,
274        kernels.pointwise_bias,
275        Conv2dSpec {
276            stride_h: 1,
277            stride_w: 1,
278        },
279        config,
280        thread_pool,
281    )
282}
283
284fn build_conv2d_plan(
285    input: &Tensor,
286    kernel: &Tensor,
287    bias: Option<&Tensor>,
288    spec: Conv2dSpec,
289) -> Result<Conv2dPlan, KernelError> {
290    let stride_h = spec.stride_h;
291    let stride_w = spec.stride_w;
292    if input.rank() != 4 || kernel.rank() != 4 {
293        return Err(KernelError::InvalidConvRank {
294            input_rank: input.rank(),
295            kernel_rank: kernel.rank(),
296        });
297    }
298    if stride_h == 0 || stride_w == 0 {
299        return Err(KernelError::InvalidConvParameters {
300            kernel_h: kernel.shape()[0],
301            kernel_w: kernel.shape()[1],
302            stride_h,
303            stride_w,
304        });
305    }
306
307    let batch = input.shape()[0];
308    let in_h = input.shape()[1];
309    let in_w = input.shape()[2];
310    let in_channels = input.shape()[3];
311    let kernel_h = kernel.shape()[0];
312    let kernel_w = kernel.shape()[1];
313    let kernel_in_channels = kernel.shape()[2];
314    let out_channels = kernel.shape()[3];
315
316    if kernel_h == 0 || kernel_w == 0 {
317        return Err(KernelError::InvalidConvParameters {
318            kernel_h,
319            kernel_w,
320            stride_h,
321            stride_w,
322        });
323    }
324    if kernel_in_channels != in_channels {
325        return Err(KernelError::ConvChannelMismatch {
326            input_channels: in_channels,
327            kernel_in_channels,
328        });
329    }
330    if kernel_h > in_h || kernel_w > in_w {
331        return Err(KernelError::ConvKernelLargerThanInput {
332            input_h: in_h,
333            input_w: in_w,
334            kernel_h,
335            kernel_w,
336        });
337    }
338    if let Some(bias_tensor) = bias
339        && (bias_tensor.rank() != 1 || bias_tensor.shape()[0] != out_channels)
340    {
341        return Err(KernelError::ConvBiasShapeMismatch {
342            bias_shape: bias_tensor.shape().to_vec(),
343            out_channels,
344        });
345    }
346
347    let out_h = (in_h - kernel_h) / stride_h + 1;
348    let out_w = (in_w - kernel_w) / stride_w + 1;
349    let output_len = batch
350        .checked_mul(out_h)
351        .and_then(|v| v.checked_mul(out_w))
352        .and_then(|v| v.checked_mul(out_channels))
353        .ok_or_else(|| {
354            KernelError::Tensor(TensorError::SizeOverflow {
355                shape: vec![batch, out_h, out_w, out_channels],
356            })
357        })?;
358
359    Ok(Conv2dPlan {
360        batch,
361        in_h,
362        in_w,
363        in_channels,
364        out_h,
365        out_w,
366        out_channels,
367        kernel_h,
368        kernel_w,
369        stride_h,
370        stride_w,
371        output_len,
372    })
373}
374
375fn build_depthwise_conv2d_plan(
376    input: &Tensor,
377    kernel: &Tensor,
378    bias: Option<&Tensor>,
379    spec: DepthwiseConv2dSpec,
380) -> Result<DepthwiseConv2dPlan, KernelError> {
381    let stride_h = spec.stride_h;
382    let stride_w = spec.stride_w;
383    if input.rank() != 4 || kernel.rank() != 4 {
384        return Err(KernelError::InvalidDepthwiseConvRank {
385            input_rank: input.rank(),
386            kernel_rank: kernel.rank(),
387        });
388    }
389    if stride_h == 0 || stride_w == 0 {
390        return Err(KernelError::InvalidDepthwiseConvParameters {
391            kernel_h: kernel.shape()[0],
392            kernel_w: kernel.shape()[1],
393            stride_h,
394            stride_w,
395        });
396    }
397
398    let batch = input.shape()[0];
399    let in_h = input.shape()[1];
400    let in_w = input.shape()[2];
401    let channels = input.shape()[3];
402    let kernel_h = kernel.shape()[0];
403    let kernel_w = kernel.shape()[1];
404    let kernel_channels = kernel.shape()[2];
405    let depth_multiplier = kernel.shape()[3];
406
407    if kernel_h == 0 || kernel_w == 0 || depth_multiplier == 0 {
408        return Err(KernelError::InvalidDepthwiseConvParameters {
409            kernel_h,
410            kernel_w,
411            stride_h,
412            stride_w,
413        });
414    }
415    if kernel_channels != channels {
416        return Err(KernelError::DepthwiseConvChannelMismatch {
417            input_channels: channels,
418            kernel_channels,
419        });
420    }
421    if kernel_h > in_h || kernel_w > in_w {
422        return Err(KernelError::DepthwiseConvKernelLargerThanInput {
423            input_h: in_h,
424            input_w: in_w,
425            kernel_h,
426            kernel_w,
427        });
428    }
429
430    let out_channels = channels.checked_mul(depth_multiplier).ok_or_else(|| {
431        KernelError::Tensor(TensorError::SizeOverflow {
432            shape: vec![channels, depth_multiplier],
433        })
434    })?;
435    if let Some(bias_tensor) = bias
436        && (bias_tensor.rank() != 1 || bias_tensor.shape()[0] != out_channels)
437    {
438        return Err(KernelError::DepthwiseConvBiasShapeMismatch {
439            bias_shape: bias_tensor.shape().to_vec(),
440            out_channels,
441        });
442    }
443
444    let out_h = (in_h - kernel_h) / stride_h + 1;
445    let out_w = (in_w - kernel_w) / stride_w + 1;
446    let output_len = batch
447        .checked_mul(out_h)
448        .and_then(|v| v.checked_mul(out_w))
449        .and_then(|v| v.checked_mul(out_channels))
450        .ok_or_else(|| {
451            KernelError::Tensor(TensorError::SizeOverflow {
452                shape: vec![batch, out_h, out_w, out_channels],
453            })
454        })?;
455
456    Ok(DepthwiseConv2dPlan {
457        batch,
458        in_h,
459        in_w,
460        channels,
461        depth_multiplier,
462        out_h,
463        out_w,
464        out_channels,
465        kernel_h,
466        kernel_w,
467        stride_h,
468        stride_w,
469        output_len,
470    })
471}
472
473fn conv2d_nhwc_row(
474    input: &[f32],
475    kernel: &[f32],
476    bias: Option<&[f32]>,
477    plan: Conv2dPlan,
478    row_idx: usize,
479    out_row: &mut [f32],
480) {
481    let batch_idx = row_idx / plan.out_h;
482    let out_y = row_idx % plan.out_h;
483    let in_y0 = out_y * plan.stride_h;
484    let batch_input_base = batch_idx * plan.in_h * plan.in_w * plan.in_channels;
485
486    for out_x in 0..plan.out_w {
487        let in_x0 = out_x * plan.stride_w;
488        let out_cell_base = out_x * plan.out_channels;
489        let out_slice = &mut out_row[out_cell_base..out_cell_base + plan.out_channels];
490
491        // Initialize with bias
492        if let Some(bias_values) = bias {
493            out_slice.copy_from_slice(&bias_values[..plan.out_channels]);
494        } else {
495            out_slice.fill(0.0);
496        }
497
498        // Accumulate: iterate over kernel window, broadcast input, FMA across out_channels
499        for ky in 0..plan.kernel_h {
500            let in_y = in_y0 + ky;
501            let input_row_base = batch_input_base + (in_y * plan.in_w + in_x0) * plan.in_channels;
502            let kernel_row_base = ky * plan.kernel_w * plan.in_channels * plan.out_channels;
503
504            for kx in 0..plan.kernel_w {
505                let input_pixel_base = input_row_base + kx * plan.in_channels;
506                let kernel_pixel_base = kernel_row_base + kx * plan.in_channels * plan.out_channels;
507
508                for in_channel in 0..plan.in_channels {
509                    let input_val = input[input_pixel_base + in_channel];
510                    let k_base = kernel_pixel_base + in_channel * plan.out_channels;
511                    // SIMD: broadcast input_val, multiply-add across out_channels
512                    conv_fma_row(
513                        out_slice,
514                        &kernel[k_base..k_base + plan.out_channels],
515                        input_val,
516                    );
517                }
518            }
519        }
520    }
521}
522
523/// Direct 3×3 convolution microkernel — no im2col overhead.
524/// For each output pixel, load 3×3×C_in input values and multiply with kernel.
525/// Accumulate C_out output channels using SIMD FMA.
526///
527/// When stride_w == 1, processes two adjacent output pixels at a time.
528/// Adjacent pixels at (ox, ox+1) share input columns: pixel ox uses columns
529/// [ix, ix+1, ix+2] and pixel ox+1 uses [ix+1, ix+2, ix+3]. The middle two
530/// columns are shared, saving ~33% of input loads.
531#[cfg(target_arch = "aarch64")]
532#[target_feature(enable = "neon")]
533#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
534unsafe fn conv2d_3x3_direct_neon(
535    input: &[f32],      // [H, W, C_in] NHWC (batch-dim already stripped)
536    kernel: &[f32],     // [3, 3, C_in, C_out]
537    output: &mut [f32], // [out_H, out_W, C_out]
538    w: usize,
539    c_in: usize,
540    c_out: usize,
541    out_h: usize,
542    out_w: usize,
543    stride_h: usize,
544    stride_w: usize,
545) {
546    use std::arch::aarch64::*;
547
548    // Bounds proof: max input index = ((out_h-1)*stride_h + 2) * w * c_in + (out_w-1)*stride_w + 2) * c_in + c_in - 1
549    debug_assert!(
550        input.len() >= ((out_h.saturating_sub(1)) * stride_h + 3) * w * c_in,
551        "conv2d_3x3_direct_neon: input too small"
552    );
553    debug_assert!(
554        output.len() >= out_h * out_w * c_out,
555        "conv2d_3x3_direct_neon: output too small"
556    );
557    debug_assert!(
558        kernel.len() >= 3 * 3 * c_in * c_out,
559        "conv2d_3x3_direct_neon: kernel too small"
560    );
561
562    for oy in 0..out_h {
563        let iy_base = oy * stride_h;
564
565        // When stride_w == 1 and we have at least 2 output pixels remaining,
566        // process pairs of adjacent ox positions. For kernel column kx:
567        //   pixel at ox   reads input column (ix_base + kx)
568        //   pixel at ox+1 reads input column (ix_base + 1 + kx) = (ix_base + kx + 1)
569        // So across kx=0,1,2 the pair reads columns ix_base..ix_base+3,
570        // and columns ix_base+1 and ix_base+2 are shared.
571        let mut ox = 0usize;
572        if stride_w == 1 {
573            while ox + 2 <= out_w {
574                let ix_base = ox; // stride_w == 1
575                let out_off_a = (oy * out_w + ox) * c_out;
576                let out_off_b = out_off_a + c_out;
577
578                let mut co = 0;
579                while co + 8 <= c_out {
580                    let mut acc_a0 = vdupq_n_f32(0.0);
581                    let mut acc_a1 = vdupq_n_f32(0.0);
582                    let mut acc_b0 = vdupq_n_f32(0.0);
583                    let mut acc_b1 = vdupq_n_f32(0.0);
584
585                    for ky in 0..3 {
586                        let iy = iy_base + ky;
587                        let row_base = iy * w;
588                        // Load input values for 4 adjacent columns: ix_base..ix_base+3
589                        // pixel A uses cols 0,1,2; pixel B uses cols 1,2,3
590                        for ci in 0..c_in {
591                            let in0 = *input.get_unchecked((row_base + ix_base) * c_in + ci);
592                            let in1 = *input.get_unchecked((row_base + ix_base + 1) * c_in + ci);
593                            let in2 = *input.get_unchecked((row_base + ix_base + 2) * c_in + ci);
594                            let in3 = *input.get_unchecked((row_base + ix_base + 3) * c_in + ci);
595
596                            // kernel weights for kx=0,1,2
597                            let k0_off = ky * 3 * c_in * c_out + ci * c_out + co;
598                            let k1_off = (ky * 3 + 1) * c_in * c_out + ci * c_out + co;
599                            let k2_off = (ky * 3 + 2) * c_in * c_out + ci * c_out + co;
600                            let kw0_lo = vld1q_f32(kernel.as_ptr().add(k0_off));
601                            let kw0_hi = vld1q_f32(kernel.as_ptr().add(k0_off + 4));
602                            let kw1_lo = vld1q_f32(kernel.as_ptr().add(k1_off));
603                            let kw1_hi = vld1q_f32(kernel.as_ptr().add(k1_off + 4));
604                            let kw2_lo = vld1q_f32(kernel.as_ptr().add(k2_off));
605                            let kw2_hi = vld1q_f32(kernel.as_ptr().add(k2_off + 4));
606
607                            // Pixel A: in0*k0 + in1*k1 + in2*k2
608                            let va0 = vdupq_n_f32(in0);
609                            let va1 = vdupq_n_f32(in1);
610                            let va2 = vdupq_n_f32(in2);
611                            acc_a0 = vfmaq_f32(acc_a0, va0, kw0_lo);
612                            acc_a1 = vfmaq_f32(acc_a1, va0, kw0_hi);
613                            acc_a0 = vfmaq_f32(acc_a0, va1, kw1_lo);
614                            acc_a1 = vfmaq_f32(acc_a1, va1, kw1_hi);
615                            acc_a0 = vfmaq_f32(acc_a0, va2, kw2_lo);
616                            acc_a1 = vfmaq_f32(acc_a1, va2, kw2_hi);
617
618                            // Pixel B: in1*k0 + in2*k1 + in3*k2
619                            let vb3 = vdupq_n_f32(in3);
620                            acc_b0 = vfmaq_f32(acc_b0, va1, kw0_lo);
621                            acc_b1 = vfmaq_f32(acc_b1, va1, kw0_hi);
622                            acc_b0 = vfmaq_f32(acc_b0, va2, kw1_lo);
623                            acc_b1 = vfmaq_f32(acc_b1, va2, kw1_hi);
624                            acc_b0 = vfmaq_f32(acc_b0, vb3, kw2_lo);
625                            acc_b1 = vfmaq_f32(acc_b1, vb3, kw2_hi);
626                        }
627                    }
628
629                    vst1q_f32(output.as_mut_ptr().add(out_off_a + co), acc_a0);
630                    vst1q_f32(output.as_mut_ptr().add(out_off_a + co + 4), acc_a1);
631                    vst1q_f32(output.as_mut_ptr().add(out_off_b + co), acc_b0);
632                    vst1q_f32(output.as_mut_ptr().add(out_off_b + co + 4), acc_b1);
633                    co += 8;
634                }
635
636                while co + 4 <= c_out {
637                    let mut acc_a = vdupq_n_f32(0.0);
638                    let mut acc_b = vdupq_n_f32(0.0);
639
640                    for ky in 0..3 {
641                        let iy = iy_base + ky;
642                        let row_base = iy * w;
643                        for ci in 0..c_in {
644                            let in0 = *input.get_unchecked((row_base + ix_base) * c_in + ci);
645                            let in1 = *input.get_unchecked((row_base + ix_base + 1) * c_in + ci);
646                            let in2 = *input.get_unchecked((row_base + ix_base + 2) * c_in + ci);
647                            let in3 = *input.get_unchecked((row_base + ix_base + 3) * c_in + ci);
648
649                            let k0_off = ky * 3 * c_in * c_out + ci * c_out + co;
650                            let k1_off = (ky * 3 + 1) * c_in * c_out + ci * c_out + co;
651                            let k2_off = (ky * 3 + 2) * c_in * c_out + ci * c_out + co;
652                            let kw0 = vld1q_f32(kernel.as_ptr().add(k0_off));
653                            let kw1 = vld1q_f32(kernel.as_ptr().add(k1_off));
654                            let kw2 = vld1q_f32(kernel.as_ptr().add(k2_off));
655
656                            let va0 = vdupq_n_f32(in0);
657                            let va1 = vdupq_n_f32(in1);
658                            let va2 = vdupq_n_f32(in2);
659                            acc_a = vfmaq_f32(acc_a, va0, kw0);
660                            acc_a = vfmaq_f32(acc_a, va1, kw1);
661                            acc_a = vfmaq_f32(acc_a, va2, kw2);
662
663                            let vb3 = vdupq_n_f32(in3);
664                            acc_b = vfmaq_f32(acc_b, va1, kw0);
665                            acc_b = vfmaq_f32(acc_b, va2, kw1);
666                            acc_b = vfmaq_f32(acc_b, vb3, kw2);
667                        }
668                    }
669
670                    vst1q_f32(output.as_mut_ptr().add(out_off_a + co), acc_a);
671                    vst1q_f32(output.as_mut_ptr().add(out_off_b + co), acc_b);
672                    co += 4;
673                }
674
675                while co < c_out {
676                    let mut acc_a = 0.0f32;
677                    let mut acc_b = 0.0f32;
678                    for ky in 0..3 {
679                        let iy = iy_base + ky;
680                        let row_base = iy * w;
681                        for ci in 0..c_in {
682                            let in0 = input[(row_base + ix_base) * c_in + ci];
683                            let in1 = input[(row_base + ix_base + 1) * c_in + ci];
684                            let in2 = input[(row_base + ix_base + 2) * c_in + ci];
685                            let in3 = input[(row_base + ix_base + 3) * c_in + ci];
686                            let k0 = kernel[ky * 3 * c_in * c_out + ci * c_out + co];
687                            let k1 = kernel[(ky * 3 + 1) * c_in * c_out + ci * c_out + co];
688                            let k2 = kernel[(ky * 3 + 2) * c_in * c_out + ci * c_out + co];
689                            acc_a += in0 * k0 + in1 * k1 + in2 * k2;
690                            acc_b += in1 * k0 + in2 * k1 + in3 * k2;
691                        }
692                    }
693                    *output.get_unchecked_mut(out_off_a + co) = acc_a;
694                    *output.get_unchecked_mut(out_off_b + co) = acc_b;
695                    co += 1;
696                }
697
698                ox += 2;
699            }
700        }
701
702        // Handle remaining single pixels (odd out_w or stride_w != 1).
703        while ox < out_w {
704            let ix_base = ox * stride_w;
705            let out_off = (oy * out_w + ox) * c_out;
706
707            let mut co = 0;
708            while co + 8 <= c_out {
709                let mut acc0 = vdupq_n_f32(0.0);
710                let mut acc1 = vdupq_n_f32(0.0);
711
712                for ky in 0..3 {
713                    for kx in 0..3 {
714                        let iy = iy_base + ky;
715                        let ix = ix_base + kx;
716                        let in_off = (iy * w + ix) * c_in;
717                        let k_base = (ky * 3 + kx) * c_in * c_out;
718
719                        for ci in 0..c_in {
720                            let iv = vdupq_n_f32(*input.get_unchecked(in_off + ci));
721                            let koff = k_base + ci * c_out + co;
722                            acc0 = vfmaq_f32(acc0, iv, vld1q_f32(kernel.as_ptr().add(koff)));
723                            acc1 = vfmaq_f32(acc1, iv, vld1q_f32(kernel.as_ptr().add(koff + 4)));
724                        }
725                    }
726                }
727
728                vst1q_f32(output.as_mut_ptr().add(out_off + co), acc0);
729                vst1q_f32(output.as_mut_ptr().add(out_off + co + 4), acc1);
730                co += 8;
731            }
732
733            while co + 4 <= c_out {
734                let mut acc = vdupq_n_f32(0.0);
735                for ky in 0..3 {
736                    for kx in 0..3 {
737                        let iy = iy_base + ky;
738                        let ix = ix_base + kx;
739                        let in_off = (iy * w + ix) * c_in;
740                        for ci in 0..c_in {
741                            let iv = vdupq_n_f32(*input.get_unchecked(in_off + ci));
742                            acc = vfmaq_f32(
743                                acc,
744                                iv,
745                                vld1q_f32(
746                                    kernel
747                                        .as_ptr()
748                                        .add((ky * 3 + kx) * c_in * c_out + ci * c_out + co),
749                                ),
750                            );
751                        }
752                    }
753                }
754                vst1q_f32(output.as_mut_ptr().add(out_off + co), acc);
755                co += 4;
756            }
757
758            // Handle remaining channels scalar
759            while co < c_out {
760                let mut acc = 0.0f32;
761                for ky in 0..3 {
762                    for kx in 0..3 {
763                        let iy = iy_base + ky;
764                        let ix = ix_base + kx;
765                        for ci in 0..c_in {
766                            acc += input[(iy * w + ix) * c_in + ci]
767                                * kernel[(ky * 3 + kx) * c_in * c_out + ci * c_out + co];
768                        }
769                    }
770                }
771                *output.get_unchecked_mut(out_off + co) = acc;
772                co += 1;
773            }
774
775            ox += 1;
776        }
777    }
778}
779
780/// Direct 3×3 convolution microkernel for x86_64 with AVX-256 + FMA.
781/// Processes 8 output channels per iteration (vs 4 for SSE), doubling
782/// throughput on the inner c_out loop. Falls back to scalar for tail channels.
783#[cfg(target_arch = "x86_64")]
784#[target_feature(enable = "avx", enable = "fma")]
785#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
786unsafe fn conv2d_3x3_direct_avx(
787    input: &[f32],      // [H, W, C_in] NHWC (batch-dim already stripped)
788    kernel: &[f32],     // [3, 3, C_in, C_out]
789    output: &mut [f32], // [out_H, out_W, C_out]
790    w: usize,
791    c_in: usize,
792    c_out: usize,
793    out_h: usize,
794    out_w: usize,
795    stride_h: usize,
796    stride_w: usize,
797) {
798    use std::arch::x86_64::*;
799
800    for oy in 0..out_h {
801        let iy_base = oy * stride_h;
802
803        let mut ox = 0usize;
804        if stride_w == 1 {
805            while ox + 2 <= out_w {
806                let ix_base = ox; // stride_w == 1
807                let out_off_a = (oy * out_w + ox) * c_out;
808                let out_off_b = out_off_a + c_out;
809
810                // Process 16 output channels per iteration (2x AVX-256 registers)
811                let mut co = 0;
812                while co + 16 <= c_out {
813                    let mut acc_a0 = _mm256_setzero_ps();
814                    let mut acc_a1 = _mm256_setzero_ps();
815                    let mut acc_b0 = _mm256_setzero_ps();
816                    let mut acc_b1 = _mm256_setzero_ps();
817
818                    for ky in 0..3 {
819                        let iy = iy_base + ky;
820                        let row_base = iy * w;
821                        for ci in 0..c_in {
822                            let in0 = *input.get_unchecked((row_base + ix_base) * c_in + ci);
823                            let in1 = *input.get_unchecked((row_base + ix_base + 1) * c_in + ci);
824                            let in2 = *input.get_unchecked((row_base + ix_base + 2) * c_in + ci);
825                            let in3 = *input.get_unchecked((row_base + ix_base + 3) * c_in + ci);
826
827                            let k0_off = ky * 3 * c_in * c_out + ci * c_out + co;
828                            let k1_off = (ky * 3 + 1) * c_in * c_out + ci * c_out + co;
829                            let k2_off = (ky * 3 + 2) * c_in * c_out + ci * c_out + co;
830                            let kw0_lo = _mm256_loadu_ps(kernel.as_ptr().add(k0_off));
831                            let kw0_hi = _mm256_loadu_ps(kernel.as_ptr().add(k0_off + 8));
832                            let kw1_lo = _mm256_loadu_ps(kernel.as_ptr().add(k1_off));
833                            let kw1_hi = _mm256_loadu_ps(kernel.as_ptr().add(k1_off + 8));
834                            let kw2_lo = _mm256_loadu_ps(kernel.as_ptr().add(k2_off));
835                            let kw2_hi = _mm256_loadu_ps(kernel.as_ptr().add(k2_off + 8));
836
837                            // Pixel A: in0*k0 + in1*k1 + in2*k2
838                            let va0 = _mm256_set1_ps(in0);
839                            let va1 = _mm256_set1_ps(in1);
840                            let va2 = _mm256_set1_ps(in2);
841                            acc_a0 = _mm256_fmadd_ps(va0, kw0_lo, acc_a0);
842                            acc_a1 = _mm256_fmadd_ps(va0, kw0_hi, acc_a1);
843                            acc_a0 = _mm256_fmadd_ps(va1, kw1_lo, acc_a0);
844                            acc_a1 = _mm256_fmadd_ps(va1, kw1_hi, acc_a1);
845                            acc_a0 = _mm256_fmadd_ps(va2, kw2_lo, acc_a0);
846                            acc_a1 = _mm256_fmadd_ps(va2, kw2_hi, acc_a1);
847
848                            // Pixel B: in1*k0 + in2*k1 + in3*k2
849                            let vb3 = _mm256_set1_ps(in3);
850                            acc_b0 = _mm256_fmadd_ps(va1, kw0_lo, acc_b0);
851                            acc_b1 = _mm256_fmadd_ps(va1, kw0_hi, acc_b1);
852                            acc_b0 = _mm256_fmadd_ps(va2, kw1_lo, acc_b0);
853                            acc_b1 = _mm256_fmadd_ps(va2, kw1_hi, acc_b1);
854                            acc_b0 = _mm256_fmadd_ps(vb3, kw2_lo, acc_b0);
855                            acc_b1 = _mm256_fmadd_ps(vb3, kw2_hi, acc_b1);
856                        }
857                    }
858
859                    _mm256_storeu_ps(output.as_mut_ptr().add(out_off_a + co), acc_a0);
860                    _mm256_storeu_ps(output.as_mut_ptr().add(out_off_a + co + 8), acc_a1);
861                    _mm256_storeu_ps(output.as_mut_ptr().add(out_off_b + co), acc_b0);
862                    _mm256_storeu_ps(output.as_mut_ptr().add(out_off_b + co + 8), acc_b1);
863                    co += 16;
864                }
865
866                // Process 8 output channels with single AVX register pair
867                while co + 8 <= c_out {
868                    let mut acc_a = _mm256_setzero_ps();
869                    let mut acc_b = _mm256_setzero_ps();
870
871                    for ky in 0..3 {
872                        let iy = iy_base + ky;
873                        let row_base = iy * w;
874                        for ci in 0..c_in {
875                            let in0 = *input.get_unchecked((row_base + ix_base) * c_in + ci);
876                            let in1 = *input.get_unchecked((row_base + ix_base + 1) * c_in + ci);
877                            let in2 = *input.get_unchecked((row_base + ix_base + 2) * c_in + ci);
878                            let in3 = *input.get_unchecked((row_base + ix_base + 3) * c_in + ci);
879
880                            let k0_off = ky * 3 * c_in * c_out + ci * c_out + co;
881                            let k1_off = (ky * 3 + 1) * c_in * c_out + ci * c_out + co;
882                            let k2_off = (ky * 3 + 2) * c_in * c_out + ci * c_out + co;
883                            let kw0 = _mm256_loadu_ps(kernel.as_ptr().add(k0_off));
884                            let kw1 = _mm256_loadu_ps(kernel.as_ptr().add(k1_off));
885                            let kw2 = _mm256_loadu_ps(kernel.as_ptr().add(k2_off));
886
887                            let va0 = _mm256_set1_ps(in0);
888                            let va1 = _mm256_set1_ps(in1);
889                            let va2 = _mm256_set1_ps(in2);
890                            acc_a = _mm256_fmadd_ps(va0, kw0, acc_a);
891                            acc_a = _mm256_fmadd_ps(va1, kw1, acc_a);
892                            acc_a = _mm256_fmadd_ps(va2, kw2, acc_a);
893
894                            let vb3 = _mm256_set1_ps(in3);
895                            acc_b = _mm256_fmadd_ps(va1, kw0, acc_b);
896                            acc_b = _mm256_fmadd_ps(va2, kw1, acc_b);
897                            acc_b = _mm256_fmadd_ps(vb3, kw2, acc_b);
898                        }
899                    }
900
901                    _mm256_storeu_ps(output.as_mut_ptr().add(out_off_a + co), acc_a);
902                    _mm256_storeu_ps(output.as_mut_ptr().add(out_off_b + co), acc_b);
903                    co += 8;
904                }
905
906                // Scalar tail for remaining channels
907                while co < c_out {
908                    let mut acc_a = 0.0f32;
909                    let mut acc_b = 0.0f32;
910                    for ky in 0..3 {
911                        let iy = iy_base + ky;
912                        let row_base = iy * w;
913                        for ci in 0..c_in {
914                            let in0 = input[(row_base + ix_base) * c_in + ci];
915                            let in1 = input[(row_base + ix_base + 1) * c_in + ci];
916                            let in2 = input[(row_base + ix_base + 2) * c_in + ci];
917                            let in3 = input[(row_base + ix_base + 3) * c_in + ci];
918                            let k0 = kernel[ky * 3 * c_in * c_out + ci * c_out + co];
919                            let k1 = kernel[(ky * 3 + 1) * c_in * c_out + ci * c_out + co];
920                            let k2 = kernel[(ky * 3 + 2) * c_in * c_out + ci * c_out + co];
921                            acc_a += in0 * k0 + in1 * k1 + in2 * k2;
922                            acc_b += in1 * k0 + in2 * k1 + in3 * k2;
923                        }
924                    }
925                    *output.get_unchecked_mut(out_off_a + co) = acc_a;
926                    *output.get_unchecked_mut(out_off_b + co) = acc_b;
927                    co += 1;
928                }
929
930                ox += 2;
931            }
932        }
933
934        // Handle remaining single pixels (odd out_w or stride_w != 1).
935        while ox < out_w {
936            let ix_base = ox * stride_w;
937            let out_off = (oy * out_w + ox) * c_out;
938
939            let mut co = 0;
940            while co + 8 <= c_out {
941                let mut acc = _mm256_setzero_ps();
942
943                for ky in 0..3 {
944                    for kx in 0..3 {
945                        let iy = iy_base + ky;
946                        let ix = ix_base + kx;
947                        let in_off = (iy * w + ix) * c_in;
948                        let k_base = (ky * 3 + kx) * c_in * c_out;
949
950                        for ci in 0..c_in {
951                            let iv = _mm256_set1_ps(*input.get_unchecked(in_off + ci));
952                            let koff = k_base + ci * c_out + co;
953                            acc = _mm256_fmadd_ps(
954                                iv,
955                                _mm256_loadu_ps(kernel.as_ptr().add(koff)),
956                                acc,
957                            );
958                        }
959                    }
960                }
961
962                _mm256_storeu_ps(output.as_mut_ptr().add(out_off + co), acc);
963                co += 8;
964            }
965
966            // Handle remaining channels scalar
967            while co < c_out {
968                let mut acc = 0.0f32;
969                for ky in 0..3 {
970                    for kx in 0..3 {
971                        let iy = iy_base + ky;
972                        let ix = ix_base + kx;
973                        for ci in 0..c_in {
974                            acc += input[(iy * w + ix) * c_in + ci]
975                                * kernel[(ky * 3 + kx) * c_in * c_out + ci * c_out + co];
976                        }
977                    }
978                }
979                *output.get_unchecked_mut(out_off + co) = acc;
980                co += 1;
981            }
982
983            ox += 1;
984        }
985    }
986}
987
988/// Direct 3×3 convolution microkernel for x86_64 with SSE + FMA.
989/// Mirrors the NEON implementation: processes two adjacent output pixels at a
990/// time when stride_w == 1, sharing overlapping input columns to save loads.
991#[cfg(target_arch = "x86_64")]
992#[target_feature(enable = "sse", enable = "fma")]
993#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
994unsafe fn conv2d_3x3_direct_sse(
995    input: &[f32],      // [H, W, C_in] NHWC (batch-dim already stripped)
996    kernel: &[f32],     // [3, 3, C_in, C_out]
997    output: &mut [f32], // [out_H, out_W, C_out]
998    w: usize,
999    c_in: usize,
1000    c_out: usize,
1001    out_h: usize,
1002    out_w: usize,
1003    stride_h: usize,
1004    stride_w: usize,
1005) {
1006    #[cfg(target_arch = "x86_64")]
1007    use std::arch::x86_64::*;
1008
1009    for oy in 0..out_h {
1010        let iy_base = oy * stride_h;
1011
1012        let mut ox = 0usize;
1013        if stride_w == 1 {
1014            while ox + 2 <= out_w {
1015                let ix_base = ox; // stride_w == 1
1016                let out_off_a = (oy * out_w + ox) * c_out;
1017                let out_off_b = out_off_a + c_out;
1018
1019                let mut co = 0;
1020                while co + 8 <= c_out {
1021                    let mut acc_a0 = _mm_setzero_ps();
1022                    let mut acc_a1 = _mm_setzero_ps();
1023                    let mut acc_b0 = _mm_setzero_ps();
1024                    let mut acc_b1 = _mm_setzero_ps();
1025
1026                    for ky in 0..3 {
1027                        let iy = iy_base + ky;
1028                        let row_base = iy * w;
1029                        for ci in 0..c_in {
1030                            let in0 = *input.get_unchecked((row_base + ix_base) * c_in + ci);
1031                            let in1 = *input.get_unchecked((row_base + ix_base + 1) * c_in + ci);
1032                            let in2 = *input.get_unchecked((row_base + ix_base + 2) * c_in + ci);
1033                            let in3 = *input.get_unchecked((row_base + ix_base + 3) * c_in + ci);
1034
1035                            let k0_off = ky * 3 * c_in * c_out + ci * c_out + co;
1036                            let k1_off = (ky * 3 + 1) * c_in * c_out + ci * c_out + co;
1037                            let k2_off = (ky * 3 + 2) * c_in * c_out + ci * c_out + co;
1038                            let kw0_lo = _mm_loadu_ps(kernel.as_ptr().add(k0_off));
1039                            let kw0_hi = _mm_loadu_ps(kernel.as_ptr().add(k0_off + 4));
1040                            let kw1_lo = _mm_loadu_ps(kernel.as_ptr().add(k1_off));
1041                            let kw1_hi = _mm_loadu_ps(kernel.as_ptr().add(k1_off + 4));
1042                            let kw2_lo = _mm_loadu_ps(kernel.as_ptr().add(k2_off));
1043                            let kw2_hi = _mm_loadu_ps(kernel.as_ptr().add(k2_off + 4));
1044
1045                            // Pixel A: in0*k0 + in1*k1 + in2*k2
1046                            let va0 = _mm_set1_ps(in0);
1047                            let va1 = _mm_set1_ps(in1);
1048                            let va2 = _mm_set1_ps(in2);
1049                            acc_a0 = _mm_fmadd_ps(va0, kw0_lo, acc_a0);
1050                            acc_a1 = _mm_fmadd_ps(va0, kw0_hi, acc_a1);
1051                            acc_a0 = _mm_fmadd_ps(va1, kw1_lo, acc_a0);
1052                            acc_a1 = _mm_fmadd_ps(va1, kw1_hi, acc_a1);
1053                            acc_a0 = _mm_fmadd_ps(va2, kw2_lo, acc_a0);
1054                            acc_a1 = _mm_fmadd_ps(va2, kw2_hi, acc_a1);
1055
1056                            // Pixel B: in1*k0 + in2*k1 + in3*k2
1057                            let vb3 = _mm_set1_ps(in3);
1058                            acc_b0 = _mm_fmadd_ps(va1, kw0_lo, acc_b0);
1059                            acc_b1 = _mm_fmadd_ps(va1, kw0_hi, acc_b1);
1060                            acc_b0 = _mm_fmadd_ps(va2, kw1_lo, acc_b0);
1061                            acc_b1 = _mm_fmadd_ps(va2, kw1_hi, acc_b1);
1062                            acc_b0 = _mm_fmadd_ps(vb3, kw2_lo, acc_b0);
1063                            acc_b1 = _mm_fmadd_ps(vb3, kw2_hi, acc_b1);
1064                        }
1065                    }
1066
1067                    _mm_storeu_ps(output.as_mut_ptr().add(out_off_a + co), acc_a0);
1068                    _mm_storeu_ps(output.as_mut_ptr().add(out_off_a + co + 4), acc_a1);
1069                    _mm_storeu_ps(output.as_mut_ptr().add(out_off_b + co), acc_b0);
1070                    _mm_storeu_ps(output.as_mut_ptr().add(out_off_b + co + 4), acc_b1);
1071                    co += 8;
1072                }
1073
1074                while co + 4 <= c_out {
1075                    let mut acc_a = _mm_setzero_ps();
1076                    let mut acc_b = _mm_setzero_ps();
1077
1078                    for ky in 0..3 {
1079                        let iy = iy_base + ky;
1080                        let row_base = iy * w;
1081                        for ci in 0..c_in {
1082                            let in0 = *input.get_unchecked((row_base + ix_base) * c_in + ci);
1083                            let in1 = *input.get_unchecked((row_base + ix_base + 1) * c_in + ci);
1084                            let in2 = *input.get_unchecked((row_base + ix_base + 2) * c_in + ci);
1085                            let in3 = *input.get_unchecked((row_base + ix_base + 3) * c_in + ci);
1086
1087                            let k0_off = ky * 3 * c_in * c_out + ci * c_out + co;
1088                            let k1_off = (ky * 3 + 1) * c_in * c_out + ci * c_out + co;
1089                            let k2_off = (ky * 3 + 2) * c_in * c_out + ci * c_out + co;
1090                            let kw0 = _mm_loadu_ps(kernel.as_ptr().add(k0_off));
1091                            let kw1 = _mm_loadu_ps(kernel.as_ptr().add(k1_off));
1092                            let kw2 = _mm_loadu_ps(kernel.as_ptr().add(k2_off));
1093
1094                            let va0 = _mm_set1_ps(in0);
1095                            let va1 = _mm_set1_ps(in1);
1096                            let va2 = _mm_set1_ps(in2);
1097                            acc_a = _mm_fmadd_ps(va0, kw0, acc_a);
1098                            acc_a = _mm_fmadd_ps(va1, kw1, acc_a);
1099                            acc_a = _mm_fmadd_ps(va2, kw2, acc_a);
1100
1101                            let vb3 = _mm_set1_ps(in3);
1102                            acc_b = _mm_fmadd_ps(va1, kw0, acc_b);
1103                            acc_b = _mm_fmadd_ps(va2, kw1, acc_b);
1104                            acc_b = _mm_fmadd_ps(vb3, kw2, acc_b);
1105                        }
1106                    }
1107
1108                    _mm_storeu_ps(output.as_mut_ptr().add(out_off_a + co), acc_a);
1109                    _mm_storeu_ps(output.as_mut_ptr().add(out_off_b + co), acc_b);
1110                    co += 4;
1111                }
1112
1113                while co < c_out {
1114                    let mut acc_a = 0.0f32;
1115                    let mut acc_b = 0.0f32;
1116                    for ky in 0..3 {
1117                        let iy = iy_base + ky;
1118                        let row_base = iy * w;
1119                        for ci in 0..c_in {
1120                            let in0 = input[(row_base + ix_base) * c_in + ci];
1121                            let in1 = input[(row_base + ix_base + 1) * c_in + ci];
1122                            let in2 = input[(row_base + ix_base + 2) * c_in + ci];
1123                            let in3 = input[(row_base + ix_base + 3) * c_in + ci];
1124                            let k0 = kernel[ky * 3 * c_in * c_out + ci * c_out + co];
1125                            let k1 = kernel[(ky * 3 + 1) * c_in * c_out + ci * c_out + co];
1126                            let k2 = kernel[(ky * 3 + 2) * c_in * c_out + ci * c_out + co];
1127                            acc_a += in0 * k0 + in1 * k1 + in2 * k2;
1128                            acc_b += in1 * k0 + in2 * k1 + in3 * k2;
1129                        }
1130                    }
1131                    *output.get_unchecked_mut(out_off_a + co) = acc_a;
1132                    *output.get_unchecked_mut(out_off_b + co) = acc_b;
1133                    co += 1;
1134                }
1135
1136                ox += 2;
1137            }
1138        }
1139
1140        // Handle remaining single pixels (odd out_w or stride_w != 1).
1141        while ox < out_w {
1142            let ix_base = ox * stride_w;
1143            let out_off = (oy * out_w + ox) * c_out;
1144
1145            let mut co = 0;
1146            while co + 8 <= c_out {
1147                let mut acc0 = _mm_setzero_ps();
1148                let mut acc1 = _mm_setzero_ps();
1149
1150                for ky in 0..3 {
1151                    for kx in 0..3 {
1152                        let iy = iy_base + ky;
1153                        let ix = ix_base + kx;
1154                        let in_off = (iy * w + ix) * c_in;
1155                        let k_base = (ky * 3 + kx) * c_in * c_out;
1156
1157                        for ci in 0..c_in {
1158                            let iv = _mm_set1_ps(*input.get_unchecked(in_off + ci));
1159                            let koff = k_base + ci * c_out + co;
1160                            acc0 = _mm_fmadd_ps(iv, _mm_loadu_ps(kernel.as_ptr().add(koff)), acc0);
1161                            acc1 =
1162                                _mm_fmadd_ps(iv, _mm_loadu_ps(kernel.as_ptr().add(koff + 4)), acc1);
1163                        }
1164                    }
1165                }
1166
1167                _mm_storeu_ps(output.as_mut_ptr().add(out_off + co), acc0);
1168                _mm_storeu_ps(output.as_mut_ptr().add(out_off + co + 4), acc1);
1169                co += 8;
1170            }
1171
1172            while co + 4 <= c_out {
1173                let mut acc = _mm_setzero_ps();
1174                for ky in 0..3 {
1175                    for kx in 0..3 {
1176                        let iy = iy_base + ky;
1177                        let ix = ix_base + kx;
1178                        let in_off = (iy * w + ix) * c_in;
1179                        for ci in 0..c_in {
1180                            let iv = _mm_set1_ps(*input.get_unchecked(in_off + ci));
1181                            acc = _mm_fmadd_ps(
1182                                iv,
1183                                _mm_loadu_ps(
1184                                    kernel
1185                                        .as_ptr()
1186                                        .add((ky * 3 + kx) * c_in * c_out + ci * c_out + co),
1187                                ),
1188                                acc,
1189                            );
1190                        }
1191                    }
1192                }
1193                _mm_storeu_ps(output.as_mut_ptr().add(out_off + co), acc);
1194                co += 4;
1195            }
1196
1197            // Handle remaining channels scalar
1198            while co < c_out {
1199                let mut acc = 0.0f32;
1200                for ky in 0..3 {
1201                    for kx in 0..3 {
1202                        let iy = iy_base + ky;
1203                        let ix = ix_base + kx;
1204                        for ci in 0..c_in {
1205                            acc += input[(iy * w + ix) * c_in + ci]
1206                                * kernel[(ky * 3 + kx) * c_in * c_out + ci * c_out + co];
1207                        }
1208                    }
1209                }
1210                *output.get_unchecked_mut(out_off + co) = acc;
1211                co += 1;
1212            }
1213
1214            ox += 1;
1215        }
1216    }
1217}
1218
1219/// FMA: out[i] += kernel[i] * input_val, SIMD-accelerated
1220#[allow(unsafe_code)]
1221fn conv_fma_row(out: &mut [f32], kernel: &[f32], input_val: f32) {
1222    let len = out.len();
1223    debug_assert_eq!(len, kernel.len());
1224
1225    if cfg!(miri) || len < 4 {
1226        for i in 0..len {
1227            out[i] += kernel[i] * input_val;
1228        }
1229        return;
1230    }
1231
1232    #[cfg(target_arch = "aarch64")]
1233    {
1234        if std::arch::is_aarch64_feature_detected!("neon") {
1235            unsafe { conv_fma_neon(out, kernel, input_val) };
1236            return;
1237        }
1238    }
1239
1240    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1241    {
1242        if std::is_x86_feature_detected!("avx") {
1243            unsafe { conv_fma_avx(out, kernel, input_val) };
1244            return;
1245        }
1246        if std::is_x86_feature_detected!("sse") {
1247            unsafe { conv_fma_sse(out, kernel, input_val) };
1248            return;
1249        }
1250    }
1251
1252    for i in 0..len {
1253        out[i] += kernel[i] * input_val;
1254    }
1255}
1256
1257#[cfg(target_arch = "aarch64")]
1258#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
1259#[target_feature(enable = "neon")]
1260unsafe fn conv_fma_neon(out: &mut [f32], kernel: &[f32], input_val: f32) {
1261    use std::arch::aarch64::*;
1262    let len = out.len();
1263    let op = out.as_mut_ptr();
1264    let kp = kernel.as_ptr();
1265    let v_input = vdupq_n_f32(input_val);
1266    let mut i = 0usize;
1267    while i + 4 <= len {
1268        let o = vld1q_f32(op.add(i));
1269        let k = vld1q_f32(kp.add(i));
1270        vst1q_f32(op.add(i), vfmaq_f32(o, k, v_input));
1271        i += 4;
1272    }
1273    while i < len {
1274        *op.add(i) += *kp.add(i) * input_val;
1275        i += 1;
1276    }
1277}
1278
1279#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1280#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
1281#[target_feature(enable = "sse")]
1282unsafe fn conv_fma_sse(out: &mut [f32], kernel: &[f32], input_val: f32) {
1283    #[cfg(target_arch = "x86")]
1284    use std::arch::x86::*;
1285    #[cfg(target_arch = "x86_64")]
1286    use std::arch::x86_64::*;
1287    let len = out.len();
1288    let op = out.as_mut_ptr();
1289    let kp = kernel.as_ptr();
1290    let v_input = _mm_set1_ps(input_val);
1291    let mut i = 0usize;
1292    while i + 4 <= len {
1293        let o = _mm_loadu_ps(op.add(i));
1294        let k = _mm_loadu_ps(kp.add(i));
1295        _mm_storeu_ps(op.add(i), _mm_add_ps(o, _mm_mul_ps(k, v_input)));
1296        i += 4;
1297    }
1298    while i < len {
1299        *op.add(i) += *kp.add(i) * input_val;
1300        i += 1;
1301    }
1302}
1303
1304#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1305#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
1306#[target_feature(enable = "avx")]
1307unsafe fn conv_fma_avx(out: &mut [f32], kernel: &[f32], input_val: f32) {
1308    #[cfg(target_arch = "x86")]
1309    use std::arch::x86::*;
1310    #[cfg(target_arch = "x86_64")]
1311    use std::arch::x86_64::*;
1312    let len = out.len();
1313    let op = out.as_mut_ptr();
1314    let kp = kernel.as_ptr();
1315    let v_input = _mm256_set1_ps(input_val);
1316    let mut i = 0usize;
1317    while i + 8 <= len {
1318        let o = _mm256_loadu_ps(op.add(i));
1319        let k = _mm256_loadu_ps(kp.add(i));
1320        _mm256_storeu_ps(op.add(i), _mm256_add_ps(o, _mm256_mul_ps(k, v_input)));
1321        i += 8;
1322    }
1323    if i < len {
1324        conv_fma_sse(&mut out[i..], &kernel[i..], input_val);
1325    }
1326}
1327
1328/// 3D convolution: input [B, D, H, W, C_in], kernel [KD, KH, KW, C_in, C_out], output [B, OD, OH, OW, C_out]
1329/// Supports padding and stride in all 3 dimensions.
1330pub fn conv3d(
1331    input: &[f32],
1332    input_shape: &[usize], // [B, D, H, W, C_in]
1333    kernel: &[f32],
1334    kernel_shape: &[usize],         // [KD, KH, KW, C_in, C_out]
1335    stride: (usize, usize, usize),  // (d, h, w)
1336    padding: (usize, usize, usize), // (d, h, w)
1337) -> (Vec<f32>, Vec<usize>) {
1338    assert_eq!(
1339        input_shape.len(),
1340        5,
1341        "input_shape must be [B, D, H, W, C_in]"
1342    );
1343    assert_eq!(
1344        kernel_shape.len(),
1345        5,
1346        "kernel_shape must be [KD, KH, KW, C_in, C_out]"
1347    );
1348
1349    let (batch, in_d, in_h, in_w, c_in) = (
1350        input_shape[0],
1351        input_shape[1],
1352        input_shape[2],
1353        input_shape[3],
1354        input_shape[4],
1355    );
1356    let (kd, kh, kw, k_cin, c_out) = (
1357        kernel_shape[0],
1358        kernel_shape[1],
1359        kernel_shape[2],
1360        kernel_shape[3],
1361        kernel_shape[4],
1362    );
1363    let (stride_d, stride_h, stride_w) = stride;
1364    let (pad_d, pad_h, pad_w) = padding;
1365
1366    assert_eq!(c_in, k_cin, "input C_in must match kernel C_in");
1367    assert!(
1368        stride_d > 0 && stride_h > 0 && stride_w > 0,
1369        "strides must be positive"
1370    );
1371    assert_eq!(input.len(), batch * in_d * in_h * in_w * c_in);
1372    assert_eq!(kernel.len(), kd * kh * kw * c_in * c_out);
1373
1374    let out_d = (in_d + 2 * pad_d - kd) / stride_d + 1;
1375    let out_h = (in_h + 2 * pad_h - kh) / stride_h + 1;
1376    let out_w = (in_w + 2 * pad_w - kw) / stride_w + 1;
1377
1378    let output_shape = vec![batch, out_d, out_h, out_w, c_out];
1379    let out_spatial = out_d * out_h * out_w;
1380    let output_len = batch * out_spatial * c_out;
1381    let k_spatial = kd * kh * kw;
1382    let col_k = k_spatial * c_in; // im2col column length
1383
1384    // im2col + BLAS path: reshape 3D conv into matrix multiply
1385    // im2col: [out_spatial, kd*kh*kw*c_in]
1386    // kernel reshaped: [kd*kh*kw*c_in, c_out]
1387    // output = im2col @ kernel_2d → [out_spatial, c_out]
1388    #[cfg(feature = "blas")]
1389    let use_blas = !cfg!(miri) && batch == 1;
1390    #[cfg(not(feature = "blas"))]
1391    let use_blas = false;
1392
1393    if use_blas {
1394        let mut output = vec![0.0f32; output_len];
1395        let in_hwc = in_h * in_w * c_in;
1396        let in_wc = in_w * c_in;
1397
1398        for b in 0..batch {
1399            let b_in = b * in_d * in_hwc;
1400            // Build im2col matrix
1401            let mut col = vec![0.0f32; out_spatial * col_k];
1402            let mut row = 0;
1403            for od in 0..out_d {
1404                for oh in 0..out_h {
1405                    for ow in 0..out_w {
1406                        let mut col_idx = 0;
1407                        for fd in 0..kd {
1408                            let id_raw = od * stride_d + fd;
1409                            for fh in 0..kh {
1410                                let ih_raw = oh * stride_h + fh;
1411                                for fw in 0..kw {
1412                                    let iw_raw = ow * stride_w + fw;
1413                                    let in_bounds = id_raw >= pad_d
1414                                        && id_raw - pad_d < in_d
1415                                        && ih_raw >= pad_h
1416                                        && ih_raw - pad_h < in_h
1417                                        && iw_raw >= pad_w
1418                                        && iw_raw - pad_w < in_w;
1419                                    if in_bounds {
1420                                        let id = id_raw - pad_d;
1421                                        let ih = ih_raw - pad_h;
1422                                        let iw = iw_raw - pad_w;
1423                                        let base = b_in + id * in_hwc + ih * in_wc + iw * c_in;
1424                                        col[row * col_k + col_idx..row * col_k + col_idx + c_in]
1425                                            .copy_from_slice(&input[base..base + c_in]);
1426                                    }
1427                                    // else: padding zeros (already zeroed)
1428                                    col_idx += c_in;
1429                                }
1430                            }
1431                        }
1432                        row += 1;
1433                    }
1434                }
1435            }
1436
1437            // BLAS: output[b] = col @ kernel_2d
1438            let b_out = b * out_spatial * c_out;
1439            super::matmul::blas_sgemm(
1440                &col,
1441                kernel,
1442                &mut output[b_out..b_out + out_spatial * c_out],
1443                out_spatial,
1444                col_k,
1445                c_out,
1446            );
1447        }
1448        return (output, output_shape);
1449    }
1450
1451    // Fallback: naive 7-nested-loop implementation
1452    let mut output = vec![0.0f32; output_len];
1453    let in_dhwc = in_d * in_h * in_w * c_in;
1454    let in_hwc = in_h * in_w * c_in;
1455    let in_wc = in_w * c_in;
1456    let k_hwcico = kh * kw * c_in * c_out;
1457    let k_wcico = kw * c_in * c_out;
1458    let k_cico = c_in * c_out;
1459    let out_dhwco = out_d * out_h * out_w * c_out;
1460    let out_hwco = out_h * out_w * c_out;
1461    let out_wco = out_w * c_out;
1462
1463    for b in 0..batch {
1464        let b_in = b * in_dhwc;
1465        let b_out = b * out_dhwco;
1466        for od in 0..out_d {
1467            for oh in 0..out_h {
1468                for ow in 0..out_w {
1469                    let out_base = b_out + od * out_hwco + oh * out_wco + ow * c_out;
1470                    for fd in 0..kd {
1471                        let id = od * stride_d + fd;
1472                        if id < pad_d || id - pad_d >= in_d {
1473                            continue;
1474                        }
1475                        let id = id - pad_d;
1476                        for fh in 0..kh {
1477                            let ih = oh * stride_h + fh;
1478                            if ih < pad_h || ih - pad_h >= in_h {
1479                                continue;
1480                            }
1481                            let ih = ih - pad_h;
1482                            for fw in 0..kw {
1483                                let iw = ow * stride_w + fw;
1484                                if iw < pad_w || iw - pad_w >= in_w {
1485                                    continue;
1486                                }
1487                                let iw = iw - pad_w;
1488                                let in_base = b_in + id * in_hwc + ih * in_wc + iw * c_in;
1489                                let k_base = fd * k_hwcico + fh * k_wcico + fw * k_cico;
1490                                for ci in 0..c_in {
1491                                    let input_val = input[in_base + ci];
1492                                    let k_offset = k_base + ci * c_out;
1493                                    for co in 0..c_out {
1494                                        output[out_base + co] += input_val * kernel[k_offset + co];
1495                                    }
1496                                }
1497                            }
1498                        }
1499                    }
1500                }
1501            }
1502        }
1503    }
1504
1505    (output, output_shape)
1506}
1507
1508fn depthwise_conv2d_nhwc_row(
1509    input: &[f32],
1510    kernel: &[f32],
1511    bias: Option<&[f32]>,
1512    plan: DepthwiseConv2dPlan,
1513    row_idx: usize,
1514    out_row: &mut [f32],
1515) {
1516    let batch_idx = row_idx / plan.out_h;
1517    let out_y = row_idx % plan.out_h;
1518    let in_y0 = out_y * plan.stride_h;
1519    let batch_input_base = batch_idx * plan.in_h * plan.in_w * plan.channels;
1520
1521    for out_x in 0..plan.out_w {
1522        let in_x0 = out_x * plan.stride_w;
1523        let out_cell_base = out_x * plan.out_channels;
1524
1525        for out_channel in 0..plan.out_channels {
1526            let mut acc = bias.map_or(0.0, |bias_values| bias_values[out_channel]);
1527            let in_channel = out_channel / plan.depth_multiplier;
1528            let depth_index = out_channel % plan.depth_multiplier;
1529
1530            for ky in 0..plan.kernel_h {
1531                let in_y = in_y0 + ky;
1532                let input_row_base = batch_input_base + (in_y * plan.in_w + in_x0) * plan.channels;
1533                let kernel_row_base = ky * plan.kernel_w * plan.channels * plan.depth_multiplier;
1534
1535                for kx in 0..plan.kernel_w {
1536                    let input_value = input[input_row_base + kx * plan.channels + in_channel];
1537                    let kernel_index = kernel_row_base
1538                        + kx * plan.channels * plan.depth_multiplier
1539                        + in_channel * plan.depth_multiplier
1540                        + depth_index;
1541                    acc += input_value * kernel[kernel_index];
1542                }
1543            }
1544
1545            out_row[out_cell_base + out_channel] = acc;
1546        }
1547    }
1548}
1549
1550// ---------------------------------------------------------------------------
1551// im2col + BLAS GEMM fast path for conv2d
1552// ---------------------------------------------------------------------------
1553
1554/// Flatten each [kH, kW, C_in] input patch into a row of the im2col matrix.
1555///
1556/// Output `col` has shape [out_h * out_w, kH * kW * C_in] (row-major).
1557/// The input is NHWC layout (batch dimension already stripped by caller).
1558#[cfg(feature = "blas")]
1559fn im2col_nhwc(
1560    input: &[f32],
1561    in_w: usize,
1562    c: usize,
1563    kh: usize,
1564    kw: usize,
1565    stride_h: usize,
1566    stride_w: usize,
1567    out_h: usize,
1568    out_w: usize,
1569    col: &mut [f32],
1570) {
1571    let k = kh * kw * c;
1572    for oy in 0..out_h {
1573        for ox in 0..out_w {
1574            let row = oy * out_w + ox;
1575            let row_off = row * k;
1576            for ky in 0..kh {
1577                let iy = oy * stride_h + ky;
1578                for kx in 0..kw {
1579                    let ix = ox * stride_w + kx;
1580                    let src_off = (iy * in_w + ix) * c;
1581                    let dst_off = row_off + (ky * kw + kx) * c;
1582                    col[dst_off..dst_off + c].copy_from_slice(&input[src_off..src_off + c]);
1583                }
1584            }
1585        }
1586    }
1587}
1588
1589/// Conv2d via im2col + BLAS sgemm.
1590///
1591/// im2col matrix: [M, K] where M = out_h*out_w, K = kH*kW*C_in
1592/// kernel (already contiguous in NHWC): [K, N] where N = C_out
1593/// output: [M, N] which maps directly to [1, out_h, out_w, C_out]
1594#[cfg(feature = "blas")]
1595fn conv2d_im2col_gemm(
1596    plan: &Conv2dPlan,
1597    input: &[f32],
1598    kernel: &[f32],
1599    bias: Option<&[f32]>,
1600) -> Result<Tensor, KernelError> {
1601    let out_h = plan.out_h;
1602    let out_w = plan.out_w;
1603    let k = plan.kernel_h * plan.kernel_w * plan.in_channels;
1604    let m = out_h * out_w;
1605    let n = plan.out_channels;
1606
1607    // Build im2col matrix. No padding needed since build_conv2d_plan computes
1608    // output dimensions without padding and the kernel is guaranteed to fit.
1609    // SAFETY: im2col_nhwc writes every element of `col` before it is read.
1610    #[allow(unsafe_code)]
1611    let mut col = AlignedVec::<f32>::uninitialized(m * k);
1612    im2col_nhwc(
1613        input,
1614        plan.in_w,
1615        plan.in_channels,
1616        plan.kernel_h,
1617        plan.kernel_w,
1618        plan.stride_h,
1619        plan.stride_w,
1620        out_h,
1621        out_w,
1622        &mut col,
1623    );
1624
1625    // GEMM: col[m, k] @ kernel[k, n] -> output[m, n]
1626    // blas_sgemm uses beta=0, so C is overwritten — no zeroing needed.
1627    // SAFETY: blas_sgemm with beta=0 writes every element of `output`.
1628    #[allow(unsafe_code)]
1629    let mut output = AlignedVec::<f32>::uninitialized(m * n);
1630    super::matmul::blas_sgemm(&col, kernel, &mut output, m, k, n);
1631
1632    // Add bias
1633    if let Some(bias) = bias {
1634        for row in 0..m {
1635            let row_off = row * n;
1636            for c in 0..n {
1637                output[row_off + c] += bias[c];
1638            }
1639        }
1640    }
1641
1642    Tensor::from_aligned(vec![1, out_h, out_w, n], output).map_err(Into::into)
1643}