Skip to main content

oximedia_gpu/
accelerator.rs

1//! `GpuAccelerator` trait and hardware acceleration abstraction.
2//!
3//! This module defines the unified [`GpuAccelerator`] trait that provides a
4//! hardware-agnostic interface for GPU compute operations.  Concrete backends
5//! (Vulkan/WGPU and CPU SIMD) implement the trait so callers never need to
6//! branch on the active backend.
7//!
8//! # Design
9//!
10//! ```text
11//! GpuAccelerator (trait)
12//!    ├── WgpuAccelerator  ── uses wgpu (Vulkan / Metal / DX12 / WebGPU)
13//!    └── CpuAccelerator   ── uses rayon SIMD fallback (always available)
14//! ```
15//!
16//! # Quick Start
17//!
18//! ```no_run
19//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
20//! use oximedia_gpu::accelerator::{AcceleratorBuilder, GpuAccelerator};
21//!
22//! let acc = AcceleratorBuilder::new().build()?;
23//! let name = acc.name();
24//! println!("Using backend: {name}");
25//!
26//! let rgb  = vec![0u8; 1920 * 1080 * 4];
27//! let mut yuv = vec![0u8; 1920 * 1080 * 4];
28//! acc.rgb_to_yuv(&rgb, &mut yuv, 1920, 1080)?;
29//! # Ok(())
30//! # }
31//! ```
32
33#![allow(clippy::cast_possible_truncation)]
34#![allow(clippy::cast_sign_loss)]
35#![allow(clippy::cast_precision_loss)]
36#![allow(clippy::cast_possible_wrap)]
37
38use crate::{GpuError, Result};
39use rayon::prelude::*;
40
41// =============================================================================
42// GpuAccelerator trait
43// =============================================================================
44
45/// Unified interface for hardware-accelerated media operations.
46///
47/// All operations have CPU fallback implementations so that callers do not
48/// need to handle the absence of a GPU.  The trait is object-safe; you can
49/// store it behind `Box<dyn GpuAccelerator>` or `Arc<dyn GpuAccelerator>`.
50#[cfg(not(target_arch = "wasm32"))]
51pub trait GpuAccelerator: Send + Sync + 'static {
52    /// Human-readable backend name (e.g. `"Vulkan"`, `"CPU SIMD"`).
53    fn name(&self) -> &str;
54
55    /// Whether this accelerator uses dedicated GPU hardware.
56    fn is_gpu(&self) -> bool;
57
58    /// Convert packed RGBA data to packed YUVA using BT.601 coefficients.
59    ///
60    /// Both slices must have length `width * height * 4`.
61    ///
62    /// # Errors
63    ///
64    /// Returns [`GpuError::InvalidBufferSize`] if slice lengths do not match
65    /// `width * height * 4`.
66    fn rgb_to_yuv(&self, input: &[u8], output: &mut [u8], width: u32, height: u32) -> Result<()>;
67
68    /// Convert packed YUVA data to packed RGBA using BT.601 coefficients.
69    ///
70    /// Both slices must have length `width * height * 4`.
71    ///
72    /// # Errors
73    ///
74    /// Returns [`GpuError::InvalidBufferSize`] if slice lengths do not match.
75    fn yuv_to_rgb(&self, input: &[u8], output: &mut [u8], width: u32, height: u32) -> Result<()>;
76
77    /// Resize packed RGBA image using bilinear interpolation.
78    ///
79    /// # Errors
80    ///
81    /// Returns an error if buffer sizes are inconsistent with the given
82    /// dimensions.
83    #[allow(clippy::too_many_arguments)]
84    fn scale_bilinear(
85        &self,
86        input: &[u8],
87        src_width: u32,
88        src_height: u32,
89        output: &mut [u8],
90        dst_width: u32,
91        dst_height: u32,
92    ) -> Result<()>;
93
94    /// Apply a separable Gaussian blur to a packed RGBA image.
95    ///
96    /// `sigma` is the standard deviation of the Gaussian kernel in pixels.
97    ///
98    /// # Errors
99    ///
100    /// Returns an error if `input.len() != output.len()` or if either length
101    /// does not equal `width * height * 4`.
102    fn gaussian_blur(
103        &self,
104        input: &[u8],
105        output: &mut [u8],
106        width: u32,
107        height: u32,
108        sigma: f32,
109    ) -> Result<()>;
110
111    /// Detect edges using the Sobel operator on a packed RGBA image.
112    ///
113    /// The output contains per-pixel gradient magnitudes.
114    ///
115    /// # Errors
116    ///
117    /// Returns an error if buffer sizes are inconsistent.
118    fn edge_detect(&self, input: &[u8], output: &mut [u8], width: u32, height: u32) -> Result<()>;
119
120    /// Sharpen a packed RGBA image using an unsharp mask.
121    ///
122    /// `amount` controls the sharpening strength (typical range 0.0–2.0).
123    ///
124    /// # Errors
125    ///
126    /// Returns an error if buffer sizes are inconsistent.
127    fn sharpen(
128        &self,
129        input: &[u8],
130        output: &mut [u8],
131        width: u32,
132        height: u32,
133        amount: f32,
134    ) -> Result<()>;
135
136    /// Compute the 2-D Type-II DCT on a grid of `f32` values.
137    ///
138    /// Both `width` and `height` must be multiples of 8.
139    ///
140    /// # Errors
141    ///
142    /// Returns an error if dimensions are not multiples of 8 or if slice
143    /// lengths do not equal `width * height`.
144    fn dct_2d(&self, input: &[f32], output: &mut [f32], width: u32, height: u32) -> Result<()>;
145
146    /// Compute the 2-D Type-III IDCT (inverse of [`dct_2d`]).
147    ///
148    /// # Errors
149    ///
150    /// Returns the same errors as [`dct_2d`].
151    ///
152    /// [`dct_2d`]: GpuAccelerator::dct_2d
153    fn idct_2d(&self, input: &[f32], output: &mut [f32], width: u32, height: u32) -> Result<()>;
154
155    /// Compute the per-pixel absolute difference between two RGBA images.
156    ///
157    /// # Errors
158    ///
159    /// Returns an error if any buffer length differs from `width * height * 4`.
160    fn pixel_diff(
161        &self,
162        a: &[u8],
163        b: &[u8],
164        output: &mut [u8],
165        width: u32,
166        height: u32,
167    ) -> Result<()>;
168
169    /// Compute the mean squared error between two RGBA images.
170    ///
171    /// Returns the average squared per-channel difference (range 0.0–65 025.0).
172    ///
173    /// # Errors
174    ///
175    /// Returns an error if buffer lengths do not equal `width * height * 4`.
176    fn mse(&self, a: &[u8], b: &[u8], width: u32, height: u32) -> Result<f64>;
177}
178
179#[cfg(target_arch = "wasm32")]
180pub trait GpuAccelerator: 'static {
181    /// Human-readable backend name (e.g. `"Vulkan"`, `"CPU SIMD"`).
182    fn name(&self) -> &str;
183
184    /// Whether this accelerator uses dedicated GPU hardware.
185    fn is_gpu(&self) -> bool;
186
187    /// Convert packed RGBA data to packed YUVA using BT.601 coefficients.
188    ///
189    /// Both slices must have length `width * height * 4`.
190    ///
191    /// # Errors
192    ///
193    /// Returns [`GpuError::InvalidBufferSize`] if slice lengths do not match
194    /// `width * height * 4`.
195    fn rgb_to_yuv(&self, input: &[u8], output: &mut [u8], width: u32, height: u32) -> Result<()>;
196
197    /// Convert packed YUVA data to packed RGBA using BT.601 coefficients.
198    ///
199    /// Both slices must have length `width * height * 4`.
200    ///
201    /// # Errors
202    ///
203    /// Returns [`GpuError::InvalidBufferSize`] if slice lengths do not match.
204    fn yuv_to_rgb(&self, input: &[u8], output: &mut [u8], width: u32, height: u32) -> Result<()>;
205
206    /// Resize packed RGBA image using bilinear interpolation.
207    ///
208    /// # Errors
209    ///
210    /// Returns an error if buffer sizes are inconsistent with the given
211    /// dimensions.
212    #[allow(clippy::too_many_arguments)]
213    fn scale_bilinear(
214        &self,
215        input: &[u8],
216        src_width: u32,
217        src_height: u32,
218        output: &mut [u8],
219        dst_width: u32,
220        dst_height: u32,
221    ) -> Result<()>;
222
223    /// Apply a separable Gaussian blur to a packed RGBA image.
224    ///
225    /// `sigma` is the standard deviation of the Gaussian kernel in pixels.
226    ///
227    /// # Errors
228    ///
229    /// Returns an error if `input.len() != output.len()` or if either length
230    /// does not equal `width * height * 4`.
231    fn gaussian_blur(
232        &self,
233        input: &[u8],
234        output: &mut [u8],
235        width: u32,
236        height: u32,
237        sigma: f32,
238    ) -> Result<()>;
239
240    /// Detect edges using the Sobel operator on a packed RGBA image.
241    ///
242    /// The output contains per-pixel gradient magnitudes.
243    ///
244    /// # Errors
245    ///
246    /// Returns an error if buffer sizes are inconsistent.
247    fn edge_detect(&self, input: &[u8], output: &mut [u8], width: u32, height: u32) -> Result<()>;
248
249    /// Sharpen a packed RGBA image using an unsharp mask.
250    ///
251    /// `amount` controls the sharpening strength (typical range 0.0–2.0).
252    ///
253    /// # Errors
254    ///
255    /// Returns an error if buffer sizes are inconsistent.
256    fn sharpen(
257        &self,
258        input: &[u8],
259        output: &mut [u8],
260        width: u32,
261        height: u32,
262        amount: f32,
263    ) -> Result<()>;
264
265    /// Compute the 2-D Type-II DCT on a grid of `f32` values.
266    ///
267    /// Both `width` and `height` must be multiples of 8.
268    ///
269    /// # Errors
270    ///
271    /// Returns an error if dimensions are not multiples of 8 or if slice
272    /// lengths do not equal `width * height`.
273    fn dct_2d(&self, input: &[f32], output: &mut [f32], width: u32, height: u32) -> Result<()>;
274
275    /// Compute the 2-D Type-III IDCT (inverse of [`dct_2d`]).
276    ///
277    /// # Errors
278    ///
279    /// Returns the same errors as [`dct_2d`].
280    ///
281    /// [`dct_2d`]: GpuAccelerator::dct_2d
282    fn idct_2d(&self, input: &[f32], output: &mut [f32], width: u32, height: u32) -> Result<()>;
283
284    /// Compute the per-pixel absolute difference between two RGBA images.
285    ///
286    /// # Errors
287    ///
288    /// Returns an error if any buffer length differs from `width * height * 4`.
289    fn pixel_diff(
290        &self,
291        a: &[u8],
292        b: &[u8],
293        output: &mut [u8],
294        width: u32,
295        height: u32,
296    ) -> Result<()>;
297
298    /// Compute the mean squared error between two RGBA images.
299    ///
300    /// Returns the average squared per-channel difference (range 0.0–65 025.0).
301    ///
302    /// # Errors
303    ///
304    /// Returns an error if buffer lengths do not equal `width * height * 4`.
305    fn mse(&self, a: &[u8], b: &[u8], width: u32, height: u32) -> Result<f64>;
306}
307
308// =============================================================================
309// Buffer-size validation helpers
310// =============================================================================
311
312fn check_rgba_buf(buf: &[u8], width: u32, height: u32, label: &str) -> Result<()> {
313    let expected = (width as usize) * (height as usize) * 4;
314    if buf.len() != expected {
315        return Err(GpuError::InvalidBufferSize {
316            expected,
317            actual: buf.len(),
318        });
319    }
320    let _ = label;
321    Ok(())
322}
323
324fn check_f32_buf(buf: &[f32], width: u32, height: u32) -> Result<()> {
325    let expected = (width as usize) * (height as usize);
326    if buf.len() != expected {
327        return Err(GpuError::InvalidBufferSize {
328            expected,
329            actual: buf.len(),
330        });
331    }
332    Ok(())
333}
334
335// =============================================================================
336// CPU Accelerator
337// =============================================================================
338
339/// Pure-CPU accelerator backed by rayon parallel iterators.
340///
341/// This implementation is always available and provides correct (if slower)
342/// results even when no GPU is present.
343pub struct CpuAccelerator {
344    num_threads: usize,
345}
346
347impl CpuAccelerator {
348    /// Create a CPU accelerator using all available threads.
349    #[must_use]
350    pub fn new() -> Self {
351        Self {
352            num_threads: rayon::current_num_threads(),
353        }
354    }
355
356    /// Number of worker threads.
357    #[must_use]
358    pub fn num_threads(&self) -> usize {
359        self.num_threads
360    }
361
362    // ---- Internal helpers --------------------------------------------------
363
364    fn rgb_to_yuv_impl(input: &[u8], output: &mut [u8]) {
365        const KR: f32 = 0.299;
366        const KG: f32 = 0.587;
367        const KB: f32 = 0.114;
368
369        output
370            .par_chunks_exact_mut(4)
371            .zip(input.par_chunks_exact(4))
372            .for_each(|(out, inp)| {
373                let r = f32::from(inp[0]) / 255.0;
374                let g = f32::from(inp[1]) / 255.0;
375                let b = f32::from(inp[2]) / 255.0;
376
377                let y = KR * r + KG * g + KB * b;
378                let u = (b - y) / (2.0 * (1.0 - KB)) + 0.5;
379                let v = (r - y) / (2.0 * (1.0 - KR)) + 0.5;
380
381                out[0] = (y.clamp(0.0, 1.0) * 255.0) as u8;
382                out[1] = (u.clamp(0.0, 1.0) * 255.0) as u8;
383                out[2] = (v.clamp(0.0, 1.0) * 255.0) as u8;
384                out[3] = inp[3];
385            });
386    }
387
388    fn yuv_to_rgb_impl(input: &[u8], output: &mut [u8]) {
389        const KR: f32 = 0.299;
390        const KG: f32 = 0.587;
391        const KB: f32 = 0.114;
392
393        output
394            .par_chunks_exact_mut(4)
395            .zip(input.par_chunks_exact(4))
396            .for_each(|(out, inp)| {
397                let y = f32::from(inp[0]) / 255.0;
398                let u = f32::from(inp[1]) / 255.0 - 0.5;
399                let v = f32::from(inp[2]) / 255.0 - 0.5;
400
401                let r = y + 2.0 * (1.0 - KR) * v;
402                let b = y + 2.0 * (1.0 - KB) * u;
403                let g = (y - KR * r - KB * b) / KG;
404
405                out[0] = (r.clamp(0.0, 1.0) * 255.0) as u8;
406                out[1] = (g.clamp(0.0, 1.0) * 255.0) as u8;
407                out[2] = (b.clamp(0.0, 1.0) * 255.0) as u8;
408                out[3] = inp[3];
409            });
410    }
411
412    fn scale_bilinear_impl(
413        input: &[u8],
414        src_w: usize,
415        src_h: usize,
416        output: &mut [u8],
417        dst_w: usize,
418        dst_h: usize,
419    ) {
420        let x_ratio = src_w as f32 / dst_w as f32;
421        let y_ratio = src_h as f32 / dst_h as f32;
422
423        output
424            .par_chunks_exact_mut(4)
425            .enumerate()
426            .for_each(|(idx, pixel)| {
427                let dst_x = idx % dst_w;
428                let dst_y = idx / dst_w;
429                if dst_y >= dst_h {
430                    return;
431                }
432                let src_x = (dst_x as f32 + 0.5) * x_ratio - 0.5;
433                let src_y = (dst_y as f32 + 0.5) * y_ratio - 0.5;
434
435                let x0 = (src_x.floor().max(0.0) as usize).min(src_w - 1);
436                let y0 = (src_y.floor().max(0.0) as usize).min(src_h - 1);
437                let x1 = (x0 + 1).min(src_w - 1);
438                let y1 = (y0 + 1).min(src_h - 1);
439
440                let fx = src_x.fract().max(0.0);
441                let fy = src_y.fract().max(0.0);
442
443                for c in 0..4 {
444                    let p00 = f32::from(input[(y0 * src_w + x0) * 4 + c]);
445                    let p10 = f32::from(input[(y0 * src_w + x1) * 4 + c]);
446                    let p01 = f32::from(input[(y1 * src_w + x0) * 4 + c]);
447                    let p11 = f32::from(input[(y1 * src_w + x1) * 4 + c]);
448
449                    let top = p00 * (1.0 - fx) + p10 * fx;
450                    let bot = p01 * (1.0 - fx) + p11 * fx;
451                    pixel[c] = (top * (1.0 - fy) + bot * fy).round().clamp(0.0, 255.0) as u8;
452                }
453            });
454    }
455
456    fn gaussian_blur_impl(
457        input: &[u8],
458        output: &mut [u8],
459        width: usize,
460        height: usize,
461        sigma: f32,
462    ) {
463        let radius = (3.0 * sigma).ceil() as i32;
464        let ksize = (2 * radius + 1) as usize;
465        let two_sigma_sq = 2.0 * sigma * sigma;
466
467        let mut kernel = vec![0.0f32; ksize];
468        let mut sum = 0.0f32;
469        for i in 0..ksize {
470            let x = i as i32 - radius;
471            let v = (-(x * x) as f32 / two_sigma_sq).exp();
472            kernel[i] = v;
473            sum += v;
474        }
475        for v in &mut kernel {
476            *v /= sum;
477        }
478
479        // Horizontal pass → temp
480        let mut temp = vec![0u8; input.len()];
481        temp.par_chunks_exact_mut(4)
482            .enumerate()
483            .for_each(|(i, out)| {
484                let px = i % width;
485                let py = i / width;
486                if py >= height {
487                    return;
488                }
489                for c in 0..4 {
490                    let mut acc = 0.0f32;
491                    for (k, &kw) in kernel.iter().enumerate() {
492                        let sx =
493                            (px as i32 + k as i32 - radius).clamp(0, width as i32 - 1) as usize;
494                        acc += f32::from(input[(py * width + sx) * 4 + c]) * kw;
495                    }
496                    out[c] = acc.round().clamp(0.0, 255.0) as u8;
497                }
498            });
499
500        // Vertical pass → output
501        output
502            .par_chunks_exact_mut(4)
503            .enumerate()
504            .for_each(|(i, out)| {
505                let px = i % width;
506                let py = i / width;
507                if py >= height {
508                    return;
509                }
510                for c in 0..4 {
511                    let mut acc = 0.0f32;
512                    for (k, &kw) in kernel.iter().enumerate() {
513                        let sy =
514                            (py as i32 + k as i32 - radius).clamp(0, height as i32 - 1) as usize;
515                        acc += f32::from(temp[(sy * width + px) * 4 + c]) * kw;
516                    }
517                    out[c] = acc.round().clamp(0.0, 255.0) as u8;
518                }
519            });
520    }
521
522    /// Apply a 3×3 Sobel gradient magnitude filter.
523    fn sobel_impl(input: &[u8], output: &mut [u8], width: usize, height: usize) {
524        // Convert to luminance first, then apply Sobel.
525        let lum: Vec<f32> = input
526            .par_chunks_exact(4)
527            .map(|p| 0.299 * f32::from(p[0]) + 0.587 * f32::from(p[1]) + 0.114 * f32::from(p[2]))
528            .collect();
529
530        output
531            .par_chunks_exact_mut(4)
532            .enumerate()
533            .for_each(|(i, out)| {
534                let x = (i % width) as i32;
535                let y = (i / width) as i32;
536
537                if x == 0 || x == (width as i32 - 1) || y == 0 || y == (height as i32 - 1) {
538                    out.fill(0);
539                    return;
540                }
541
542                // Sobel kernels
543                let gx = -lum[(y - 1) as usize * width + (x - 1) as usize]
544                    - 2.0 * lum[y as usize * width + (x - 1) as usize]
545                    - lum[(y + 1) as usize * width + (x - 1) as usize]
546                    + lum[(y - 1) as usize * width + (x + 1) as usize]
547                    + 2.0 * lum[y as usize * width + (x + 1) as usize]
548                    + lum[(y + 1) as usize * width + (x + 1) as usize];
549
550                let gy = -lum[(y - 1) as usize * width + (x - 1) as usize]
551                    - 2.0 * lum[(y - 1) as usize * width + x as usize]
552                    - lum[(y - 1) as usize * width + (x + 1) as usize]
553                    + lum[(y + 1) as usize * width + (x - 1) as usize]
554                    + 2.0 * lum[(y + 1) as usize * width + x as usize]
555                    + lum[(y + 1) as usize * width + (x + 1) as usize];
556
557                let mag = (gx * gx + gy * gy).sqrt().clamp(0.0, 255.0) as u8;
558                out[0] = mag;
559                out[1] = mag;
560                out[2] = mag;
561                out[3] = input[i * 4 + 3]; // preserve alpha
562            });
563    }
564
565    fn sharpen_impl(input: &[u8], output: &mut [u8], width: usize, height: usize, amount: f32) {
566        // Unsharp mask: output = input + amount * (input - blurred)
567        let mut blurred = vec![0u8; input.len()];
568        Self::gaussian_blur_impl(input, &mut blurred, width, height, 1.0);
569
570        output
571            .par_chunks_exact_mut(4)
572            .zip(input.par_chunks_exact(4))
573            .zip(blurred.par_chunks_exact(4))
574            .for_each(|((out, orig), blur)| {
575                for c in 0..3 {
576                    let o = f32::from(orig[c]);
577                    let b = f32::from(blur[c]);
578                    let sharpened = o + amount * (o - b);
579                    out[c] = sharpened.round().clamp(0.0, 255.0) as u8;
580                }
581                out[3] = orig[3];
582            });
583    }
584
585    /// 1-D Type-II DCT on a slice of length N.
586    fn dct1d(data: &[f32], out: &mut [f32]) {
587        let n = data.len();
588        let nf = n as f32;
589        for k in 0..n {
590            let mut s = 0.0f32;
591            let kf = k as f32;
592            for (j, &v) in data.iter().enumerate() {
593                let angle = std::f32::consts::PI * kf * (2.0 * j as f32 + 1.0) / (2.0 * nf);
594                s += v * angle.cos();
595            }
596            let scale = if k == 0 {
597                (1.0 / nf).sqrt()
598            } else {
599                (2.0 / nf).sqrt()
600            };
601            out[k] = s * scale;
602        }
603    }
604
605    /// 1-D Type-III IDCT (inverse of `dct1d`) on a slice of length N.
606    fn idct1d(data: &[f32], out: &mut [f32]) {
607        let n = data.len();
608        let nf = n as f32;
609        for j in 0..n {
610            let jf = j as f32;
611            let mut s = data[0] / nf.sqrt();
612            for k in 1..n {
613                let scale = (2.0 / nf).sqrt();
614                let angle = std::f32::consts::PI * k as f32 * (2.0 * jf + 1.0) / (2.0 * nf);
615                s += scale * data[k] * angle.cos();
616            }
617            out[j] = s;
618        }
619    }
620}
621
622impl Default for CpuAccelerator {
623    fn default() -> Self {
624        Self::new()
625    }
626}
627
628impl GpuAccelerator for CpuAccelerator {
629    fn name(&self) -> &'static str {
630        "CPU SIMD"
631    }
632
633    fn is_gpu(&self) -> bool {
634        false
635    }
636
637    fn rgb_to_yuv(&self, input: &[u8], output: &mut [u8], width: u32, height: u32) -> Result<()> {
638        check_rgba_buf(input, width, height, "input")?;
639        check_rgba_buf(output, width, height, "output")?;
640        Self::rgb_to_yuv_impl(input, output);
641        Ok(())
642    }
643
644    fn yuv_to_rgb(&self, input: &[u8], output: &mut [u8], width: u32, height: u32) -> Result<()> {
645        check_rgba_buf(input, width, height, "input")?;
646        check_rgba_buf(output, width, height, "output")?;
647        Self::yuv_to_rgb_impl(input, output);
648        Ok(())
649    }
650
651    #[allow(clippy::too_many_arguments)]
652    fn scale_bilinear(
653        &self,
654        input: &[u8],
655        src_width: u32,
656        src_height: u32,
657        output: &mut [u8],
658        dst_width: u32,
659        dst_height: u32,
660    ) -> Result<()> {
661        check_rgba_buf(input, src_width, src_height, "input")?;
662        check_rgba_buf(output, dst_width, dst_height, "output")?;
663        Self::scale_bilinear_impl(
664            input,
665            src_width as usize,
666            src_height as usize,
667            output,
668            dst_width as usize,
669            dst_height as usize,
670        );
671        Ok(())
672    }
673
674    fn gaussian_blur(
675        &self,
676        input: &[u8],
677        output: &mut [u8],
678        width: u32,
679        height: u32,
680        sigma: f32,
681    ) -> Result<()> {
682        check_rgba_buf(input, width, height, "input")?;
683        check_rgba_buf(output, width, height, "output")?;
684        Self::gaussian_blur_impl(input, output, width as usize, height as usize, sigma);
685        Ok(())
686    }
687
688    fn edge_detect(&self, input: &[u8], output: &mut [u8], width: u32, height: u32) -> Result<()> {
689        check_rgba_buf(input, width, height, "input")?;
690        check_rgba_buf(output, width, height, "output")?;
691        Self::sobel_impl(input, output, width as usize, height as usize);
692        Ok(())
693    }
694
695    fn sharpen(
696        &self,
697        input: &[u8],
698        output: &mut [u8],
699        width: u32,
700        height: u32,
701        amount: f32,
702    ) -> Result<()> {
703        check_rgba_buf(input, width, height, "input")?;
704        check_rgba_buf(output, width, height, "output")?;
705        Self::sharpen_impl(input, output, width as usize, height as usize, amount);
706        Ok(())
707    }
708
709    fn dct_2d(&self, input: &[f32], output: &mut [f32], width: u32, height: u32) -> Result<()> {
710        check_f32_buf(input, width, height)?;
711        check_f32_buf(output, width, height)?;
712        if width % 8 != 0 || height % 8 != 0 {
713            return Err(GpuError::InvalidDimensions { width, height });
714        }
715
716        let w = width as usize;
717        let h = height as usize;
718
719        // Row-wise DCT
720        let mut row_pass = vec![0.0f32; w * h];
721        for row in 0..h {
722            let src = &input[row * w..(row + 1) * w];
723            let dst = &mut row_pass[row * w..(row + 1) * w];
724            Self::dct1d(src, dst);
725        }
726
727        // Column-wise DCT
728        for col in 0..w {
729            let col_data: Vec<f32> = (0..h).map(|r| row_pass[r * w + col]).collect();
730            let mut col_out = vec![0.0f32; h];
731            Self::dct1d(&col_data, &mut col_out);
732            for (r, &v) in col_out.iter().enumerate() {
733                output[r * w + col] = v;
734            }
735        }
736        Ok(())
737    }
738
739    fn idct_2d(&self, input: &[f32], output: &mut [f32], width: u32, height: u32) -> Result<()> {
740        check_f32_buf(input, width, height)?;
741        check_f32_buf(output, width, height)?;
742        if width % 8 != 0 || height % 8 != 0 {
743            return Err(GpuError::InvalidDimensions { width, height });
744        }
745
746        let w = width as usize;
747        let h = height as usize;
748
749        // Column-wise IDCT
750        let mut col_pass = vec![0.0f32; w * h];
751        for col in 0..w {
752            let col_data: Vec<f32> = (0..h).map(|r| input[r * w + col]).collect();
753            let mut col_out = vec![0.0f32; h];
754            Self::idct1d(&col_data, &mut col_out);
755            for (r, &v) in col_out.iter().enumerate() {
756                col_pass[r * w + col] = v;
757            }
758        }
759
760        // Row-wise IDCT
761        for row in 0..h {
762            let src = &col_pass[row * w..(row + 1) * w];
763            let dst = &mut output[row * w..(row + 1) * w];
764            Self::idct1d(src, dst);
765        }
766        Ok(())
767    }
768
769    fn pixel_diff(
770        &self,
771        a: &[u8],
772        b: &[u8],
773        output: &mut [u8],
774        width: u32,
775        height: u32,
776    ) -> Result<()> {
777        check_rgba_buf(a, width, height, "a")?;
778        check_rgba_buf(b, width, height, "b")?;
779        check_rgba_buf(output, width, height, "output")?;
780
781        output
782            .par_chunks_exact_mut(4)
783            .zip(a.par_chunks_exact(4))
784            .zip(b.par_chunks_exact(4))
785            .for_each(|((out, pa), pb)| {
786                for c in 0..4 {
787                    out[c] = pa[c].abs_diff(pb[c]);
788                }
789            });
790        Ok(())
791    }
792
793    fn mse(&self, a: &[u8], b: &[u8], width: u32, height: u32) -> Result<f64> {
794        check_rgba_buf(a, width, height, "a")?;
795        check_rgba_buf(b, width, height, "b")?;
796
797        let sum_sq: f64 = a
798            .par_chunks_exact(4)
799            .zip(b.par_chunks_exact(4))
800            .map(|(pa, pb)| {
801                (0..4)
802                    .map(|c| {
803                        let d = f64::from(pa[c]) - f64::from(pb[c]);
804                        d * d
805                    })
806                    .sum::<f64>()
807            })
808            .sum();
809
810        let n = f64::from(width) * f64::from(height) * 4.0;
811        Ok(sum_sq / n)
812    }
813}
814
815// =============================================================================
816// WGPU/GPU Accelerator  (delegates to CpuAccelerator where GPU is unavailable)
817// =============================================================================
818
819/// GPU-backed accelerator using wgpu (Vulkan / Metal / DX12 / WebGPU).
820///
821/// If no GPU is available the constructor fails; use [`AcceleratorBuilder`]
822/// which transparently falls back to [`CpuAccelerator`].
823///
824/// All operations currently delegate to the CPU path while the wgpu compute
825/// pipeline is being set up.  The struct is intentionally structured so that
826/// individual operations can be migrated to GPU shaders without changing the
827/// public interface.
828pub struct WgpuAccelerator {
829    device: std::sync::Arc<crate::GpuDevice>,
830    /// CPU fallback for operations not yet ported to shaders.
831    cpu: CpuAccelerator,
832    backend_name: String,
833}
834
835impl WgpuAccelerator {
836    /// Create a `WgpuAccelerator` with automatic device selection.
837    ///
838    /// # Errors
839    ///
840    /// Returns [`GpuError::NoAdapter`] if no GPU is available.
841    pub fn new() -> Result<Self> {
842        let device = crate::GpuDevice::new(None)?;
843        let backend_name = format!("{} GPU", device.info().backend);
844        Ok(Self {
845            device: std::sync::Arc::new(device),
846            cpu: CpuAccelerator::new(),
847            backend_name,
848        })
849    }
850
851    /// Underlying GPU device.
852    #[must_use]
853    pub fn gpu_device(&self) -> &std::sync::Arc<crate::GpuDevice> {
854        &self.device
855    }
856}
857
858impl GpuAccelerator for WgpuAccelerator {
859    fn name(&self) -> &str {
860        &self.backend_name
861    }
862
863    fn is_gpu(&self) -> bool {
864        true
865    }
866
867    // The implementations below use the GPU device for simple operations
868    // and fall back to the CPU for complex shaders not yet implemented.
869
870    fn rgb_to_yuv(&self, input: &[u8], output: &mut [u8], width: u32, height: u32) -> Result<()> {
871        crate::ops::ColorSpaceConversion::rgb_to_yuv(
872            &self.device,
873            input,
874            output,
875            width,
876            height,
877            crate::ops::ColorSpace::BT601,
878        )
879    }
880
881    fn yuv_to_rgb(&self, input: &[u8], output: &mut [u8], width: u32, height: u32) -> Result<()> {
882        crate::ops::ColorSpaceConversion::yuv_to_rgb(
883            &self.device,
884            input,
885            output,
886            width,
887            height,
888            crate::ops::ColorSpace::BT601,
889        )
890    }
891
892    #[allow(clippy::too_many_arguments)]
893    fn scale_bilinear(
894        &self,
895        input: &[u8],
896        src_width: u32,
897        src_height: u32,
898        output: &mut [u8],
899        dst_width: u32,
900        dst_height: u32,
901    ) -> Result<()> {
902        crate::ops::ScaleOperation::scale(
903            &self.device,
904            input,
905            src_width,
906            src_height,
907            output,
908            dst_width,
909            dst_height,
910            crate::ops::ScaleFilter::Bilinear,
911        )
912    }
913
914    fn gaussian_blur(
915        &self,
916        input: &[u8],
917        output: &mut [u8],
918        width: u32,
919        height: u32,
920        sigma: f32,
921    ) -> Result<()> {
922        crate::ops::FilterOperation::gaussian_blur(
923            &self.device,
924            input,
925            output,
926            width,
927            height,
928            sigma,
929        )
930    }
931
932    fn edge_detect(&self, input: &[u8], output: &mut [u8], width: u32, height: u32) -> Result<()> {
933        crate::ops::FilterOperation::edge_detect(&self.device, input, output, width, height)
934    }
935
936    fn sharpen(
937        &self,
938        input: &[u8],
939        output: &mut [u8],
940        width: u32,
941        height: u32,
942        amount: f32,
943    ) -> Result<()> {
944        crate::ops::FilterOperation::sharpen(&self.device, input, output, width, height, amount)
945    }
946
947    fn dct_2d(&self, input: &[f32], output: &mut [f32], width: u32, height: u32) -> Result<()> {
948        crate::ops::TransformOperation::dct_2d(&self.device, input, output, width, height)
949    }
950
951    fn idct_2d(&self, input: &[f32], output: &mut [f32], width: u32, height: u32) -> Result<()> {
952        crate::ops::TransformOperation::idct_2d(&self.device, input, output, width, height)
953    }
954
955    fn pixel_diff(
956        &self,
957        a: &[u8],
958        b: &[u8],
959        output: &mut [u8],
960        width: u32,
961        height: u32,
962    ) -> Result<()> {
963        self.cpu.pixel_diff(a, b, output, width, height)
964    }
965
966    fn mse(&self, a: &[u8], b: &[u8], width: u32, height: u32) -> Result<f64> {
967        self.cpu.mse(a, b, width, height)
968    }
969}
970
971// =============================================================================
972// AcceleratorBuilder
973// =============================================================================
974
975/// Ergonomic builder for creating a [`GpuAccelerator`].
976///
977/// Tries GPU first; falls back to CPU automatically.
978///
979/// # Example
980///
981/// ```no_run
982/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
983/// use oximedia_gpu::accelerator::AcceleratorBuilder;
984///
985/// let acc = AcceleratorBuilder::new()
986///     .prefer_gpu(true)
987///     .build()?;
988///
989/// println!("Active backend: {}", acc.name());
990/// # Ok(())
991/// # }
992/// ```
993pub struct AcceleratorBuilder {
994    prefer_gpu: bool,
995    force_cpu: bool,
996}
997
998impl AcceleratorBuilder {
999    /// Create a new builder with default settings (GPU preferred, CPU fallback).
1000    #[must_use]
1001    pub fn new() -> Self {
1002        Self {
1003            prefer_gpu: true,
1004            force_cpu: false,
1005        }
1006    }
1007
1008    /// Set whether to prefer GPU (default: `true`).
1009    #[must_use]
1010    pub fn prefer_gpu(mut self, value: bool) -> Self {
1011        self.prefer_gpu = value;
1012        self
1013    }
1014
1015    /// Force CPU-only mode even if a GPU is available.
1016    #[must_use]
1017    pub fn force_cpu(mut self, value: bool) -> Self {
1018        self.force_cpu = value;
1019        self
1020    }
1021
1022    /// Build the accelerator.
1023    ///
1024    /// Returns `Ok(Box<dyn GpuAccelerator>)`.  Never fails because a CPU
1025    /// fallback is always available.
1026    ///
1027    /// # Errors
1028    ///
1029    /// This method never returns `Err` in practice (CPU fallback is always
1030    /// constructed), but the signature uses `Result` to allow future error
1031    /// propagation.
1032    pub fn build(self) -> Result<Box<dyn GpuAccelerator>> {
1033        if self.force_cpu || !self.prefer_gpu {
1034            return Ok(Box::new(CpuAccelerator::new()));
1035        }
1036
1037        match WgpuAccelerator::new() {
1038            Ok(gpu) => Ok(Box::new(gpu)),
1039            Err(_) => Ok(Box::new(CpuAccelerator::new())),
1040        }
1041    }
1042
1043    /// Build a CPU-only accelerator directly.
1044    #[must_use]
1045    pub fn build_cpu() -> CpuAccelerator {
1046        CpuAccelerator::new()
1047    }
1048}
1049
1050impl Default for AcceleratorBuilder {
1051    fn default() -> Self {
1052        Self::new()
1053    }
1054}
1055
1056// =============================================================================
1057// Tests
1058// =============================================================================
1059
1060#[cfg(test)]
1061mod tests {
1062    use super::*;
1063
1064    fn make_rgba(w: usize, h: usize, fill: u8) -> Vec<u8> {
1065        vec![fill; w * h * 4]
1066    }
1067
1068    // ---- CpuAccelerator ----------------------------------------------------
1069
1070    #[test]
1071    fn test_cpu_accelerator_name() {
1072        let acc = CpuAccelerator::new();
1073        assert_eq!(acc.name(), "CPU SIMD");
1074        assert!(!acc.is_gpu());
1075    }
1076
1077    #[test]
1078    fn test_cpu_rgb_to_yuv_roundtrip() {
1079        // A grey pixel (R=G=B) should survive a round-trip with ≤ 2 LSB error.
1080        let grey = 128u8;
1081        let input = vec![grey, grey, grey, 255u8];
1082        let mut yuv = vec![0u8; 4];
1083        let mut rgb = vec![0u8; 4];
1084
1085        let acc = CpuAccelerator::new();
1086        acc.rgb_to_yuv(&input, &mut yuv, 1, 1)
1087            .expect("RGB to YUV conversion should succeed");
1088        acc.yuv_to_rgb(&yuv, &mut rgb, 1, 1)
1089            .expect("YUV to RGB conversion should succeed");
1090
1091        // Grey channel should be ≈ 128
1092        assert!(
1093            (rgb[0] as i32 - grey as i32).abs() <= 3,
1094            "R mismatch: {}",
1095            rgb[0]
1096        );
1097        assert!(
1098            (rgb[1] as i32 - grey as i32).abs() <= 3,
1099            "G mismatch: {}",
1100            rgb[1]
1101        );
1102        assert!(
1103            (rgb[2] as i32 - grey as i32).abs() <= 3,
1104            "B mismatch: {}",
1105            rgb[2]
1106        );
1107    }
1108
1109    #[test]
1110    fn test_cpu_rgb_to_yuv_invalid_size() {
1111        let acc = CpuAccelerator::new();
1112        let input = vec![0u8; 5]; // wrong size
1113        let mut output = vec![0u8; 4];
1114        assert!(acc.rgb_to_yuv(&input, &mut output, 1, 1).is_err());
1115    }
1116
1117    #[test]
1118    fn test_cpu_scale_bilinear_identity() {
1119        // Scaling a uniform white image should produce a uniform white image.
1120        let w = 16usize;
1121        let h = 16usize;
1122        let input = make_rgba(w, h, 200);
1123        let mut output = make_rgba(w, h, 0);
1124
1125        let acc = CpuAccelerator::new();
1126        acc.scale_bilinear(&input, w as u32, h as u32, &mut output, w as u32, h as u32)
1127            .expect("operation should succeed in test");
1128
1129        // All output pixels should still be white.
1130        for &v in &output {
1131            assert!(v >= 195, "pixel value {v} too low after identity scale");
1132        }
1133    }
1134
1135    #[test]
1136    fn test_cpu_scale_bilinear_upsample() {
1137        let input = make_rgba(2, 2, 255);
1138        let mut output = make_rgba(4, 4, 0);
1139
1140        let acc = CpuAccelerator::new();
1141        acc.scale_bilinear(&input, 2, 2, &mut output, 4, 4)
1142            .expect("bilinear scaling should succeed");
1143
1144        // All output pixels should be white (source was all-white).
1145        for &v in &output {
1146            assert!(v >= 250, "upsampled pixel {v} not white");
1147        }
1148    }
1149
1150    #[test]
1151    fn test_cpu_gaussian_blur_preserves_size() {
1152        let input = make_rgba(8, 8, 128);
1153        let mut output = make_rgba(8, 8, 0);
1154
1155        let acc = CpuAccelerator::new();
1156        acc.gaussian_blur(&input, &mut output, 8, 8, 1.0)
1157            .expect("gaussian blur should succeed");
1158        assert_eq!(output.len(), input.len());
1159    }
1160
1161    #[test]
1162    fn test_cpu_edge_detect_flat_image() {
1163        // A flat-colour image has no edges → gradient magnitude ≈ 0
1164        // (border pixels are excluded).
1165        let input = make_rgba(16, 16, 200);
1166        let mut output = make_rgba(16, 16, 0);
1167
1168        let acc = CpuAccelerator::new();
1169        acc.edge_detect(&input, &mut output, 16, 16)
1170            .expect("edge detection should succeed");
1171
1172        // Interior pixels should be near zero.
1173        for row in 1..15usize {
1174            for col in 1..15usize {
1175                let idx = (row * 16 + col) * 4;
1176                assert!(
1177                    output[idx] < 10,
1178                    "interior edge pixel {} at ({row},{col}) is non-zero",
1179                    output[idx]
1180                );
1181            }
1182        }
1183    }
1184
1185    #[test]
1186    fn test_cpu_sharpen_stable_flat() {
1187        // Sharpening a flat image should leave it unchanged.
1188        let input = make_rgba(8, 8, 128);
1189        let mut output = make_rgba(8, 8, 0);
1190
1191        let acc = CpuAccelerator::new();
1192        acc.sharpen(&input, &mut output, 8, 8, 1.0)
1193            .expect("sharpen should succeed");
1194
1195        // Allow ±2 LSB for accumulation of float rounding.
1196        for (&o, &i) in output.iter().zip(input.iter()) {
1197            assert!(
1198                (o as i32 - i as i32).abs() <= 3,
1199                "sharpen changed flat pixel by more than 3"
1200            );
1201        }
1202    }
1203
1204    #[test]
1205    fn test_cpu_dct_idct_roundtrip() {
1206        let w = 8u32;
1207        let h = 8u32;
1208        let input: Vec<f32> = (0..(w * h)).map(|i| i as f32).collect();
1209        let mut dct_out = vec![0.0f32; (w * h) as usize];
1210        let mut rec = vec![0.0f32; (w * h) as usize];
1211
1212        let acc = CpuAccelerator::new();
1213        acc.dct_2d(&input, &mut dct_out, w, h)
1214            .expect("DCT should succeed");
1215        acc.idct_2d(&dct_out, &mut rec, w, h)
1216            .expect("DCT should succeed");
1217
1218        for (a, b) in input.iter().zip(rec.iter()) {
1219            assert!((a - b).abs() < 1e-3, "DCT round-trip error: {a} vs {b}");
1220        }
1221    }
1222
1223    #[test]
1224    fn test_cpu_dct_invalid_dims() {
1225        let acc = CpuAccelerator::new();
1226        let input = vec![0.0f32; 10];
1227        let mut output = vec![0.0f32; 10];
1228        // 10 is not a multiple of 8 → error
1229        assert!(acc.dct_2d(&input, &mut output, 10, 1).is_err());
1230    }
1231
1232    #[test]
1233    fn test_cpu_pixel_diff_self() {
1234        let img = make_rgba(4, 4, 100);
1235        let mut diff = make_rgba(4, 4, 255);
1236
1237        let acc = CpuAccelerator::new();
1238        acc.pixel_diff(&img, &img, &mut diff, 4, 4)
1239            .expect("pixel diff should succeed");
1240
1241        for &v in &diff {
1242            assert_eq!(v, 0, "self-diff should be zero");
1243        }
1244    }
1245
1246    #[test]
1247    fn test_cpu_mse_identical() {
1248        let img = make_rgba(8, 8, 128);
1249        let acc = CpuAccelerator::new();
1250        let mse = acc
1251            .mse(&img, &img, 8, 8)
1252            .expect("MSE computation should succeed");
1253        assert!(
1254            mse.abs() < 1e-10,
1255            "MSE of identical images should be 0, got {mse}"
1256        );
1257    }
1258
1259    #[test]
1260    fn test_cpu_mse_max_error() {
1261        // Black vs white → max MSE = 255^2 = 65025.
1262        let a = make_rgba(4, 4, 0);
1263        let b = make_rgba(4, 4, 255);
1264        let acc = CpuAccelerator::new();
1265        let mse = acc
1266            .mse(&a, &b, 4, 4)
1267            .expect("MSE computation should succeed");
1268        assert!(
1269            (mse - 65025.0).abs() < 1.0,
1270            "max MSE should be 65025, got {mse}"
1271        );
1272    }
1273
1274    // ---- AcceleratorBuilder ------------------------------------------------
1275
1276    #[test]
1277    fn test_builder_force_cpu() {
1278        let acc = AcceleratorBuilder::new()
1279            .force_cpu(true)
1280            .build()
1281            .expect("accelerator build should succeed");
1282        assert_eq!(acc.name(), "CPU SIMD");
1283        assert!(!acc.is_gpu());
1284    }
1285
1286    #[test]
1287    fn test_builder_build_cpu_static() {
1288        let acc = AcceleratorBuilder::build_cpu();
1289        assert_eq!(acc.name(), "CPU SIMD");
1290    }
1291
1292    #[test]
1293    #[ignore] // Requires GPU hardware probe; run with --ignored
1294    fn test_builder_default_builds() {
1295        // Should never panic even without a GPU.
1296        let acc = AcceleratorBuilder::new()
1297            .build()
1298            .expect("accelerator build should succeed");
1299        assert!(!acc.name().is_empty());
1300    }
1301
1302    #[test]
1303    fn test_cpu_rgb_red_pixel() {
1304        // Red pixel → Y ≈ 0.299 * 255 ≈ 76
1305        let input = vec![255u8, 0, 0, 255];
1306        let mut yuv = vec![0u8; 4];
1307        let acc = CpuAccelerator::new();
1308        acc.rgb_to_yuv(&input, &mut yuv, 1, 1)
1309            .expect("RGB to YUV conversion should succeed");
1310        assert!(
1311            (yuv[0] as i32 - 76).abs() <= 2,
1312            "Y for red should be ~76, got {}",
1313            yuv[0]
1314        );
1315    }
1316}