Skip to main content

oximedia_gpu/
compute_kernels.rs

1//! Pure-Rust SIMD-optimised compute kernels for image processing.
2//!
3//! All algorithms are implemented as scalar loops written in a style that
4//! LLVM / rustc can auto-vectorise (chunk-unrolled, no data-dependent
5//! branches inside hot loops).  No external C/Fortran dependencies are used.
6
7#![allow(clippy::cast_precision_loss)]
8#![allow(clippy::cast_possible_truncation)]
9#![allow(clippy::cast_sign_loss)]
10#![allow(clippy::cast_lossless)]
11
12/// Configuration for a [`ComputeKernel`].
13#[derive(Debug, Clone)]
14pub struct KernelConfig {
15    /// Whether to enable SIMD-style optimised code paths.
16    pub use_simd: bool,
17    /// Number of threads to use in future parallel extensions.
18    pub thread_count: usize,
19    /// Processing chunk size (default 256).  Must be a power of two ≥ 8.
20    pub chunk_size: usize,
21}
22
23impl Default for KernelConfig {
24    fn default() -> Self {
25        Self {
26            use_simd: true,
27            thread_count: 1,
28            chunk_size: 256,
29        }
30    }
31}
32
33/// Collection of CPU compute kernels for image processing.
34///
35/// Every method is a pure function — `&self` is used only to read
36/// [`KernelConfig`]; no mutable state is kept.
37pub struct ComputeKernel {
38    config: KernelConfig,
39}
40
41impl ComputeKernel {
42    /// Create a new `ComputeKernel` with the given configuration.
43    #[must_use]
44    pub fn new(config: KernelConfig) -> Self {
45        Self { config }
46    }
47
48    /// Create a new `ComputeKernel` with default configuration.
49    #[must_use]
50    pub fn default_config() -> Self {
51        Self::new(KernelConfig::default())
52    }
53
54    /// Return a reference to the current configuration.
55    #[must_use]
56    pub fn config(&self) -> &KernelConfig {
57        &self.config
58    }
59
60    // -----------------------------------------------------------------------
61    // RGBA → YUV 420 (BT.601)
62    // -----------------------------------------------------------------------
63
64    /// Convert packed RGBA to planar YUV 420.
65    ///
66    /// Layout of the returned buffer:
67    /// - Y plane  : `width * height` bytes
68    /// - Cb plane : `(width/2) * (height/2)` bytes
69    /// - Cr plane : `(width/2) * (height/2)` bytes
70    ///
71    /// `rgba` must have length `width * height * 4`.
72    /// Returns `None` if the input length is unexpected.
73    pub fn rgba_to_yuv420(&self, rgba: &[u8], width: u32, height: u32) -> Option<Vec<u8>> {
74        let w = width as usize;
75        let h = height as usize;
76        if rgba.len() != w * h * 4 {
77            return None;
78        }
79
80        let y_size = w * h;
81        let uv_w = (w + 1) / 2;
82        let uv_h = (h + 1) / 2;
83        let uv_size = uv_w * uv_h;
84
85        let mut out = vec![0u8; y_size + 2 * uv_size];
86        let (y_plane, uv_rest) = out.split_at_mut(y_size);
87        let (cb_plane, cr_plane) = uv_rest.split_at_mut(uv_size);
88
89        // Unrolled Y-plane pass: process 8 pixels at a time where possible.
90        let chunks = w * h;
91        let chunk8 = chunks / 8;
92        let remainder = chunks % 8;
93
94        // Helper closure — avoids redundant indexing arithmetic in the loop.
95        let sample_y = |idx: usize| -> u8 {
96            let base = idx * 4;
97            let r = rgba[base] as f32;
98            let g = rgba[base + 1] as f32;
99            let b = rgba[base + 2] as f32;
100            let y = 0.299_f32 * r + 0.587_f32 * g + 0.114_f32 * b;
101            y.round().clamp(0.0, 255.0) as u8
102        };
103
104        for i in 0..chunk8 {
105            let base = i * 8;
106            y_plane[base] = sample_y(base);
107            y_plane[base + 1] = sample_y(base + 1);
108            y_plane[base + 2] = sample_y(base + 2);
109            y_plane[base + 3] = sample_y(base + 3);
110            y_plane[base + 4] = sample_y(base + 4);
111            y_plane[base + 5] = sample_y(base + 5);
112            y_plane[base + 6] = sample_y(base + 6);
113            y_plane[base + 7] = sample_y(base + 7);
114        }
115        let rem_start = chunk8 * 8;
116        for i in 0..remainder {
117            y_plane[rem_start + i] = sample_y(rem_start + i);
118        }
119
120        // Cb / Cr — 2×2 average subsampling.
121        for block_y in 0..uv_h {
122            for block_x in 0..uv_w {
123                let mut sum_cb = 0.0_f32;
124                let mut sum_cr = 0.0_f32;
125                let mut count = 0_u32;
126
127                for dy in 0..2_usize {
128                    let sy = block_y * 2 + dy;
129                    if sy >= h {
130                        continue;
131                    }
132                    for dx in 0..2_usize {
133                        let sx = block_x * 2 + dx;
134                        if sx >= w {
135                            continue;
136                        }
137                        let base = (sy * w + sx) * 4;
138                        let r = rgba[base] as f32;
139                        let g = rgba[base + 1] as f32;
140                        let b = rgba[base + 2] as f32;
141                        sum_cb += -0.168_736_f32 * r - 0.331_264_f32 * g + 0.5_f32 * b + 128.0;
142                        sum_cr += 0.5_f32 * r - 0.418_688_f32 * g - 0.081_312_f32 * b + 128.0;
143                        count += 1;
144                    }
145                }
146
147                let uv_idx = block_y * uv_w + block_x;
148                if count > 0 {
149                    cb_plane[uv_idx] = (sum_cb / count as f32).round().clamp(0.0, 255.0) as u8;
150                    cr_plane[uv_idx] = (sum_cr / count as f32).round().clamp(0.0, 255.0) as u8;
151                }
152            }
153        }
154
155        Some(out)
156    }
157
158    // -----------------------------------------------------------------------
159    // YUV 420 → RGBA
160    // -----------------------------------------------------------------------
161
162    /// Convert planar YUV 420 to packed RGBA.
163    ///
164    /// Expects the same memory layout as produced by `rgba_to_yuv420`.
165    /// Returns `None` if the input length is unexpected.
166    pub fn yuv420_to_rgba(&self, yuv: &[u8], width: u32, height: u32) -> Option<Vec<u8>> {
167        let w = width as usize;
168        let h = height as usize;
169        let y_size = w * h;
170        let uv_w = (w + 1) / 2;
171        let uv_h = (h + 1) / 2;
172        let uv_size = uv_w * uv_h;
173        let expected = y_size + 2 * uv_size;
174
175        if yuv.len() != expected {
176            return None;
177        }
178
179        let y_plane = &yuv[..y_size];
180        let cb_plane = &yuv[y_size..y_size + uv_size];
181        let cr_plane = &yuv[y_size + uv_size..];
182
183        let mut rgba = vec![0u8; w * h * 4];
184
185        // Process 4 pixels at a time for auto-vectorisation.
186        let total_pixels = w * h;
187        let chunk4 = total_pixels / 4;
188        let rem4 = total_pixels % 4;
189
190        let convert_pixel = |pix_idx: usize, out: &mut [u8]| {
191            let py = pix_idx / w;
192            let px = pix_idx % w;
193            let uv_x = px / 2;
194            let uv_y = py / 2;
195            let uv_idx = uv_y * uv_w + uv_x;
196
197            let yv = y_plane[pix_idx] as f32;
198            let cb = cb_plane[uv_idx] as f32 - 128.0;
199            let cr = cr_plane[uv_idx] as f32 - 128.0;
200
201            let r = (yv + 1.402_f32 * cr).round().clamp(0.0, 255.0) as u8;
202            let g = (yv - 0.344_136_f32 * cb - 0.714_136_f32 * cr)
203                .round()
204                .clamp(0.0, 255.0) as u8;
205            let b = (yv + 1.772_f32 * cb).round().clamp(0.0, 255.0) as u8;
206
207            let base = pix_idx * 4;
208            out[base] = r;
209            out[base + 1] = g;
210            out[base + 2] = b;
211            out[base + 3] = 255;
212        };
213
214        for i in 0..chunk4 {
215            let base = i * 4;
216            convert_pixel(base, &mut rgba);
217            convert_pixel(base + 1, &mut rgba);
218            convert_pixel(base + 2, &mut rgba);
219            convert_pixel(base + 3, &mut rgba);
220        }
221        let rem_start = chunk4 * 4;
222        for i in 0..rem4 {
223            convert_pixel(rem_start + i, &mut rgba);
224        }
225
226        Some(rgba)
227    }
228
229    // -----------------------------------------------------------------------
230    // Gaussian blur (separable)
231    // -----------------------------------------------------------------------
232
233    /// Apply a separable Gaussian blur.
234    ///
235    /// Kernel radius is `ceil(3 * sigma)`.  Each pixel channel is treated as
236    /// an independent `f32` sample (grayscale or multi-channel flattened).
237    /// Returns `None` if the input length doesn't match `width * height`.
238    pub fn gaussian_blur(
239        &self,
240        pixels: &[f32],
241        width: u32,
242        height: u32,
243        sigma: f32,
244    ) -> Option<Vec<f32>> {
245        let w = width as usize;
246        let h = height as usize;
247        if pixels.len() != w * h {
248            return None;
249        }
250
251        if sigma <= 0.0 {
252            return Some(pixels.to_vec());
253        }
254
255        let radius = (3.0 * sigma).ceil() as usize;
256        let kernel = build_gaussian_kernel_1d(radius, sigma);
257
258        // --- Horizontal pass ---
259        let mut tmp = vec![0.0_f32; w * h];
260        for row in 0..h {
261            let row_start = row * w;
262            for col in 0..w {
263                let mut acc = 0.0_f32;
264                let mut weight_sum = 0.0_f32;
265                for ki in 0..kernel.len() {
266                    let koff = ki as isize - radius as isize;
267                    let src_col = col as isize + koff;
268                    if src_col >= 0 && src_col < w as isize {
269                        let k = kernel[ki];
270                        acc += pixels[row_start + src_col as usize] * k;
271                        weight_sum += k;
272                    }
273                }
274                tmp[row_start + col] = if weight_sum > 0.0 {
275                    acc / weight_sum
276                } else {
277                    0.0
278                };
279            }
280        }
281
282        // --- Vertical pass ---
283        let mut out = vec![0.0_f32; w * h];
284        for col in 0..w {
285            for row in 0..h {
286                let mut acc = 0.0_f32;
287                let mut weight_sum = 0.0_f32;
288                for ki in 0..kernel.len() {
289                    let koff = ki as isize - radius as isize;
290                    let src_row = row as isize + koff;
291                    if src_row >= 0 && src_row < h as isize {
292                        let k = kernel[ki];
293                        acc += tmp[src_row as usize * w + col] * k;
294                        weight_sum += k;
295                    }
296                }
297                out[row * w + col] = if weight_sum > 0.0 {
298                    acc / weight_sum
299                } else {
300                    0.0
301                };
302            }
303        }
304
305        Some(out)
306    }
307
308    // -----------------------------------------------------------------------
309    // Sobel edge detection
310    // -----------------------------------------------------------------------
311
312    /// Compute Sobel gradient magnitude for a grayscale image.
313    ///
314    /// Input `gray` must have length `width * height`.
315    /// Returns `None` if length mismatch.  Border pixels are set to 0.
316    pub fn sobel_edges(&self, gray: &[f32], width: u32, height: u32) -> Option<Vec<f32>> {
317        let w = width as usize;
318        let h = height as usize;
319        if gray.len() != w * h {
320            return None;
321        }
322
323        let mut out = vec![0.0_f32; w * h];
324
325        // Kernels:
326        // Gx = [[-1, 0, +1], [-2, 0, +2], [-1, 0, +1]]
327        // Gy = [[-1, -2, -1], [0, 0, 0], [+1, +2, +1]]
328        for row in 1..h.saturating_sub(1) {
329            let row_base = row * w;
330            for col in 1..w.saturating_sub(1) {
331                let tl = gray[(row - 1) * w + (col - 1)];
332                let tc = gray[(row - 1) * w + col];
333                let tr = gray[(row - 1) * w + (col + 1)];
334                let ml = gray[row * w + (col - 1)];
335                let mr = gray[row * w + (col + 1)];
336                let bl = gray[(row + 1) * w + (col - 1)];
337                let bc = gray[(row + 1) * w + col];
338                let br = gray[(row + 1) * w + (col + 1)];
339
340                let gx = -tl + tr - 2.0 * ml + 2.0 * mr - bl + br;
341                let gy = -tl - 2.0 * tc - tr + bl + 2.0 * bc + br;
342
343                out[row_base + col] = (gx * gx + gy * gy).sqrt();
344            }
345        }
346
347        Some(out)
348    }
349
350    // -----------------------------------------------------------------------
351    // Histogram equalization
352    // -----------------------------------------------------------------------
353
354    /// Apply histogram equalization to an 8-bit grayscale image.
355    ///
356    /// Returns `None` if `gray.len() != width * height`.
357    pub fn histogram_equalization(&self, gray: &[u8], width: u32, height: u32) -> Option<Vec<u8>> {
358        let n = width as usize * height as usize;
359        if gray.len() != n {
360            return None;
361        }
362
363        // Build histogram.
364        let mut hist = [0u64; 256];
365        for &px in gray {
366            hist[px as usize] += 1;
367        }
368
369        // CDF.
370        let mut cdf = [0u64; 256];
371        cdf[0] = hist[0];
372        for i in 1..256 {
373            cdf[i] = cdf[i - 1] + hist[i];
374        }
375
376        let cdf_min = cdf.iter().copied().find(|&v| v > 0).unwrap_or(0);
377        let total = n as u64;
378
379        // Lookup table.
380        let lut: Vec<u8> = (0..256)
381            .map(|i| {
382                if total == cdf_min {
383                    i as u8
384                } else {
385                    let v = (cdf[i] - cdf_min) as f64 * 255.0 / (total - cdf_min) as f64;
386                    v.round().clamp(0.0, 255.0) as u8
387                }
388            })
389            .collect();
390
391        Some(gray.iter().map(|&px| lut[px as usize]).collect())
392    }
393
394    // -----------------------------------------------------------------------
395    // Otsu thresholding
396    // -----------------------------------------------------------------------
397
398    /// Compute Otsu's optimal threshold and produce a binary image.
399    ///
400    /// Returns `(threshold, binary_image)` or `None` on size mismatch.
401    /// Binary output: 0 for below-threshold, 255 for at/above.
402    pub fn threshold_otsu(&self, gray: &[u8], width: u32, height: u32) -> Option<(u8, Vec<u8>)> {
403        let n = width as usize * height as usize;
404        if gray.len() != n {
405            return None;
406        }
407
408        let mut hist = [0u64; 256];
409        for &px in gray {
410            hist[px as usize] += 1;
411        }
412
413        let total = n as f64;
414        let mut sum_total = 0.0_f64;
415        for i in 0..256_usize {
416            sum_total += i as f64 * hist[i] as f64;
417        }
418
419        let mut sum_b = 0.0_f64;
420        let mut w_b = 0.0_f64;
421        let mut max_var = 0.0_f64;
422        let mut threshold = 0u8;
423
424        for i in 0..256_usize {
425            w_b += hist[i] as f64;
426            if w_b == 0.0 {
427                continue;
428            }
429            let w_f = total - w_b;
430            if w_f == 0.0 {
431                break;
432            }
433
434            sum_b += i as f64 * hist[i] as f64;
435            let m_b = sum_b / w_b;
436            let m_f = (sum_total - sum_b) / w_f;
437            let diff = m_b - m_f;
438            let between_var = w_b * w_f * diff * diff;
439
440            if between_var > max_var {
441                max_var = between_var;
442                threshold = i as u8;
443            }
444        }
445
446        let binary: Vec<u8> = gray
447            .iter()
448            .map(|&px| if px > threshold { 255 } else { 0 })
449            .collect();
450
451        Some((threshold, binary))
452    }
453
454    // -----------------------------------------------------------------------
455    // Alpha compositing (Porter-Duff "over")
456    // -----------------------------------------------------------------------
457
458    /// Composite `fg` over `bg` using the Porter-Duff "over" operator.
459    ///
460    /// Both buffers must be RGBA, length `width * height * 4`.
461    /// Returns `None` on size mismatch.
462    pub fn alpha_composite(
463        &self,
464        fg: &[u8],
465        bg: &[u8],
466        width: u32,
467        height: u32,
468    ) -> Option<Vec<u8>> {
469        let n = width as usize * height as usize;
470        let expected = n * 4;
471        if fg.len() != expected || bg.len() != expected {
472            return None;
473        }
474
475        let mut out = vec![0u8; expected];
476        let chunk_size = self.config.chunk_size.max(8) / 4 * 4; // keep multiple of 4
477
478        let chunks = n / (chunk_size / 4);
479        let rem = n % (chunk_size / 4);
480
481        let composite_pixel = |i: usize, out: &mut [u8]| {
482            let base = i * 4;
483            let fa = fg[base + 3] as f32 / 255.0;
484            let ba = bg[base + 3] as f32 / 255.0;
485            let out_a = fa + ba * (1.0 - fa);
486            if out_a <= 0.0 {
487                return;
488            }
489            let inv_out = 1.0 / out_a;
490            out[base] = ((fg[base] as f32 * fa + bg[base] as f32 * ba * (1.0 - fa)) * inv_out)
491                .round()
492                .clamp(0.0, 255.0) as u8;
493            out[base + 1] = ((fg[base + 1] as f32 * fa + bg[base + 1] as f32 * ba * (1.0 - fa))
494                * inv_out)
495                .round()
496                .clamp(0.0, 255.0) as u8;
497            out[base + 2] = ((fg[base + 2] as f32 * fa + bg[base + 2] as f32 * ba * (1.0 - fa))
498                * inv_out)
499                .round()
500                .clamp(0.0, 255.0) as u8;
501            out[base + 3] = (out_a * 255.0).round().clamp(0.0, 255.0) as u8;
502        };
503
504        let pixels_per_chunk = chunk_size / 4;
505        for c in 0..chunks {
506            let start = c * pixels_per_chunk;
507            for p in 0..pixels_per_chunk {
508                composite_pixel(start + p, &mut out);
509            }
510        }
511        let rem_start = chunks * pixels_per_chunk;
512        for p in 0..rem {
513            composite_pixel(rem_start + p, &mut out);
514        }
515
516        Some(out)
517    }
518
519    // -----------------------------------------------------------------------
520    // Bilinear image scaling
521    // -----------------------------------------------------------------------
522
523    /// Scale an RGBA image using bilinear interpolation.
524    ///
525    /// Input `pixels` must be packed RGBA with length `src_w * src_h * 4`.
526    /// Returns `None` on size mismatch or zero dimensions.
527    pub fn scale_image(
528        &self,
529        pixels: &[u8],
530        src_w: u32,
531        src_h: u32,
532        dst_w: u32,
533        dst_h: u32,
534    ) -> Option<Vec<u8>> {
535        let sw = src_w as usize;
536        let sh = src_h as usize;
537        let dw = dst_w as usize;
538        let dh = dst_h as usize;
539
540        if sw == 0 || sh == 0 || dw == 0 || dh == 0 {
541            return None;
542        }
543        if pixels.len() != sw * sh * 4 {
544            return None;
545        }
546
547        let mut out = vec![0u8; dw * dh * 4];
548
549        let x_scale = sw as f32 / dw as f32;
550        let y_scale = sh as f32 / dh as f32;
551
552        for dy in 0..dh {
553            // Continuous source y coordinate (centre-of-pixel mapping).
554            let src_y = (dy as f32 + 0.5) * y_scale - 0.5;
555            let y0 = (src_y.floor() as isize).clamp(0, sh as isize - 1) as usize;
556            let y1 = (y0 + 1).min(sh - 1);
557            let ty = (src_y - src_y.floor()).max(0.0).min(1.0);
558
559            for dx in 0..dw {
560                let src_x = (dx as f32 + 0.5) * x_scale - 0.5;
561                let x0 = (src_x.floor() as isize).clamp(0, sw as isize - 1) as usize;
562                let x1 = (x0 + 1).min(sw - 1);
563                let tx = (src_x - src_x.floor()).max(0.0).min(1.0);
564
565                let i00 = (y0 * sw + x0) * 4;
566                let i10 = (y0 * sw + x1) * 4;
567                let i01 = (y1 * sw + x0) * 4;
568                let i11 = (y1 * sw + x1) * 4;
569
570                let dst_base = (dy * dw + dx) * 4;
571
572                // Unrolled over 4 channels.
573                out[dst_base] =
574                    bilinear_u8(pixels[i00], pixels[i10], pixels[i01], pixels[i11], tx, ty);
575                out[dst_base + 1] = bilinear_u8(
576                    pixels[i00 + 1],
577                    pixels[i10 + 1],
578                    pixels[i01 + 1],
579                    pixels[i11 + 1],
580                    tx,
581                    ty,
582                );
583                out[dst_base + 2] = bilinear_u8(
584                    pixels[i00 + 2],
585                    pixels[i10 + 2],
586                    pixels[i01 + 2],
587                    pixels[i11 + 2],
588                    tx,
589                    ty,
590                );
591                out[dst_base + 3] = bilinear_u8(
592                    pixels[i00 + 3],
593                    pixels[i10 + 3],
594                    pixels[i01 + 3],
595                    pixels[i11 + 3],
596                    tx,
597                    ty,
598                );
599            }
600        }
601
602        Some(out)
603    }
604}
605
606// ---------------------------------------------------------------------------
607// Internal helpers
608// ---------------------------------------------------------------------------
609
610/// Build a 1-D Gaussian kernel of radius `r` (length = 2r+1).
611fn build_gaussian_kernel_1d(radius: usize, sigma: f32) -> Vec<f32> {
612    let len = 2 * radius + 1;
613    let mut k = Vec::with_capacity(len);
614    let two_sigma_sq = 2.0 * sigma * sigma;
615    let mut sum = 0.0_f32;
616    for i in 0..len {
617        let x = (i as isize - radius as isize) as f32;
618        let v = (-x * x / two_sigma_sq).exp();
619        k.push(v);
620        sum += v;
621    }
622    if sum > 0.0 {
623        for v in &mut k {
624            *v /= sum;
625        }
626    }
627    k
628}
629
630/// Bilinear interpolation for a single `u8` channel.
631#[inline(always)]
632fn bilinear_u8(c00: u8, c10: u8, c01: u8, c11: u8, tx: f32, ty: f32) -> u8 {
633    let v00 = c00 as f32;
634    let v10 = c10 as f32;
635    let v01 = c01 as f32;
636    let v11 = c11 as f32;
637    let top = v00 + (v10 - v00) * tx;
638    let bottom = v01 + (v11 - v01) * tx;
639    (top + (bottom - top) * ty).round().clamp(0.0, 255.0) as u8
640}
641
642// ---------------------------------------------------------------------------
643// Tests
644// ---------------------------------------------------------------------------
645
646#[cfg(test)]
647mod tests {
648    use super::*;
649
650    fn make_kernel() -> ComputeKernel {
651        ComputeKernel::default_config()
652    }
653
654    // --- rgba_to_yuv420 ---
655
656    #[test]
657    fn test_rgba_to_yuv420_size() {
658        let kernel = make_kernel();
659        // 4×4 image: each pixel = [100, 149, 237, 255]
660        let rgba: Vec<u8> = (0..16).flat_map(|_| [100u8, 149, 237, 255]).collect();
661        let yuv = kernel
662            .rgba_to_yuv420(&rgba, 4, 4)
663            .expect("conversion failed");
664        assert_eq!(yuv.len(), 4 * 4 + 2 * 2 * 2); // Y + 2*(2*2)
665    }
666
667    #[test]
668    fn test_rgba_to_yuv420_invalid_size() {
669        let kernel = make_kernel();
670        let rgba = vec![0u8; 10]; // wrong size
671        assert!(kernel.rgba_to_yuv420(&rgba, 4, 4).is_none());
672    }
673
674    #[test]
675    fn test_rgba_to_yuv420_white_pixel() {
676        let kernel = make_kernel();
677        // 2×2 white pixels
678        let rgba: Vec<u8> = (0..4).flat_map(|_| [255u8, 255, 255, 255]).collect();
679        let yuv = kernel
680            .rgba_to_yuv420(&rgba, 2, 2)
681            .expect("conversion failed");
682        // Y for white ≈ 255
683        assert!(yuv[0] > 230, "Y for white should be ≈ 255, got {}", yuv[0]);
684    }
685
686    #[test]
687    fn test_rgba_to_yuv420_black_pixel() {
688        let kernel = make_kernel();
689        // 2×2 black pixels
690        let rgba: Vec<u8> = (0..4).flat_map(|_| [0u8, 0, 0, 255]).collect();
691        let yuv = kernel
692            .rgba_to_yuv420(&rgba, 2, 2)
693            .expect("conversion failed");
694        assert_eq!(yuv[0], 0, "Y for black should be 0");
695    }
696
697    // --- yuv420_to_rgba ---
698
699    #[test]
700    fn test_yuv420_roundtrip() {
701        let kernel = make_kernel();
702        // Build a simple 4×4 RGBA image (mid-grey).
703        let rgba_in: Vec<u8> = (0..16).flat_map(|_| [128u8, 128, 128, 255]).collect();
704        let yuv = kernel.rgba_to_yuv420(&rgba_in, 4, 4).expect("to_yuv");
705        let rgba_out = kernel.yuv420_to_rgba(&yuv, 4, 4).expect("to_rgba");
706        // Round-trip: each channel should be within ±4 due to quantisation.
707        for i in (0..rgba_out.len()).step_by(4) {
708            let diff = (rgba_in[i] as i32 - rgba_out[i] as i32).unsigned_abs();
709            assert!(diff <= 4, "channel diff too large: {diff}");
710        }
711    }
712
713    #[test]
714    fn test_yuv420_to_rgba_invalid_size() {
715        let kernel = make_kernel();
716        let bad = vec![0u8; 5];
717        assert!(kernel.yuv420_to_rgba(&bad, 4, 4).is_none());
718    }
719
720    // --- gaussian_blur ---
721
722    #[test]
723    fn test_gaussian_blur_flat_image() {
724        let kernel = make_kernel();
725        let pixels = vec![1.0_f32; 8 * 8];
726        let blurred = kernel.gaussian_blur(&pixels, 8, 8, 1.0).expect("blur");
727        // Blurring a constant image should leave it unchanged.
728        for &v in &blurred {
729            assert!((v - 1.0).abs() < 1e-4, "expected ~1.0 got {v}");
730        }
731    }
732
733    #[test]
734    fn test_gaussian_blur_zero_sigma() {
735        let kernel = make_kernel();
736        let pixels = vec![0.5_f32; 4 * 4];
737        let out = kernel.gaussian_blur(&pixels, 4, 4, 0.0).expect("blur");
738        // sigma=0 → identity
739        for &v in &out {
740            assert!((v - 0.5).abs() < 1e-5);
741        }
742    }
743
744    #[test]
745    fn test_gaussian_blur_invalid_size() {
746        let kernel = make_kernel();
747        let pixels = vec![0.0_f32; 3];
748        assert!(kernel.gaussian_blur(&pixels, 4, 4, 1.0).is_none());
749    }
750
751    // --- sobel_edges ---
752
753    #[test]
754    fn test_sobel_flat_image_is_zero() {
755        let kernel = make_kernel();
756        let gray = vec![0.5_f32; 8 * 8];
757        let edges = kernel.sobel_edges(&gray, 8, 8).expect("sobel");
758        // Interior pixels of a flat image → gradient = 0.
759        for row in 1..7_usize {
760            for col in 1..7_usize {
761                let v = edges[row * 8 + col];
762                assert!(v.abs() < 1e-5, "expected 0 at ({row},{col}), got {v}");
763            }
764        }
765    }
766
767    #[test]
768    fn test_sobel_vertical_edge() {
769        let kernel = make_kernel();
770        // Left half = 0, right half = 1 → strong vertical edge in the middle.
771        let mut gray = vec![0.0_f32; 8 * 8];
772        for row in 0..8_usize {
773            for col in 4..8_usize {
774                gray[row * 8 + col] = 1.0;
775            }
776        }
777        let edges = kernel.sobel_edges(&gray, 8, 8).expect("sobel");
778        // The column just at the boundary (col=3 or col=4, row interior) should
779        // have a non-zero gradient.
780        let mid_val = edges[3 * 8 + 3];
781        assert!(mid_val > 0.1, "expected edge at boundary, got {mid_val}");
782    }
783
784    // --- histogram_equalization ---
785
786    #[test]
787    fn test_histogram_equalization_constant() {
788        let kernel = make_kernel();
789        let gray = vec![100u8; 4 * 4];
790        let out = kernel.histogram_equalization(&gray, 4, 4).expect("eq");
791        // With constant input, all output values should be the same.
792        let first = out[0];
793        for &v in &out {
794            assert_eq!(v, first);
795        }
796    }
797
798    // --- threshold_otsu ---
799
800    #[test]
801    fn test_threshold_otsu_bimodal() {
802        let kernel = make_kernel();
803        // Two classes: 50 pixels at value 30 (dark), 50 pixels at value 200 (bright).
804        // Otsu's threshold will be 30 (the dark class value) because the maximum
805        // between-class variance occurs at the boundary between the two modes.
806        // With `px > threshold` classification: 30 → 0 (bg), 200 → 255 (fg).
807        let mut gray = vec![30u8; 50];
808        gray.extend_from_slice(&[200u8; 50]);
809        let (thresh, binary) = kernel.threshold_otsu(&gray, 10, 10).expect("otsu");
810        // Threshold should be at the dark-class value (30).
811        assert!(
812            thresh < 200,
813            "threshold {thresh} must be less than bright class value 200"
814        );
815        // Dark pixels should map to background (0), bright to foreground (255).
816        assert_eq!(binary[0], 0, "dark pixel (value 30) should be background");
817        assert_eq!(
818            binary[50], 255,
819            "bright pixel (value 200) should be foreground"
820        );
821    }
822
823    // --- alpha_composite ---
824
825    #[test]
826    fn test_alpha_composite_opaque_fg() {
827        let kernel = make_kernel();
828        // Fully opaque red fg over any bg → output = red.
829        let fg: Vec<u8> = (0..4).flat_map(|_| [255u8, 0, 0, 255]).collect();
830        let bg: Vec<u8> = (0..4).flat_map(|_| [0u8, 0, 255, 255]).collect();
831        let out = kernel.alpha_composite(&fg, &bg, 2, 2).expect("composite");
832        assert_eq!(&out[0..4], &[255u8, 0, 0, 255]);
833    }
834
835    #[test]
836    fn test_alpha_composite_transparent_fg() {
837        let kernel = make_kernel();
838        // Fully transparent fg → output = bg.
839        let fg: Vec<u8> = (0..4).flat_map(|_| [255u8, 0, 0, 0u8]).collect();
840        let bg: Vec<u8> = (0..4).flat_map(|_| [0u8, 0, 255, 255]).collect();
841        let out = kernel.alpha_composite(&fg, &bg, 2, 2).expect("composite");
842        assert_eq!(&out[0..4], &[0u8, 0, 255, 255]);
843    }
844
845    #[test]
846    fn test_alpha_composite_size_mismatch() {
847        let kernel = make_kernel();
848        let fg = vec![0u8; 8];
849        let bg = vec![0u8; 16];
850        assert!(kernel.alpha_composite(&fg, &bg, 2, 2).is_none());
851    }
852
853    // --- scale_image ---
854
855    #[test]
856    fn test_scale_image_identity() {
857        let kernel = make_kernel();
858        let pixels: Vec<u8> = (0..16)
859            .flat_map(|i: u8| [i * 4, i * 4, i * 4, 255])
860            .collect();
861        let out = kernel.scale_image(&pixels, 4, 4, 4, 4).expect("scale");
862        assert_eq!(out, pixels);
863    }
864
865    #[test]
866    fn test_scale_image_upscale_size() {
867        let kernel = make_kernel();
868        let pixels = vec![128u8; 4 * 4 * 4]; // 4×4 grey
869        let out = kernel.scale_image(&pixels, 4, 4, 8, 8).expect("scale");
870        assert_eq!(out.len(), 8 * 8 * 4);
871    }
872
873    #[test]
874    fn test_scale_image_zero_dimension() {
875        let kernel = make_kernel();
876        let pixels = vec![0u8; 4 * 4 * 4];
877        assert!(kernel.scale_image(&pixels, 4, 4, 0, 8).is_none());
878    }
879}