Skip to main content

oximedia_gpu/
accelerator.rs

1//! `GpuAccelerator` trait and hardware acceleration abstraction.
2//!
3//! This module defines the unified [`GpuAccelerator`] trait that provides a
4//! hardware-agnostic interface for GPU compute operations.  Concrete backends
5//! (Vulkan/WGPU and CPU SIMD) implement the trait so callers never need to
6//! branch on the active backend.
7//!
8//! # Design
9//!
10//! ```text
11//! GpuAccelerator (trait)
12//!    ├── WgpuAccelerator  ── uses wgpu (Vulkan / Metal / DX12 / WebGPU)
13//!    └── CpuAccelerator   ── uses rayon SIMD fallback (always available)
14//! ```
15//!
16//! # Quick Start
17//!
18//! ```no_run
19//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
20//! use oximedia_gpu::accelerator::{AcceleratorBuilder, GpuAccelerator};
21//!
22//! let acc = AcceleratorBuilder::new().build()?;
23//! let name = acc.name();
24//! println!("Using backend: {name}");
25//!
26//! let rgb  = vec![0u8; 1920 * 1080 * 4];
27//! let mut yuv = vec![0u8; 1920 * 1080 * 4];
28//! acc.rgb_to_yuv(&rgb, &mut yuv, 1920, 1080)?;
29//! # Ok(())
30//! # }
31//! ```
32
33#![allow(clippy::cast_possible_truncation)]
34#![allow(clippy::cast_sign_loss)]
35#![allow(clippy::cast_precision_loss)]
36#![allow(clippy::cast_possible_wrap)]
37
38use crate::{GpuError, Result};
39use rayon::prelude::*;
40
41// =============================================================================
42// GpuAccelerator trait
43// =============================================================================
44
45/// Unified interface for hardware-accelerated media operations.
46///
47/// All operations have CPU fallback implementations so that callers do not
48/// need to handle the absence of a GPU.  The trait is object-safe; you can
49/// store it behind `Box<dyn GpuAccelerator>` or `Arc<dyn GpuAccelerator>`.
50pub trait GpuAccelerator: Send + Sync {
51    /// Human-readable backend name (e.g. `"Vulkan"`, `"CPU SIMD"`).
52    fn name(&self) -> &str;
53
54    /// Whether this accelerator uses dedicated GPU hardware.
55    fn is_gpu(&self) -> bool;
56
57    /// Convert packed RGBA data to packed YUVA using BT.601 coefficients.
58    ///
59    /// Both slices must have length `width * height * 4`.
60    ///
61    /// # Errors
62    ///
63    /// Returns [`GpuError::InvalidBufferSize`] if slice lengths do not match
64    /// `width * height * 4`.
65    fn rgb_to_yuv(&self, input: &[u8], output: &mut [u8], width: u32, height: u32) -> Result<()>;
66
67    /// Convert packed YUVA data to packed RGBA using BT.601 coefficients.
68    ///
69    /// Both slices must have length `width * height * 4`.
70    ///
71    /// # Errors
72    ///
73    /// Returns [`GpuError::InvalidBufferSize`] if slice lengths do not match.
74    fn yuv_to_rgb(&self, input: &[u8], output: &mut [u8], width: u32, height: u32) -> Result<()>;
75
76    /// Resize packed RGBA image using bilinear interpolation.
77    ///
78    /// # Errors
79    ///
80    /// Returns an error if buffer sizes are inconsistent with the given
81    /// dimensions.
82    #[allow(clippy::too_many_arguments)]
83    fn scale_bilinear(
84        &self,
85        input: &[u8],
86        src_width: u32,
87        src_height: u32,
88        output: &mut [u8],
89        dst_width: u32,
90        dst_height: u32,
91    ) -> Result<()>;
92
93    /// Apply a separable Gaussian blur to a packed RGBA image.
94    ///
95    /// `sigma` is the standard deviation of the Gaussian kernel in pixels.
96    ///
97    /// # Errors
98    ///
99    /// Returns an error if `input.len() != output.len()` or if either length
100    /// does not equal `width * height * 4`.
101    fn gaussian_blur(
102        &self,
103        input: &[u8],
104        output: &mut [u8],
105        width: u32,
106        height: u32,
107        sigma: f32,
108    ) -> Result<()>;
109
110    /// Detect edges using the Sobel operator on a packed RGBA image.
111    ///
112    /// The output contains per-pixel gradient magnitudes.
113    ///
114    /// # Errors
115    ///
116    /// Returns an error if buffer sizes are inconsistent.
117    fn edge_detect(&self, input: &[u8], output: &mut [u8], width: u32, height: u32) -> Result<()>;
118
119    /// Sharpen a packed RGBA image using an unsharp mask.
120    ///
121    /// `amount` controls the sharpening strength (typical range 0.0–2.0).
122    ///
123    /// # Errors
124    ///
125    /// Returns an error if buffer sizes are inconsistent.
126    fn sharpen(
127        &self,
128        input: &[u8],
129        output: &mut [u8],
130        width: u32,
131        height: u32,
132        amount: f32,
133    ) -> Result<()>;
134
135    /// Compute the 2-D Type-II DCT on a grid of `f32` values.
136    ///
137    /// Both `width` and `height` must be multiples of 8.
138    ///
139    /// # Errors
140    ///
141    /// Returns an error if dimensions are not multiples of 8 or if slice
142    /// lengths do not equal `width * height`.
143    fn dct_2d(&self, input: &[f32], output: &mut [f32], width: u32, height: u32) -> Result<()>;
144
145    /// Compute the 2-D Type-III IDCT (inverse of [`dct_2d`]).
146    ///
147    /// # Errors
148    ///
149    /// Returns the same errors as [`dct_2d`].
150    ///
151    /// [`dct_2d`]: GpuAccelerator::dct_2d
152    fn idct_2d(&self, input: &[f32], output: &mut [f32], width: u32, height: u32) -> Result<()>;
153
154    /// Compute the per-pixel absolute difference between two RGBA images.
155    ///
156    /// # Errors
157    ///
158    /// Returns an error if any buffer length differs from `width * height * 4`.
159    fn pixel_diff(
160        &self,
161        a: &[u8],
162        b: &[u8],
163        output: &mut [u8],
164        width: u32,
165        height: u32,
166    ) -> Result<()>;
167
168    /// Compute the mean squared error between two RGBA images.
169    ///
170    /// Returns the average squared per-channel difference (range 0.0–65 025.0).
171    ///
172    /// # Errors
173    ///
174    /// Returns an error if buffer lengths do not equal `width * height * 4`.
175    fn mse(&self, a: &[u8], b: &[u8], width: u32, height: u32) -> Result<f64>;
176}
177
178// =============================================================================
179// Buffer-size validation helpers
180// =============================================================================
181
182fn check_rgba_buf(buf: &[u8], width: u32, height: u32, label: &str) -> Result<()> {
183    let expected = (width as usize) * (height as usize) * 4;
184    if buf.len() != expected {
185        return Err(GpuError::InvalidBufferSize {
186            expected,
187            actual: buf.len(),
188        });
189    }
190    let _ = label;
191    Ok(())
192}
193
194fn check_f32_buf(buf: &[f32], width: u32, height: u32) -> Result<()> {
195    let expected = (width as usize) * (height as usize);
196    if buf.len() != expected {
197        return Err(GpuError::InvalidBufferSize {
198            expected,
199            actual: buf.len(),
200        });
201    }
202    Ok(())
203}
204
205// =============================================================================
206// CPU Accelerator
207// =============================================================================
208
209/// Pure-CPU accelerator backed by rayon parallel iterators.
210///
211/// This implementation is always available and provides correct (if slower)
212/// results even when no GPU is present.
213pub struct CpuAccelerator {
214    num_threads: usize,
215}
216
217impl CpuAccelerator {
218    /// Create a CPU accelerator using all available threads.
219    #[must_use]
220    pub fn new() -> Self {
221        Self {
222            num_threads: rayon::current_num_threads(),
223        }
224    }
225
226    /// Number of worker threads.
227    #[must_use]
228    pub fn num_threads(&self) -> usize {
229        self.num_threads
230    }
231
232    // ---- Internal helpers --------------------------------------------------
233
234    fn rgb_to_yuv_impl(input: &[u8], output: &mut [u8]) {
235        const KR: f32 = 0.299;
236        const KG: f32 = 0.587;
237        const KB: f32 = 0.114;
238
239        output
240            .par_chunks_exact_mut(4)
241            .zip(input.par_chunks_exact(4))
242            .for_each(|(out, inp)| {
243                let r = f32::from(inp[0]) / 255.0;
244                let g = f32::from(inp[1]) / 255.0;
245                let b = f32::from(inp[2]) / 255.0;
246
247                let y = KR * r + KG * g + KB * b;
248                let u = (b - y) / (2.0 * (1.0 - KB)) + 0.5;
249                let v = (r - y) / (2.0 * (1.0 - KR)) + 0.5;
250
251                out[0] = (y.clamp(0.0, 1.0) * 255.0) as u8;
252                out[1] = (u.clamp(0.0, 1.0) * 255.0) as u8;
253                out[2] = (v.clamp(0.0, 1.0) * 255.0) as u8;
254                out[3] = inp[3];
255            });
256    }
257
258    fn yuv_to_rgb_impl(input: &[u8], output: &mut [u8]) {
259        const KR: f32 = 0.299;
260        const KG: f32 = 0.587;
261        const KB: f32 = 0.114;
262
263        output
264            .par_chunks_exact_mut(4)
265            .zip(input.par_chunks_exact(4))
266            .for_each(|(out, inp)| {
267                let y = f32::from(inp[0]) / 255.0;
268                let u = f32::from(inp[1]) / 255.0 - 0.5;
269                let v = f32::from(inp[2]) / 255.0 - 0.5;
270
271                let r = y + 2.0 * (1.0 - KR) * v;
272                let b = y + 2.0 * (1.0 - KB) * u;
273                let g = (y - KR * r - KB * b) / KG;
274
275                out[0] = (r.clamp(0.0, 1.0) * 255.0) as u8;
276                out[1] = (g.clamp(0.0, 1.0) * 255.0) as u8;
277                out[2] = (b.clamp(0.0, 1.0) * 255.0) as u8;
278                out[3] = inp[3];
279            });
280    }
281
282    fn scale_bilinear_impl(
283        input: &[u8],
284        src_w: usize,
285        src_h: usize,
286        output: &mut [u8],
287        dst_w: usize,
288        dst_h: usize,
289    ) {
290        let x_ratio = src_w as f32 / dst_w as f32;
291        let y_ratio = src_h as f32 / dst_h as f32;
292
293        output
294            .par_chunks_exact_mut(4)
295            .enumerate()
296            .for_each(|(idx, pixel)| {
297                let dst_x = idx % dst_w;
298                let dst_y = idx / dst_w;
299                if dst_y >= dst_h {
300                    return;
301                }
302                let src_x = (dst_x as f32 + 0.5) * x_ratio - 0.5;
303                let src_y = (dst_y as f32 + 0.5) * y_ratio - 0.5;
304
305                let x0 = (src_x.floor().max(0.0) as usize).min(src_w - 1);
306                let y0 = (src_y.floor().max(0.0) as usize).min(src_h - 1);
307                let x1 = (x0 + 1).min(src_w - 1);
308                let y1 = (y0 + 1).min(src_h - 1);
309
310                let fx = src_x.fract().max(0.0);
311                let fy = src_y.fract().max(0.0);
312
313                for c in 0..4 {
314                    let p00 = f32::from(input[(y0 * src_w + x0) * 4 + c]);
315                    let p10 = f32::from(input[(y0 * src_w + x1) * 4 + c]);
316                    let p01 = f32::from(input[(y1 * src_w + x0) * 4 + c]);
317                    let p11 = f32::from(input[(y1 * src_w + x1) * 4 + c]);
318
319                    let top = p00 * (1.0 - fx) + p10 * fx;
320                    let bot = p01 * (1.0 - fx) + p11 * fx;
321                    pixel[c] = (top * (1.0 - fy) + bot * fy).round().clamp(0.0, 255.0) as u8;
322                }
323            });
324    }
325
326    fn gaussian_blur_impl(
327        input: &[u8],
328        output: &mut [u8],
329        width: usize,
330        height: usize,
331        sigma: f32,
332    ) {
333        let radius = (3.0 * sigma).ceil() as i32;
334        let ksize = (2 * radius + 1) as usize;
335        let two_sigma_sq = 2.0 * sigma * sigma;
336
337        let mut kernel = vec![0.0f32; ksize];
338        let mut sum = 0.0f32;
339        for i in 0..ksize {
340            let x = i as i32 - radius;
341            let v = (-(x * x) as f32 / two_sigma_sq).exp();
342            kernel[i] = v;
343            sum += v;
344        }
345        for v in &mut kernel {
346            *v /= sum;
347        }
348
349        // Horizontal pass → temp
350        let mut temp = vec![0u8; input.len()];
351        temp.par_chunks_exact_mut(4)
352            .enumerate()
353            .for_each(|(i, out)| {
354                let px = i % width;
355                let py = i / width;
356                if py >= height {
357                    return;
358                }
359                for c in 0..4 {
360                    let mut acc = 0.0f32;
361                    for (k, &kw) in kernel.iter().enumerate() {
362                        let sx =
363                            (px as i32 + k as i32 - radius).clamp(0, width as i32 - 1) as usize;
364                        acc += f32::from(input[(py * width + sx) * 4 + c]) * kw;
365                    }
366                    out[c] = acc.round().clamp(0.0, 255.0) as u8;
367                }
368            });
369
370        // Vertical pass → output
371        output
372            .par_chunks_exact_mut(4)
373            .enumerate()
374            .for_each(|(i, out)| {
375                let px = i % width;
376                let py = i / width;
377                if py >= height {
378                    return;
379                }
380                for c in 0..4 {
381                    let mut acc = 0.0f32;
382                    for (k, &kw) in kernel.iter().enumerate() {
383                        let sy =
384                            (py as i32 + k as i32 - radius).clamp(0, height as i32 - 1) as usize;
385                        acc += f32::from(temp[(sy * width + px) * 4 + c]) * kw;
386                    }
387                    out[c] = acc.round().clamp(0.0, 255.0) as u8;
388                }
389            });
390    }
391
392    /// Apply a 3×3 Sobel gradient magnitude filter.
393    fn sobel_impl(input: &[u8], output: &mut [u8], width: usize, height: usize) {
394        // Convert to luminance first, then apply Sobel.
395        let lum: Vec<f32> = input
396            .par_chunks_exact(4)
397            .map(|p| 0.299 * f32::from(p[0]) + 0.587 * f32::from(p[1]) + 0.114 * f32::from(p[2]))
398            .collect();
399
400        output
401            .par_chunks_exact_mut(4)
402            .enumerate()
403            .for_each(|(i, out)| {
404                let x = (i % width) as i32;
405                let y = (i / width) as i32;
406
407                if x == 0 || x == (width as i32 - 1) || y == 0 || y == (height as i32 - 1) {
408                    out.fill(0);
409                    return;
410                }
411
412                // Sobel kernels
413                let gx = -lum[(y - 1) as usize * width + (x - 1) as usize]
414                    - 2.0 * lum[y as usize * width + (x - 1) as usize]
415                    - lum[(y + 1) as usize * width + (x - 1) as usize]
416                    + lum[(y - 1) as usize * width + (x + 1) as usize]
417                    + 2.0 * lum[y as usize * width + (x + 1) as usize]
418                    + lum[(y + 1) as usize * width + (x + 1) as usize];
419
420                let gy = -lum[(y - 1) as usize * width + (x - 1) as usize]
421                    - 2.0 * lum[(y - 1) as usize * width + x as usize]
422                    - lum[(y - 1) as usize * width + (x + 1) as usize]
423                    + lum[(y + 1) as usize * width + (x - 1) as usize]
424                    + 2.0 * lum[(y + 1) as usize * width + x as usize]
425                    + lum[(y + 1) as usize * width + (x + 1) as usize];
426
427                let mag = (gx * gx + gy * gy).sqrt().clamp(0.0, 255.0) as u8;
428                out[0] = mag;
429                out[1] = mag;
430                out[2] = mag;
431                out[3] = input[i * 4 + 3]; // preserve alpha
432            });
433    }
434
435    fn sharpen_impl(input: &[u8], output: &mut [u8], width: usize, height: usize, amount: f32) {
436        // Unsharp mask: output = input + amount * (input - blurred)
437        let mut blurred = vec![0u8; input.len()];
438        Self::gaussian_blur_impl(input, &mut blurred, width, height, 1.0);
439
440        output
441            .par_chunks_exact_mut(4)
442            .zip(input.par_chunks_exact(4))
443            .zip(blurred.par_chunks_exact(4))
444            .for_each(|((out, orig), blur)| {
445                for c in 0..3 {
446                    let o = f32::from(orig[c]);
447                    let b = f32::from(blur[c]);
448                    let sharpened = o + amount * (o - b);
449                    out[c] = sharpened.round().clamp(0.0, 255.0) as u8;
450                }
451                out[3] = orig[3];
452            });
453    }
454
455    /// 1-D Type-II DCT on a slice of length N.
456    fn dct1d(data: &[f32], out: &mut [f32]) {
457        let n = data.len();
458        let nf = n as f32;
459        for k in 0..n {
460            let mut s = 0.0f32;
461            let kf = k as f32;
462            for (j, &v) in data.iter().enumerate() {
463                let angle = std::f32::consts::PI * kf * (2.0 * j as f32 + 1.0) / (2.0 * nf);
464                s += v * angle.cos();
465            }
466            let scale = if k == 0 {
467                (1.0 / nf).sqrt()
468            } else {
469                (2.0 / nf).sqrt()
470            };
471            out[k] = s * scale;
472        }
473    }
474
475    /// 1-D Type-III IDCT (inverse of `dct1d`) on a slice of length N.
476    fn idct1d(data: &[f32], out: &mut [f32]) {
477        let n = data.len();
478        let nf = n as f32;
479        for j in 0..n {
480            let jf = j as f32;
481            let mut s = data[0] / nf.sqrt();
482            for k in 1..n {
483                let scale = (2.0 / nf).sqrt();
484                let angle = std::f32::consts::PI * k as f32 * (2.0 * jf + 1.0) / (2.0 * nf);
485                s += scale * data[k] * angle.cos();
486            }
487            out[j] = s;
488        }
489    }
490}
491
492impl Default for CpuAccelerator {
493    fn default() -> Self {
494        Self::new()
495    }
496}
497
498impl GpuAccelerator for CpuAccelerator {
499    fn name(&self) -> &'static str {
500        "CPU SIMD"
501    }
502
503    fn is_gpu(&self) -> bool {
504        false
505    }
506
507    fn rgb_to_yuv(&self, input: &[u8], output: &mut [u8], width: u32, height: u32) -> Result<()> {
508        check_rgba_buf(input, width, height, "input")?;
509        check_rgba_buf(output, width, height, "output")?;
510        Self::rgb_to_yuv_impl(input, output);
511        Ok(())
512    }
513
514    fn yuv_to_rgb(&self, input: &[u8], output: &mut [u8], width: u32, height: u32) -> Result<()> {
515        check_rgba_buf(input, width, height, "input")?;
516        check_rgba_buf(output, width, height, "output")?;
517        Self::yuv_to_rgb_impl(input, output);
518        Ok(())
519    }
520
521    #[allow(clippy::too_many_arguments)]
522    fn scale_bilinear(
523        &self,
524        input: &[u8],
525        src_width: u32,
526        src_height: u32,
527        output: &mut [u8],
528        dst_width: u32,
529        dst_height: u32,
530    ) -> Result<()> {
531        check_rgba_buf(input, src_width, src_height, "input")?;
532        check_rgba_buf(output, dst_width, dst_height, "output")?;
533        Self::scale_bilinear_impl(
534            input,
535            src_width as usize,
536            src_height as usize,
537            output,
538            dst_width as usize,
539            dst_height as usize,
540        );
541        Ok(())
542    }
543
544    fn gaussian_blur(
545        &self,
546        input: &[u8],
547        output: &mut [u8],
548        width: u32,
549        height: u32,
550        sigma: f32,
551    ) -> Result<()> {
552        check_rgba_buf(input, width, height, "input")?;
553        check_rgba_buf(output, width, height, "output")?;
554        Self::gaussian_blur_impl(input, output, width as usize, height as usize, sigma);
555        Ok(())
556    }
557
558    fn edge_detect(&self, input: &[u8], output: &mut [u8], width: u32, height: u32) -> Result<()> {
559        check_rgba_buf(input, width, height, "input")?;
560        check_rgba_buf(output, width, height, "output")?;
561        Self::sobel_impl(input, output, width as usize, height as usize);
562        Ok(())
563    }
564
565    fn sharpen(
566        &self,
567        input: &[u8],
568        output: &mut [u8],
569        width: u32,
570        height: u32,
571        amount: f32,
572    ) -> Result<()> {
573        check_rgba_buf(input, width, height, "input")?;
574        check_rgba_buf(output, width, height, "output")?;
575        Self::sharpen_impl(input, output, width as usize, height as usize, amount);
576        Ok(())
577    }
578
579    fn dct_2d(&self, input: &[f32], output: &mut [f32], width: u32, height: u32) -> Result<()> {
580        check_f32_buf(input, width, height)?;
581        check_f32_buf(output, width, height)?;
582        if width % 8 != 0 || height % 8 != 0 {
583            return Err(GpuError::InvalidDimensions { width, height });
584        }
585
586        let w = width as usize;
587        let h = height as usize;
588
589        // Row-wise DCT
590        let mut row_pass = vec![0.0f32; w * h];
591        for row in 0..h {
592            let src = &input[row * w..(row + 1) * w];
593            let dst = &mut row_pass[row * w..(row + 1) * w];
594            Self::dct1d(src, dst);
595        }
596
597        // Column-wise DCT
598        for col in 0..w {
599            let col_data: Vec<f32> = (0..h).map(|r| row_pass[r * w + col]).collect();
600            let mut col_out = vec![0.0f32; h];
601            Self::dct1d(&col_data, &mut col_out);
602            for (r, &v) in col_out.iter().enumerate() {
603                output[r * w + col] = v;
604            }
605        }
606        Ok(())
607    }
608
609    fn idct_2d(&self, input: &[f32], output: &mut [f32], width: u32, height: u32) -> Result<()> {
610        check_f32_buf(input, width, height)?;
611        check_f32_buf(output, width, height)?;
612        if width % 8 != 0 || height % 8 != 0 {
613            return Err(GpuError::InvalidDimensions { width, height });
614        }
615
616        let w = width as usize;
617        let h = height as usize;
618
619        // Column-wise IDCT
620        let mut col_pass = vec![0.0f32; w * h];
621        for col in 0..w {
622            let col_data: Vec<f32> = (0..h).map(|r| input[r * w + col]).collect();
623            let mut col_out = vec![0.0f32; h];
624            Self::idct1d(&col_data, &mut col_out);
625            for (r, &v) in col_out.iter().enumerate() {
626                col_pass[r * w + col] = v;
627            }
628        }
629
630        // Row-wise IDCT
631        for row in 0..h {
632            let src = &col_pass[row * w..(row + 1) * w];
633            let dst = &mut output[row * w..(row + 1) * w];
634            Self::idct1d(src, dst);
635        }
636        Ok(())
637    }
638
639    fn pixel_diff(
640        &self,
641        a: &[u8],
642        b: &[u8],
643        output: &mut [u8],
644        width: u32,
645        height: u32,
646    ) -> Result<()> {
647        check_rgba_buf(a, width, height, "a")?;
648        check_rgba_buf(b, width, height, "b")?;
649        check_rgba_buf(output, width, height, "output")?;
650
651        output
652            .par_chunks_exact_mut(4)
653            .zip(a.par_chunks_exact(4))
654            .zip(b.par_chunks_exact(4))
655            .for_each(|((out, pa), pb)| {
656                for c in 0..4 {
657                    out[c] = pa[c].abs_diff(pb[c]);
658                }
659            });
660        Ok(())
661    }
662
663    fn mse(&self, a: &[u8], b: &[u8], width: u32, height: u32) -> Result<f64> {
664        check_rgba_buf(a, width, height, "a")?;
665        check_rgba_buf(b, width, height, "b")?;
666
667        let sum_sq: f64 = a
668            .par_chunks_exact(4)
669            .zip(b.par_chunks_exact(4))
670            .map(|(pa, pb)| {
671                (0..4)
672                    .map(|c| {
673                        let d = f64::from(pa[c]) - f64::from(pb[c]);
674                        d * d
675                    })
676                    .sum::<f64>()
677            })
678            .sum();
679
680        let n = f64::from(width) * f64::from(height) * 4.0;
681        Ok(sum_sq / n)
682    }
683}
684
685// =============================================================================
686// WGPU/GPU Accelerator  (delegates to CpuAccelerator where GPU is unavailable)
687// =============================================================================
688
689/// GPU-backed accelerator using wgpu (Vulkan / Metal / DX12 / WebGPU).
690///
691/// If no GPU is available the constructor fails; use [`AcceleratorBuilder`]
692/// which transparently falls back to [`CpuAccelerator`].
693///
694/// All operations currently delegate to the CPU path while the wgpu compute
695/// pipeline is being set up.  The struct is intentionally structured so that
696/// individual operations can be migrated to GPU shaders without changing the
697/// public interface.
698pub struct WgpuAccelerator {
699    device: std::sync::Arc<crate::GpuDevice>,
700    /// CPU fallback for operations not yet ported to shaders.
701    cpu: CpuAccelerator,
702    backend_name: String,
703}
704
705impl WgpuAccelerator {
706    /// Create a `WgpuAccelerator` with automatic device selection.
707    ///
708    /// # Errors
709    ///
710    /// Returns [`GpuError::NoAdapter`] if no GPU is available.
711    pub fn new() -> Result<Self> {
712        let device = crate::GpuDevice::new(None)?;
713        let backend_name = format!("{} GPU", device.info().backend);
714        Ok(Self {
715            device: std::sync::Arc::new(device),
716            cpu: CpuAccelerator::new(),
717            backend_name,
718        })
719    }
720
721    /// Underlying GPU device.
722    #[must_use]
723    pub fn gpu_device(&self) -> &std::sync::Arc<crate::GpuDevice> {
724        &self.device
725    }
726}
727
728impl GpuAccelerator for WgpuAccelerator {
729    fn name(&self) -> &str {
730        &self.backend_name
731    }
732
733    fn is_gpu(&self) -> bool {
734        true
735    }
736
737    // The implementations below use the GPU device for simple operations
738    // and fall back to the CPU for complex shaders not yet implemented.
739
740    fn rgb_to_yuv(&self, input: &[u8], output: &mut [u8], width: u32, height: u32) -> Result<()> {
741        crate::ops::ColorSpaceConversion::rgb_to_yuv(
742            &self.device,
743            input,
744            output,
745            width,
746            height,
747            crate::ops::ColorSpace::BT601,
748        )
749    }
750
751    fn yuv_to_rgb(&self, input: &[u8], output: &mut [u8], width: u32, height: u32) -> Result<()> {
752        crate::ops::ColorSpaceConversion::yuv_to_rgb(
753            &self.device,
754            input,
755            output,
756            width,
757            height,
758            crate::ops::ColorSpace::BT601,
759        )
760    }
761
762    #[allow(clippy::too_many_arguments)]
763    fn scale_bilinear(
764        &self,
765        input: &[u8],
766        src_width: u32,
767        src_height: u32,
768        output: &mut [u8],
769        dst_width: u32,
770        dst_height: u32,
771    ) -> Result<()> {
772        crate::ops::ScaleOperation::scale(
773            &self.device,
774            input,
775            src_width,
776            src_height,
777            output,
778            dst_width,
779            dst_height,
780            crate::ops::ScaleFilter::Bilinear,
781        )
782    }
783
784    fn gaussian_blur(
785        &self,
786        input: &[u8],
787        output: &mut [u8],
788        width: u32,
789        height: u32,
790        sigma: f32,
791    ) -> Result<()> {
792        crate::ops::FilterOperation::gaussian_blur(
793            &self.device,
794            input,
795            output,
796            width,
797            height,
798            sigma,
799        )
800    }
801
802    fn edge_detect(&self, input: &[u8], output: &mut [u8], width: u32, height: u32) -> Result<()> {
803        crate::ops::FilterOperation::edge_detect(&self.device, input, output, width, height)
804    }
805
806    fn sharpen(
807        &self,
808        input: &[u8],
809        output: &mut [u8],
810        width: u32,
811        height: u32,
812        amount: f32,
813    ) -> Result<()> {
814        crate::ops::FilterOperation::sharpen(&self.device, input, output, width, height, amount)
815    }
816
817    fn dct_2d(&self, input: &[f32], output: &mut [f32], width: u32, height: u32) -> Result<()> {
818        crate::ops::TransformOperation::dct_2d(&self.device, input, output, width, height)
819    }
820
821    fn idct_2d(&self, input: &[f32], output: &mut [f32], width: u32, height: u32) -> Result<()> {
822        crate::ops::TransformOperation::idct_2d(&self.device, input, output, width, height)
823    }
824
825    fn pixel_diff(
826        &self,
827        a: &[u8],
828        b: &[u8],
829        output: &mut [u8],
830        width: u32,
831        height: u32,
832    ) -> Result<()> {
833        self.cpu.pixel_diff(a, b, output, width, height)
834    }
835
836    fn mse(&self, a: &[u8], b: &[u8], width: u32, height: u32) -> Result<f64> {
837        self.cpu.mse(a, b, width, height)
838    }
839}
840
841// =============================================================================
842// AcceleratorBuilder
843// =============================================================================
844
845/// Ergonomic builder for creating a [`GpuAccelerator`].
846///
847/// Tries GPU first; falls back to CPU automatically.
848///
849/// # Example
850///
851/// ```no_run
852/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
853/// use oximedia_gpu::accelerator::AcceleratorBuilder;
854///
855/// let acc = AcceleratorBuilder::new()
856///     .prefer_gpu(true)
857///     .build()?;
858///
859/// println!("Active backend: {}", acc.name());
860/// # Ok(())
861/// # }
862/// ```
863pub struct AcceleratorBuilder {
864    prefer_gpu: bool,
865    force_cpu: bool,
866}
867
868impl AcceleratorBuilder {
869    /// Create a new builder with default settings (GPU preferred, CPU fallback).
870    #[must_use]
871    pub fn new() -> Self {
872        Self {
873            prefer_gpu: true,
874            force_cpu: false,
875        }
876    }
877
878    /// Set whether to prefer GPU (default: `true`).
879    #[must_use]
880    pub fn prefer_gpu(mut self, value: bool) -> Self {
881        self.prefer_gpu = value;
882        self
883    }
884
885    /// Force CPU-only mode even if a GPU is available.
886    #[must_use]
887    pub fn force_cpu(mut self, value: bool) -> Self {
888        self.force_cpu = value;
889        self
890    }
891
892    /// Build the accelerator.
893    ///
894    /// Returns `Ok(Box<dyn GpuAccelerator>)`.  Never fails because a CPU
895    /// fallback is always available.
896    ///
897    /// # Errors
898    ///
899    /// This method never returns `Err` in practice (CPU fallback is always
900    /// constructed), but the signature uses `Result` to allow future error
901    /// propagation.
902    pub fn build(self) -> Result<Box<dyn GpuAccelerator>> {
903        if self.force_cpu || !self.prefer_gpu {
904            return Ok(Box::new(CpuAccelerator::new()));
905        }
906
907        match WgpuAccelerator::new() {
908            Ok(gpu) => Ok(Box::new(gpu)),
909            Err(_) => Ok(Box::new(CpuAccelerator::new())),
910        }
911    }
912
913    /// Build a CPU-only accelerator directly.
914    #[must_use]
915    pub fn build_cpu() -> CpuAccelerator {
916        CpuAccelerator::new()
917    }
918}
919
920impl Default for AcceleratorBuilder {
921    fn default() -> Self {
922        Self::new()
923    }
924}
925
926// =============================================================================
927// Tests
928// =============================================================================
929
930#[cfg(test)]
931mod tests {
932    use super::*;
933
934    fn make_rgba(w: usize, h: usize, fill: u8) -> Vec<u8> {
935        vec![fill; w * h * 4]
936    }
937
938    // ---- CpuAccelerator ----------------------------------------------------
939
940    #[test]
941    fn test_cpu_accelerator_name() {
942        let acc = CpuAccelerator::new();
943        assert_eq!(acc.name(), "CPU SIMD");
944        assert!(!acc.is_gpu());
945    }
946
947    #[test]
948    fn test_cpu_rgb_to_yuv_roundtrip() {
949        // A grey pixel (R=G=B) should survive a round-trip with ≤ 2 LSB error.
950        let grey = 128u8;
951        let input = vec![grey, grey, grey, 255u8];
952        let mut yuv = vec![0u8; 4];
953        let mut rgb = vec![0u8; 4];
954
955        let acc = CpuAccelerator::new();
956        acc.rgb_to_yuv(&input, &mut yuv, 1, 1)
957            .expect("RGB to YUV conversion should succeed");
958        acc.yuv_to_rgb(&yuv, &mut rgb, 1, 1)
959            .expect("YUV to RGB conversion should succeed");
960
961        // Grey channel should be ≈ 128
962        assert!(
963            (rgb[0] as i32 - grey as i32).abs() <= 3,
964            "R mismatch: {}",
965            rgb[0]
966        );
967        assert!(
968            (rgb[1] as i32 - grey as i32).abs() <= 3,
969            "G mismatch: {}",
970            rgb[1]
971        );
972        assert!(
973            (rgb[2] as i32 - grey as i32).abs() <= 3,
974            "B mismatch: {}",
975            rgb[2]
976        );
977    }
978
979    #[test]
980    fn test_cpu_rgb_to_yuv_invalid_size() {
981        let acc = CpuAccelerator::new();
982        let input = vec![0u8; 5]; // wrong size
983        let mut output = vec![0u8; 4];
984        assert!(acc.rgb_to_yuv(&input, &mut output, 1, 1).is_err());
985    }
986
987    #[test]
988    fn test_cpu_scale_bilinear_identity() {
989        // Scaling a uniform white image should produce a uniform white image.
990        let w = 16usize;
991        let h = 16usize;
992        let input = make_rgba(w, h, 200);
993        let mut output = make_rgba(w, h, 0);
994
995        let acc = CpuAccelerator::new();
996        acc.scale_bilinear(&input, w as u32, h as u32, &mut output, w as u32, h as u32)
997            .expect("operation should succeed in test");
998
999        // All output pixels should still be white.
1000        for &v in &output {
1001            assert!(v >= 195, "pixel value {v} too low after identity scale");
1002        }
1003    }
1004
1005    #[test]
1006    fn test_cpu_scale_bilinear_upsample() {
1007        let input = make_rgba(2, 2, 255);
1008        let mut output = make_rgba(4, 4, 0);
1009
1010        let acc = CpuAccelerator::new();
1011        acc.scale_bilinear(&input, 2, 2, &mut output, 4, 4)
1012            .expect("bilinear scaling should succeed");
1013
1014        // All output pixels should be white (source was all-white).
1015        for &v in &output {
1016            assert!(v >= 250, "upsampled pixel {v} not white");
1017        }
1018    }
1019
1020    #[test]
1021    fn test_cpu_gaussian_blur_preserves_size() {
1022        let input = make_rgba(8, 8, 128);
1023        let mut output = make_rgba(8, 8, 0);
1024
1025        let acc = CpuAccelerator::new();
1026        acc.gaussian_blur(&input, &mut output, 8, 8, 1.0)
1027            .expect("gaussian blur should succeed");
1028        assert_eq!(output.len(), input.len());
1029    }
1030
1031    #[test]
1032    fn test_cpu_edge_detect_flat_image() {
1033        // A flat-colour image has no edges → gradient magnitude ≈ 0
1034        // (border pixels are excluded).
1035        let input = make_rgba(16, 16, 200);
1036        let mut output = make_rgba(16, 16, 0);
1037
1038        let acc = CpuAccelerator::new();
1039        acc.edge_detect(&input, &mut output, 16, 16)
1040            .expect("edge detection should succeed");
1041
1042        // Interior pixels should be near zero.
1043        for row in 1..15usize {
1044            for col in 1..15usize {
1045                let idx = (row * 16 + col) * 4;
1046                assert!(
1047                    output[idx] < 10,
1048                    "interior edge pixel {} at ({row},{col}) is non-zero",
1049                    output[idx]
1050                );
1051            }
1052        }
1053    }
1054
1055    #[test]
1056    fn test_cpu_sharpen_stable_flat() {
1057        // Sharpening a flat image should leave it unchanged.
1058        let input = make_rgba(8, 8, 128);
1059        let mut output = make_rgba(8, 8, 0);
1060
1061        let acc = CpuAccelerator::new();
1062        acc.sharpen(&input, &mut output, 8, 8, 1.0)
1063            .expect("sharpen should succeed");
1064
1065        // Allow ±2 LSB for accumulation of float rounding.
1066        for (&o, &i) in output.iter().zip(input.iter()) {
1067            assert!(
1068                (o as i32 - i as i32).abs() <= 3,
1069                "sharpen changed flat pixel by more than 3"
1070            );
1071        }
1072    }
1073
1074    #[test]
1075    fn test_cpu_dct_idct_roundtrip() {
1076        let w = 8u32;
1077        let h = 8u32;
1078        let input: Vec<f32> = (0..(w * h)).map(|i| i as f32).collect();
1079        let mut dct_out = vec![0.0f32; (w * h) as usize];
1080        let mut rec = vec![0.0f32; (w * h) as usize];
1081
1082        let acc = CpuAccelerator::new();
1083        acc.dct_2d(&input, &mut dct_out, w, h)
1084            .expect("DCT should succeed");
1085        acc.idct_2d(&dct_out, &mut rec, w, h)
1086            .expect("DCT should succeed");
1087
1088        for (a, b) in input.iter().zip(rec.iter()) {
1089            assert!((a - b).abs() < 1e-3, "DCT round-trip error: {a} vs {b}");
1090        }
1091    }
1092
1093    #[test]
1094    fn test_cpu_dct_invalid_dims() {
1095        let acc = CpuAccelerator::new();
1096        let input = vec![0.0f32; 10];
1097        let mut output = vec![0.0f32; 10];
1098        // 10 is not a multiple of 8 → error
1099        assert!(acc.dct_2d(&input, &mut output, 10, 1).is_err());
1100    }
1101
1102    #[test]
1103    fn test_cpu_pixel_diff_self() {
1104        let img = make_rgba(4, 4, 100);
1105        let mut diff = make_rgba(4, 4, 255);
1106
1107        let acc = CpuAccelerator::new();
1108        acc.pixel_diff(&img, &img, &mut diff, 4, 4)
1109            .expect("pixel diff should succeed");
1110
1111        for &v in &diff {
1112            assert_eq!(v, 0, "self-diff should be zero");
1113        }
1114    }
1115
1116    #[test]
1117    fn test_cpu_mse_identical() {
1118        let img = make_rgba(8, 8, 128);
1119        let acc = CpuAccelerator::new();
1120        let mse = acc
1121            .mse(&img, &img, 8, 8)
1122            .expect("MSE computation should succeed");
1123        assert!(
1124            mse.abs() < 1e-10,
1125            "MSE of identical images should be 0, got {mse}"
1126        );
1127    }
1128
1129    #[test]
1130    fn test_cpu_mse_max_error() {
1131        // Black vs white → max MSE = 255^2 = 65025.
1132        let a = make_rgba(4, 4, 0);
1133        let b = make_rgba(4, 4, 255);
1134        let acc = CpuAccelerator::new();
1135        let mse = acc
1136            .mse(&a, &b, 4, 4)
1137            .expect("MSE computation should succeed");
1138        assert!(
1139            (mse - 65025.0).abs() < 1.0,
1140            "max MSE should be 65025, got {mse}"
1141        );
1142    }
1143
1144    // ---- AcceleratorBuilder ------------------------------------------------
1145
1146    #[test]
1147    fn test_builder_force_cpu() {
1148        let acc = AcceleratorBuilder::new()
1149            .force_cpu(true)
1150            .build()
1151            .expect("accelerator build should succeed");
1152        assert_eq!(acc.name(), "CPU SIMD");
1153        assert!(!acc.is_gpu());
1154    }
1155
1156    #[test]
1157    fn test_builder_build_cpu_static() {
1158        let acc = AcceleratorBuilder::build_cpu();
1159        assert_eq!(acc.name(), "CPU SIMD");
1160    }
1161
1162    #[test]
1163    #[ignore] // Requires GPU hardware probe; run with --ignored
1164    fn test_builder_default_builds() {
1165        // Should never panic even without a GPU.
1166        let acc = AcceleratorBuilder::new()
1167            .build()
1168            .expect("accelerator build should succeed");
1169        assert!(!acc.name().is_empty());
1170    }
1171
1172    #[test]
1173    fn test_cpu_rgb_red_pixel() {
1174        // Red pixel → Y ≈ 0.299 * 255 ≈ 76
1175        let input = vec![255u8, 0, 0, 255];
1176        let mut yuv = vec![0u8; 4];
1177        let acc = CpuAccelerator::new();
1178        acc.rgb_to_yuv(&input, &mut yuv, 1, 1)
1179            .expect("RGB to YUV conversion should succeed");
1180        assert!(
1181            (yuv[0] as i32 - 76).abs() <= 2,
1182            "Y for red should be ~76, got {}",
1183            yuv[0]
1184        );
1185    }
1186}