Skip to main content

oximedia_gpu/
accelerator.rs

1//! `GpuAccelerator` trait and hardware acceleration abstraction.
2//!
3//! This module defines the unified [`GpuAccelerator`] trait that provides a
4//! hardware-agnostic interface for GPU compute operations.  Concrete backends
5//! (Vulkan/WGPU and CPU SIMD) implement the trait so callers never need to
6//! branch on the active backend.
7//!
8//! # Design
9//!
10//! ```text
11//! GpuAccelerator (trait)
12//!    ├── WgpuAccelerator  ── uses wgpu (Vulkan / Metal / DX12 / WebGPU)
13//!    └── CpuAccelerator   ── uses rayon SIMD fallback (always available)
14//! ```
15//!
16//! # Quick Start
17//!
18//! ```no_run
19//! use oximedia_gpu::accelerator::{AcceleratorBuilder, GpuAccelerator};
20//!
21//! let acc = AcceleratorBuilder::new().build().unwrap();
22//! let name = acc.name();
23//! println!("Using backend: {name}");
24//!
25//! let rgb  = vec![0u8; 1920 * 1080 * 4];
26//! let mut yuv = vec![0u8; 1920 * 1080 * 4];
27//! acc.rgb_to_yuv(&rgb, &mut yuv, 1920, 1080).unwrap();
28//! ```
29
30#![allow(clippy::cast_possible_truncation)]
31#![allow(clippy::cast_sign_loss)]
32#![allow(clippy::cast_precision_loss)]
33#![allow(clippy::cast_possible_wrap)]
34
35use crate::{GpuError, Result};
36use rayon::prelude::*;
37
38// =============================================================================
39// GpuAccelerator trait
40// =============================================================================
41
42/// Unified interface for hardware-accelerated media operations.
43///
44/// All operations have CPU fallback implementations so that callers do not
45/// need to handle the absence of a GPU.  The trait is object-safe; you can
46/// store it behind `Box<dyn GpuAccelerator>` or `Arc<dyn GpuAccelerator>`.
47pub trait GpuAccelerator: Send + Sync {
48    /// Human-readable backend name (e.g. `"Vulkan"`, `"CPU SIMD"`).
49    fn name(&self) -> &str;
50
51    /// Whether this accelerator uses dedicated GPU hardware.
52    fn is_gpu(&self) -> bool;
53
54    /// Convert packed RGBA data to packed YUVA using BT.601 coefficients.
55    ///
56    /// Both slices must have length `width * height * 4`.
57    ///
58    /// # Errors
59    ///
60    /// Returns [`GpuError::InvalidBufferSize`] if slice lengths do not match
61    /// `width * height * 4`.
62    fn rgb_to_yuv(&self, input: &[u8], output: &mut [u8], width: u32, height: u32) -> Result<()>;
63
64    /// Convert packed YUVA data to packed RGBA using BT.601 coefficients.
65    ///
66    /// Both slices must have length `width * height * 4`.
67    ///
68    /// # Errors
69    ///
70    /// Returns [`GpuError::InvalidBufferSize`] if slice lengths do not match.
71    fn yuv_to_rgb(&self, input: &[u8], output: &mut [u8], width: u32, height: u32) -> Result<()>;
72
73    /// Resize packed RGBA image using bilinear interpolation.
74    ///
75    /// # Errors
76    ///
77    /// Returns an error if buffer sizes are inconsistent with the given
78    /// dimensions.
79    #[allow(clippy::too_many_arguments)]
80    fn scale_bilinear(
81        &self,
82        input: &[u8],
83        src_width: u32,
84        src_height: u32,
85        output: &mut [u8],
86        dst_width: u32,
87        dst_height: u32,
88    ) -> Result<()>;
89
90    /// Apply a separable Gaussian blur to a packed RGBA image.
91    ///
92    /// `sigma` is the standard deviation of the Gaussian kernel in pixels.
93    ///
94    /// # Errors
95    ///
96    /// Returns an error if `input.len() != output.len()` or if either length
97    /// does not equal `width * height * 4`.
98    fn gaussian_blur(
99        &self,
100        input: &[u8],
101        output: &mut [u8],
102        width: u32,
103        height: u32,
104        sigma: f32,
105    ) -> Result<()>;
106
107    /// Detect edges using the Sobel operator on a packed RGBA image.
108    ///
109    /// The output contains per-pixel gradient magnitudes.
110    ///
111    /// # Errors
112    ///
113    /// Returns an error if buffer sizes are inconsistent.
114    fn edge_detect(&self, input: &[u8], output: &mut [u8], width: u32, height: u32) -> Result<()>;
115
116    /// Sharpen a packed RGBA image using an unsharp mask.
117    ///
118    /// `amount` controls the sharpening strength (typical range 0.0–2.0).
119    ///
120    /// # Errors
121    ///
122    /// Returns an error if buffer sizes are inconsistent.
123    fn sharpen(
124        &self,
125        input: &[u8],
126        output: &mut [u8],
127        width: u32,
128        height: u32,
129        amount: f32,
130    ) -> Result<()>;
131
132    /// Compute the 2-D Type-II DCT on a grid of `f32` values.
133    ///
134    /// Both `width` and `height` must be multiples of 8.
135    ///
136    /// # Errors
137    ///
138    /// Returns an error if dimensions are not multiples of 8 or if slice
139    /// lengths do not equal `width * height`.
140    fn dct_2d(&self, input: &[f32], output: &mut [f32], width: u32, height: u32) -> Result<()>;
141
142    /// Compute the 2-D Type-III IDCT (inverse of [`dct_2d`]).
143    ///
144    /// # Errors
145    ///
146    /// Returns the same errors as [`dct_2d`].
147    ///
148    /// [`dct_2d`]: GpuAccelerator::dct_2d
149    fn idct_2d(&self, input: &[f32], output: &mut [f32], width: u32, height: u32) -> Result<()>;
150
151    /// Compute the per-pixel absolute difference between two RGBA images.
152    ///
153    /// # Errors
154    ///
155    /// Returns an error if any buffer length differs from `width * height * 4`.
156    fn pixel_diff(
157        &self,
158        a: &[u8],
159        b: &[u8],
160        output: &mut [u8],
161        width: u32,
162        height: u32,
163    ) -> Result<()>;
164
165    /// Compute the mean squared error between two RGBA images.
166    ///
167    /// Returns the average squared per-channel difference (range 0.0–65 025.0).
168    ///
169    /// # Errors
170    ///
171    /// Returns an error if buffer lengths do not equal `width * height * 4`.
172    fn mse(&self, a: &[u8], b: &[u8], width: u32, height: u32) -> Result<f64>;
173}
174
175// =============================================================================
176// Buffer-size validation helpers
177// =============================================================================
178
179fn check_rgba_buf(buf: &[u8], width: u32, height: u32, label: &str) -> Result<()> {
180    let expected = (width as usize) * (height as usize) * 4;
181    if buf.len() != expected {
182        return Err(GpuError::InvalidBufferSize {
183            expected,
184            actual: buf.len(),
185        });
186    }
187    let _ = label;
188    Ok(())
189}
190
191fn check_f32_buf(buf: &[f32], width: u32, height: u32) -> Result<()> {
192    let expected = (width as usize) * (height as usize);
193    if buf.len() != expected {
194        return Err(GpuError::InvalidBufferSize {
195            expected,
196            actual: buf.len(),
197        });
198    }
199    Ok(())
200}
201
202// =============================================================================
203// CPU Accelerator
204// =============================================================================
205
206/// Pure-CPU accelerator backed by rayon parallel iterators.
207///
208/// This implementation is always available and provides correct (if slower)
209/// results even when no GPU is present.
210pub struct CpuAccelerator {
211    num_threads: usize,
212}
213
214impl CpuAccelerator {
215    /// Create a CPU accelerator using all available threads.
216    #[must_use]
217    pub fn new() -> Self {
218        Self {
219            num_threads: rayon::current_num_threads(),
220        }
221    }
222
223    /// Number of worker threads.
224    #[must_use]
225    pub fn num_threads(&self) -> usize {
226        self.num_threads
227    }
228
229    // ---- Internal helpers --------------------------------------------------
230
231    fn rgb_to_yuv_impl(input: &[u8], output: &mut [u8]) {
232        const KR: f32 = 0.299;
233        const KG: f32 = 0.587;
234        const KB: f32 = 0.114;
235
236        output
237            .par_chunks_exact_mut(4)
238            .zip(input.par_chunks_exact(4))
239            .for_each(|(out, inp)| {
240                let r = f32::from(inp[0]) / 255.0;
241                let g = f32::from(inp[1]) / 255.0;
242                let b = f32::from(inp[2]) / 255.0;
243
244                let y = KR * r + KG * g + KB * b;
245                let u = (b - y) / (2.0 * (1.0 - KB)) + 0.5;
246                let v = (r - y) / (2.0 * (1.0 - KR)) + 0.5;
247
248                out[0] = (y.clamp(0.0, 1.0) * 255.0) as u8;
249                out[1] = (u.clamp(0.0, 1.0) * 255.0) as u8;
250                out[2] = (v.clamp(0.0, 1.0) * 255.0) as u8;
251                out[3] = inp[3];
252            });
253    }
254
255    fn yuv_to_rgb_impl(input: &[u8], output: &mut [u8]) {
256        const KR: f32 = 0.299;
257        const KG: f32 = 0.587;
258        const KB: f32 = 0.114;
259
260        output
261            .par_chunks_exact_mut(4)
262            .zip(input.par_chunks_exact(4))
263            .for_each(|(out, inp)| {
264                let y = f32::from(inp[0]) / 255.0;
265                let u = f32::from(inp[1]) / 255.0 - 0.5;
266                let v = f32::from(inp[2]) / 255.0 - 0.5;
267
268                let r = y + 2.0 * (1.0 - KR) * v;
269                let b = y + 2.0 * (1.0 - KB) * u;
270                let g = (y - KR * r - KB * b) / KG;
271
272                out[0] = (r.clamp(0.0, 1.0) * 255.0) as u8;
273                out[1] = (g.clamp(0.0, 1.0) * 255.0) as u8;
274                out[2] = (b.clamp(0.0, 1.0) * 255.0) as u8;
275                out[3] = inp[3];
276            });
277    }
278
279    fn scale_bilinear_impl(
280        input: &[u8],
281        src_w: usize,
282        src_h: usize,
283        output: &mut [u8],
284        dst_w: usize,
285        dst_h: usize,
286    ) {
287        let x_ratio = src_w as f32 / dst_w as f32;
288        let y_ratio = src_h as f32 / dst_h as f32;
289
290        output
291            .par_chunks_exact_mut(4)
292            .enumerate()
293            .for_each(|(idx, pixel)| {
294                let dst_x = idx % dst_w;
295                let dst_y = idx / dst_w;
296                if dst_y >= dst_h {
297                    return;
298                }
299                let src_x = (dst_x as f32 + 0.5) * x_ratio - 0.5;
300                let src_y = (dst_y as f32 + 0.5) * y_ratio - 0.5;
301
302                let x0 = (src_x.floor().max(0.0) as usize).min(src_w - 1);
303                let y0 = (src_y.floor().max(0.0) as usize).min(src_h - 1);
304                let x1 = (x0 + 1).min(src_w - 1);
305                let y1 = (y0 + 1).min(src_h - 1);
306
307                let fx = src_x.fract().max(0.0);
308                let fy = src_y.fract().max(0.0);
309
310                for c in 0..4 {
311                    let p00 = f32::from(input[(y0 * src_w + x0) * 4 + c]);
312                    let p10 = f32::from(input[(y0 * src_w + x1) * 4 + c]);
313                    let p01 = f32::from(input[(y1 * src_w + x0) * 4 + c]);
314                    let p11 = f32::from(input[(y1 * src_w + x1) * 4 + c]);
315
316                    let top = p00 * (1.0 - fx) + p10 * fx;
317                    let bot = p01 * (1.0 - fx) + p11 * fx;
318                    pixel[c] = (top * (1.0 - fy) + bot * fy).round().clamp(0.0, 255.0) as u8;
319                }
320            });
321    }
322
323    fn gaussian_blur_impl(
324        input: &[u8],
325        output: &mut [u8],
326        width: usize,
327        height: usize,
328        sigma: f32,
329    ) {
330        let radius = (3.0 * sigma).ceil() as i32;
331        let ksize = (2 * radius + 1) as usize;
332        let two_sigma_sq = 2.0 * sigma * sigma;
333
334        let mut kernel = vec![0.0f32; ksize];
335        let mut sum = 0.0f32;
336        for i in 0..ksize {
337            let x = i as i32 - radius;
338            let v = (-(x * x) as f32 / two_sigma_sq).exp();
339            kernel[i] = v;
340            sum += v;
341        }
342        for v in &mut kernel {
343            *v /= sum;
344        }
345
346        // Horizontal pass → temp
347        let mut temp = vec![0u8; input.len()];
348        temp.par_chunks_exact_mut(4)
349            .enumerate()
350            .for_each(|(i, out)| {
351                let px = i % width;
352                let py = i / width;
353                if py >= height {
354                    return;
355                }
356                for c in 0..4 {
357                    let mut acc = 0.0f32;
358                    for (k, &kw) in kernel.iter().enumerate() {
359                        let sx =
360                            (px as i32 + k as i32 - radius).clamp(0, width as i32 - 1) as usize;
361                        acc += f32::from(input[(py * width + sx) * 4 + c]) * kw;
362                    }
363                    out[c] = acc.round().clamp(0.0, 255.0) as u8;
364                }
365            });
366
367        // Vertical pass → output
368        output
369            .par_chunks_exact_mut(4)
370            .enumerate()
371            .for_each(|(i, out)| {
372                let px = i % width;
373                let py = i / width;
374                if py >= height {
375                    return;
376                }
377                for c in 0..4 {
378                    let mut acc = 0.0f32;
379                    for (k, &kw) in kernel.iter().enumerate() {
380                        let sy =
381                            (py as i32 + k as i32 - radius).clamp(0, height as i32 - 1) as usize;
382                        acc += f32::from(temp[(sy * width + px) * 4 + c]) * kw;
383                    }
384                    out[c] = acc.round().clamp(0.0, 255.0) as u8;
385                }
386            });
387    }
388
389    /// Apply a 3×3 Sobel gradient magnitude filter.
390    fn sobel_impl(input: &[u8], output: &mut [u8], width: usize, height: usize) {
391        // Convert to luminance first, then apply Sobel.
392        let lum: Vec<f32> = input
393            .par_chunks_exact(4)
394            .map(|p| 0.299 * f32::from(p[0]) + 0.587 * f32::from(p[1]) + 0.114 * f32::from(p[2]))
395            .collect();
396
397        output
398            .par_chunks_exact_mut(4)
399            .enumerate()
400            .for_each(|(i, out)| {
401                let x = (i % width) as i32;
402                let y = (i / width) as i32;
403
404                if x == 0 || x == (width as i32 - 1) || y == 0 || y == (height as i32 - 1) {
405                    out.fill(0);
406                    return;
407                }
408
409                // Sobel kernels
410                let gx = -lum[(y - 1) as usize * width + (x - 1) as usize]
411                    - 2.0 * lum[y as usize * width + (x - 1) as usize]
412                    - lum[(y + 1) as usize * width + (x - 1) as usize]
413                    + lum[(y - 1) as usize * width + (x + 1) as usize]
414                    + 2.0 * lum[y as usize * width + (x + 1) as usize]
415                    + lum[(y + 1) as usize * width + (x + 1) as usize];
416
417                let gy = -lum[(y - 1) as usize * width + (x - 1) as usize]
418                    - 2.0 * lum[(y - 1) as usize * width + x as usize]
419                    - lum[(y - 1) as usize * width + (x + 1) as usize]
420                    + lum[(y + 1) as usize * width + (x - 1) as usize]
421                    + 2.0 * lum[(y + 1) as usize * width + x as usize]
422                    + lum[(y + 1) as usize * width + (x + 1) as usize];
423
424                let mag = (gx * gx + gy * gy).sqrt().clamp(0.0, 255.0) as u8;
425                out[0] = mag;
426                out[1] = mag;
427                out[2] = mag;
428                out[3] = input[i * 4 + 3]; // preserve alpha
429            });
430    }
431
432    fn sharpen_impl(input: &[u8], output: &mut [u8], width: usize, height: usize, amount: f32) {
433        // Unsharp mask: output = input + amount * (input - blurred)
434        let mut blurred = vec![0u8; input.len()];
435        Self::gaussian_blur_impl(input, &mut blurred, width, height, 1.0);
436
437        output
438            .par_chunks_exact_mut(4)
439            .zip(input.par_chunks_exact(4))
440            .zip(blurred.par_chunks_exact(4))
441            .for_each(|((out, orig), blur)| {
442                for c in 0..3 {
443                    let o = f32::from(orig[c]);
444                    let b = f32::from(blur[c]);
445                    let sharpened = o + amount * (o - b);
446                    out[c] = sharpened.round().clamp(0.0, 255.0) as u8;
447                }
448                out[3] = orig[3];
449            });
450    }
451
452    /// 1-D Type-II DCT on a slice of length N.
453    fn dct1d(data: &[f32], out: &mut [f32]) {
454        let n = data.len();
455        let nf = n as f32;
456        for k in 0..n {
457            let mut s = 0.0f32;
458            let kf = k as f32;
459            for (j, &v) in data.iter().enumerate() {
460                let angle = std::f32::consts::PI * kf * (2.0 * j as f32 + 1.0) / (2.0 * nf);
461                s += v * angle.cos();
462            }
463            let scale = if k == 0 {
464                (1.0 / nf).sqrt()
465            } else {
466                (2.0 / nf).sqrt()
467            };
468            out[k] = s * scale;
469        }
470    }
471
472    /// 1-D Type-III IDCT (inverse of `dct1d`) on a slice of length N.
473    fn idct1d(data: &[f32], out: &mut [f32]) {
474        let n = data.len();
475        let nf = n as f32;
476        for j in 0..n {
477            let jf = j as f32;
478            let mut s = data[0] / nf.sqrt();
479            for k in 1..n {
480                let scale = (2.0 / nf).sqrt();
481                let angle = std::f32::consts::PI * k as f32 * (2.0 * jf + 1.0) / (2.0 * nf);
482                s += scale * data[k] * angle.cos();
483            }
484            out[j] = s;
485        }
486    }
487}
488
489impl Default for CpuAccelerator {
490    fn default() -> Self {
491        Self::new()
492    }
493}
494
495impl GpuAccelerator for CpuAccelerator {
496    fn name(&self) -> &'static str {
497        "CPU SIMD"
498    }
499
500    fn is_gpu(&self) -> bool {
501        false
502    }
503
504    fn rgb_to_yuv(&self, input: &[u8], output: &mut [u8], width: u32, height: u32) -> Result<()> {
505        check_rgba_buf(input, width, height, "input")?;
506        check_rgba_buf(output, width, height, "output")?;
507        Self::rgb_to_yuv_impl(input, output);
508        Ok(())
509    }
510
511    fn yuv_to_rgb(&self, input: &[u8], output: &mut [u8], width: u32, height: u32) -> Result<()> {
512        check_rgba_buf(input, width, height, "input")?;
513        check_rgba_buf(output, width, height, "output")?;
514        Self::yuv_to_rgb_impl(input, output);
515        Ok(())
516    }
517
518    #[allow(clippy::too_many_arguments)]
519    fn scale_bilinear(
520        &self,
521        input: &[u8],
522        src_width: u32,
523        src_height: u32,
524        output: &mut [u8],
525        dst_width: u32,
526        dst_height: u32,
527    ) -> Result<()> {
528        check_rgba_buf(input, src_width, src_height, "input")?;
529        check_rgba_buf(output, dst_width, dst_height, "output")?;
530        Self::scale_bilinear_impl(
531            input,
532            src_width as usize,
533            src_height as usize,
534            output,
535            dst_width as usize,
536            dst_height as usize,
537        );
538        Ok(())
539    }
540
541    fn gaussian_blur(
542        &self,
543        input: &[u8],
544        output: &mut [u8],
545        width: u32,
546        height: u32,
547        sigma: f32,
548    ) -> Result<()> {
549        check_rgba_buf(input, width, height, "input")?;
550        check_rgba_buf(output, width, height, "output")?;
551        Self::gaussian_blur_impl(input, output, width as usize, height as usize, sigma);
552        Ok(())
553    }
554
555    fn edge_detect(&self, input: &[u8], output: &mut [u8], width: u32, height: u32) -> Result<()> {
556        check_rgba_buf(input, width, height, "input")?;
557        check_rgba_buf(output, width, height, "output")?;
558        Self::sobel_impl(input, output, width as usize, height as usize);
559        Ok(())
560    }
561
562    fn sharpen(
563        &self,
564        input: &[u8],
565        output: &mut [u8],
566        width: u32,
567        height: u32,
568        amount: f32,
569    ) -> Result<()> {
570        check_rgba_buf(input, width, height, "input")?;
571        check_rgba_buf(output, width, height, "output")?;
572        Self::sharpen_impl(input, output, width as usize, height as usize, amount);
573        Ok(())
574    }
575
576    fn dct_2d(&self, input: &[f32], output: &mut [f32], width: u32, height: u32) -> Result<()> {
577        check_f32_buf(input, width, height)?;
578        check_f32_buf(output, width, height)?;
579        if width % 8 != 0 || height % 8 != 0 {
580            return Err(GpuError::InvalidDimensions { width, height });
581        }
582
583        let w = width as usize;
584        let h = height as usize;
585
586        // Row-wise DCT
587        let mut row_pass = vec![0.0f32; w * h];
588        for row in 0..h {
589            let src = &input[row * w..(row + 1) * w];
590            let dst = &mut row_pass[row * w..(row + 1) * w];
591            Self::dct1d(src, dst);
592        }
593
594        // Column-wise DCT
595        for col in 0..w {
596            let col_data: Vec<f32> = (0..h).map(|r| row_pass[r * w + col]).collect();
597            let mut col_out = vec![0.0f32; h];
598            Self::dct1d(&col_data, &mut col_out);
599            for (r, &v) in col_out.iter().enumerate() {
600                output[r * w + col] = v;
601            }
602        }
603        Ok(())
604    }
605
606    fn idct_2d(&self, input: &[f32], output: &mut [f32], width: u32, height: u32) -> Result<()> {
607        check_f32_buf(input, width, height)?;
608        check_f32_buf(output, width, height)?;
609        if width % 8 != 0 || height % 8 != 0 {
610            return Err(GpuError::InvalidDimensions { width, height });
611        }
612
613        let w = width as usize;
614        let h = height as usize;
615
616        // Column-wise IDCT
617        let mut col_pass = vec![0.0f32; w * h];
618        for col in 0..w {
619            let col_data: Vec<f32> = (0..h).map(|r| input[r * w + col]).collect();
620            let mut col_out = vec![0.0f32; h];
621            Self::idct1d(&col_data, &mut col_out);
622            for (r, &v) in col_out.iter().enumerate() {
623                col_pass[r * w + col] = v;
624            }
625        }
626
627        // Row-wise IDCT
628        for row in 0..h {
629            let src = &col_pass[row * w..(row + 1) * w];
630            let dst = &mut output[row * w..(row + 1) * w];
631            Self::idct1d(src, dst);
632        }
633        Ok(())
634    }
635
636    fn pixel_diff(
637        &self,
638        a: &[u8],
639        b: &[u8],
640        output: &mut [u8],
641        width: u32,
642        height: u32,
643    ) -> Result<()> {
644        check_rgba_buf(a, width, height, "a")?;
645        check_rgba_buf(b, width, height, "b")?;
646        check_rgba_buf(output, width, height, "output")?;
647
648        output
649            .par_chunks_exact_mut(4)
650            .zip(a.par_chunks_exact(4))
651            .zip(b.par_chunks_exact(4))
652            .for_each(|((out, pa), pb)| {
653                for c in 0..4 {
654                    out[c] = pa[c].abs_diff(pb[c]);
655                }
656            });
657        Ok(())
658    }
659
660    fn mse(&self, a: &[u8], b: &[u8], width: u32, height: u32) -> Result<f64> {
661        check_rgba_buf(a, width, height, "a")?;
662        check_rgba_buf(b, width, height, "b")?;
663
664        let sum_sq: f64 = a
665            .par_chunks_exact(4)
666            .zip(b.par_chunks_exact(4))
667            .map(|(pa, pb)| {
668                (0..4)
669                    .map(|c| {
670                        let d = f64::from(pa[c]) - f64::from(pb[c]);
671                        d * d
672                    })
673                    .sum::<f64>()
674            })
675            .sum();
676
677        let n = f64::from(width) * f64::from(height) * 4.0;
678        Ok(sum_sq / n)
679    }
680}
681
682// =============================================================================
683// WGPU/GPU Accelerator  (delegates to CpuAccelerator where GPU is unavailable)
684// =============================================================================
685
686/// GPU-backed accelerator using wgpu (Vulkan / Metal / DX12 / WebGPU).
687///
688/// If no GPU is available the constructor fails; use [`AcceleratorBuilder`]
689/// which transparently falls back to [`CpuAccelerator`].
690///
691/// All operations currently delegate to the CPU path while the wgpu compute
692/// pipeline is being set up.  The struct is intentionally structured so that
693/// individual operations can be migrated to GPU shaders without changing the
694/// public interface.
695pub struct WgpuAccelerator {
696    device: std::sync::Arc<crate::GpuDevice>,
697    /// CPU fallback for operations not yet ported to shaders.
698    cpu: CpuAccelerator,
699    backend_name: String,
700}
701
702impl WgpuAccelerator {
703    /// Create a `WgpuAccelerator` with automatic device selection.
704    ///
705    /// # Errors
706    ///
707    /// Returns [`GpuError::NoAdapter`] if no GPU is available.
708    pub fn new() -> Result<Self> {
709        let device = crate::GpuDevice::new(None)?;
710        let backend_name = format!("{} GPU", device.info().backend);
711        Ok(Self {
712            device: std::sync::Arc::new(device),
713            cpu: CpuAccelerator::new(),
714            backend_name,
715        })
716    }
717
718    /// Underlying GPU device.
719    #[must_use]
720    pub fn gpu_device(&self) -> &std::sync::Arc<crate::GpuDevice> {
721        &self.device
722    }
723}
724
725impl GpuAccelerator for WgpuAccelerator {
726    fn name(&self) -> &str {
727        &self.backend_name
728    }
729
730    fn is_gpu(&self) -> bool {
731        true
732    }
733
734    // The implementations below use the GPU device for simple operations
735    // and fall back to the CPU for complex shaders not yet implemented.
736
737    fn rgb_to_yuv(&self, input: &[u8], output: &mut [u8], width: u32, height: u32) -> Result<()> {
738        crate::ops::ColorSpaceConversion::rgb_to_yuv(
739            &self.device,
740            input,
741            output,
742            width,
743            height,
744            crate::ops::ColorSpace::BT601,
745        )
746    }
747
748    fn yuv_to_rgb(&self, input: &[u8], output: &mut [u8], width: u32, height: u32) -> Result<()> {
749        crate::ops::ColorSpaceConversion::yuv_to_rgb(
750            &self.device,
751            input,
752            output,
753            width,
754            height,
755            crate::ops::ColorSpace::BT601,
756        )
757    }
758
759    #[allow(clippy::too_many_arguments)]
760    fn scale_bilinear(
761        &self,
762        input: &[u8],
763        src_width: u32,
764        src_height: u32,
765        output: &mut [u8],
766        dst_width: u32,
767        dst_height: u32,
768    ) -> Result<()> {
769        crate::ops::ScaleOperation::scale(
770            &self.device,
771            input,
772            src_width,
773            src_height,
774            output,
775            dst_width,
776            dst_height,
777            crate::ops::ScaleFilter::Bilinear,
778        )
779    }
780
781    fn gaussian_blur(
782        &self,
783        input: &[u8],
784        output: &mut [u8],
785        width: u32,
786        height: u32,
787        sigma: f32,
788    ) -> Result<()> {
789        crate::ops::FilterOperation::gaussian_blur(
790            &self.device,
791            input,
792            output,
793            width,
794            height,
795            sigma,
796        )
797    }
798
799    fn edge_detect(&self, input: &[u8], output: &mut [u8], width: u32, height: u32) -> Result<()> {
800        crate::ops::FilterOperation::edge_detect(&self.device, input, output, width, height)
801    }
802
803    fn sharpen(
804        &self,
805        input: &[u8],
806        output: &mut [u8],
807        width: u32,
808        height: u32,
809        amount: f32,
810    ) -> Result<()> {
811        crate::ops::FilterOperation::sharpen(&self.device, input, output, width, height, amount)
812    }
813
814    fn dct_2d(&self, input: &[f32], output: &mut [f32], width: u32, height: u32) -> Result<()> {
815        crate::ops::TransformOperation::dct_2d(&self.device, input, output, width, height)
816    }
817
818    fn idct_2d(&self, input: &[f32], output: &mut [f32], width: u32, height: u32) -> Result<()> {
819        crate::ops::TransformOperation::idct_2d(&self.device, input, output, width, height)
820    }
821
822    fn pixel_diff(
823        &self,
824        a: &[u8],
825        b: &[u8],
826        output: &mut [u8],
827        width: u32,
828        height: u32,
829    ) -> Result<()> {
830        self.cpu.pixel_diff(a, b, output, width, height)
831    }
832
833    fn mse(&self, a: &[u8], b: &[u8], width: u32, height: u32) -> Result<f64> {
834        self.cpu.mse(a, b, width, height)
835    }
836}
837
838// =============================================================================
839// AcceleratorBuilder
840// =============================================================================
841
842/// Ergonomic builder for creating a [`GpuAccelerator`].
843///
844/// Tries GPU first; falls back to CPU automatically.
845///
846/// # Example
847///
848/// ```no_run
849/// use oximedia_gpu::accelerator::AcceleratorBuilder;
850///
851/// let acc = AcceleratorBuilder::new()
852///     .prefer_gpu(true)
853///     .build()
854///     .unwrap();
855///
856/// println!("Active backend: {}", acc.name());
857/// ```
858pub struct AcceleratorBuilder {
859    prefer_gpu: bool,
860    force_cpu: bool,
861}
862
863impl AcceleratorBuilder {
864    /// Create a new builder with default settings (GPU preferred, CPU fallback).
865    #[must_use]
866    pub fn new() -> Self {
867        Self {
868            prefer_gpu: true,
869            force_cpu: false,
870        }
871    }
872
873    /// Set whether to prefer GPU (default: `true`).
874    #[must_use]
875    pub fn prefer_gpu(mut self, value: bool) -> Self {
876        self.prefer_gpu = value;
877        self
878    }
879
880    /// Force CPU-only mode even if a GPU is available.
881    #[must_use]
882    pub fn force_cpu(mut self, value: bool) -> Self {
883        self.force_cpu = value;
884        self
885    }
886
887    /// Build the accelerator.
888    ///
889    /// Returns `Ok(Box<dyn GpuAccelerator>)`.  Never fails because a CPU
890    /// fallback is always available.
891    ///
892    /// # Errors
893    ///
894    /// This method never returns `Err` in practice (CPU fallback is always
895    /// constructed), but the signature uses `Result` to allow future error
896    /// propagation.
897    pub fn build(self) -> Result<Box<dyn GpuAccelerator>> {
898        if self.force_cpu || !self.prefer_gpu {
899            return Ok(Box::new(CpuAccelerator::new()));
900        }
901
902        match WgpuAccelerator::new() {
903            Ok(gpu) => Ok(Box::new(gpu)),
904            Err(_) => Ok(Box::new(CpuAccelerator::new())),
905        }
906    }
907
908    /// Build a CPU-only accelerator directly.
909    #[must_use]
910    pub fn build_cpu() -> CpuAccelerator {
911        CpuAccelerator::new()
912    }
913}
914
915impl Default for AcceleratorBuilder {
916    fn default() -> Self {
917        Self::new()
918    }
919}
920
921// =============================================================================
922// Tests
923// =============================================================================
924
925#[cfg(test)]
926mod tests {
927    use super::*;
928
929    fn make_rgba(w: usize, h: usize, fill: u8) -> Vec<u8> {
930        vec![fill; w * h * 4]
931    }
932
933    // ---- CpuAccelerator ----------------------------------------------------
934
935    #[test]
936    fn test_cpu_accelerator_name() {
937        let acc = CpuAccelerator::new();
938        assert_eq!(acc.name(), "CPU SIMD");
939        assert!(!acc.is_gpu());
940    }
941
942    #[test]
943    fn test_cpu_rgb_to_yuv_roundtrip() {
944        // A grey pixel (R=G=B) should survive a round-trip with ≤ 2 LSB error.
945        let grey = 128u8;
946        let input = vec![grey, grey, grey, 255u8];
947        let mut yuv = vec![0u8; 4];
948        let mut rgb = vec![0u8; 4];
949
950        let acc = CpuAccelerator::new();
951        acc.rgb_to_yuv(&input, &mut yuv, 1, 1).unwrap();
952        acc.yuv_to_rgb(&yuv, &mut rgb, 1, 1).unwrap();
953
954        // Grey channel should be ≈ 128
955        assert!(
956            (rgb[0] as i32 - grey as i32).abs() <= 3,
957            "R mismatch: {}",
958            rgb[0]
959        );
960        assert!(
961            (rgb[1] as i32 - grey as i32).abs() <= 3,
962            "G mismatch: {}",
963            rgb[1]
964        );
965        assert!(
966            (rgb[2] as i32 - grey as i32).abs() <= 3,
967            "B mismatch: {}",
968            rgb[2]
969        );
970    }
971
972    #[test]
973    fn test_cpu_rgb_to_yuv_invalid_size() {
974        let acc = CpuAccelerator::new();
975        let input = vec![0u8; 5]; // wrong size
976        let mut output = vec![0u8; 4];
977        assert!(acc.rgb_to_yuv(&input, &mut output, 1, 1).is_err());
978    }
979
980    #[test]
981    fn test_cpu_scale_bilinear_identity() {
982        // Scaling a uniform white image should produce a uniform white image.
983        let w = 16usize;
984        let h = 16usize;
985        let input = make_rgba(w, h, 200);
986        let mut output = make_rgba(w, h, 0);
987
988        let acc = CpuAccelerator::new();
989        acc.scale_bilinear(&input, w as u32, h as u32, &mut output, w as u32, h as u32)
990            .unwrap();
991
992        // All output pixels should still be white.
993        for &v in &output {
994            assert!(v >= 195, "pixel value {v} too low after identity scale");
995        }
996    }
997
998    #[test]
999    fn test_cpu_scale_bilinear_upsample() {
1000        let input = make_rgba(2, 2, 255);
1001        let mut output = make_rgba(4, 4, 0);
1002
1003        let acc = CpuAccelerator::new();
1004        acc.scale_bilinear(&input, 2, 2, &mut output, 4, 4).unwrap();
1005
1006        // All output pixels should be white (source was all-white).
1007        for &v in &output {
1008            assert!(v >= 250, "upsampled pixel {v} not white");
1009        }
1010    }
1011
1012    #[test]
1013    fn test_cpu_gaussian_blur_preserves_size() {
1014        let input = make_rgba(8, 8, 128);
1015        let mut output = make_rgba(8, 8, 0);
1016
1017        let acc = CpuAccelerator::new();
1018        acc.gaussian_blur(&input, &mut output, 8, 8, 1.0).unwrap();
1019        assert_eq!(output.len(), input.len());
1020    }
1021
1022    #[test]
1023    fn test_cpu_edge_detect_flat_image() {
1024        // A flat-colour image has no edges → gradient magnitude ≈ 0
1025        // (border pixels are excluded).
1026        let input = make_rgba(16, 16, 200);
1027        let mut output = make_rgba(16, 16, 0);
1028
1029        let acc = CpuAccelerator::new();
1030        acc.edge_detect(&input, &mut output, 16, 16).unwrap();
1031
1032        // Interior pixels should be near zero.
1033        for row in 1..15usize {
1034            for col in 1..15usize {
1035                let idx = (row * 16 + col) * 4;
1036                assert!(
1037                    output[idx] < 10,
1038                    "interior edge pixel {} at ({row},{col}) is non-zero",
1039                    output[idx]
1040                );
1041            }
1042        }
1043    }
1044
1045    #[test]
1046    fn test_cpu_sharpen_stable_flat() {
1047        // Sharpening a flat image should leave it unchanged.
1048        let input = make_rgba(8, 8, 128);
1049        let mut output = make_rgba(8, 8, 0);
1050
1051        let acc = CpuAccelerator::new();
1052        acc.sharpen(&input, &mut output, 8, 8, 1.0).unwrap();
1053
1054        // Allow ±2 LSB for accumulation of float rounding.
1055        for (&o, &i) in output.iter().zip(input.iter()) {
1056            assert!(
1057                (o as i32 - i as i32).abs() <= 3,
1058                "sharpen changed flat pixel by more than 3"
1059            );
1060        }
1061    }
1062
1063    #[test]
1064    fn test_cpu_dct_idct_roundtrip() {
1065        let w = 8u32;
1066        let h = 8u32;
1067        let input: Vec<f32> = (0..(w * h)).map(|i| i as f32).collect();
1068        let mut dct_out = vec![0.0f32; (w * h) as usize];
1069        let mut rec = vec![0.0f32; (w * h) as usize];
1070
1071        let acc = CpuAccelerator::new();
1072        acc.dct_2d(&input, &mut dct_out, w, h).unwrap();
1073        acc.idct_2d(&dct_out, &mut rec, w, h).unwrap();
1074
1075        for (a, b) in input.iter().zip(rec.iter()) {
1076            assert!((a - b).abs() < 1e-3, "DCT round-trip error: {a} vs {b}");
1077        }
1078    }
1079
1080    #[test]
1081    fn test_cpu_dct_invalid_dims() {
1082        let acc = CpuAccelerator::new();
1083        let input = vec![0.0f32; 10];
1084        let mut output = vec![0.0f32; 10];
1085        // 10 is not a multiple of 8 → error
1086        assert!(acc.dct_2d(&input, &mut output, 10, 1).is_err());
1087    }
1088
1089    #[test]
1090    fn test_cpu_pixel_diff_self() {
1091        let img = make_rgba(4, 4, 100);
1092        let mut diff = make_rgba(4, 4, 255);
1093
1094        let acc = CpuAccelerator::new();
1095        acc.pixel_diff(&img, &img, &mut diff, 4, 4).unwrap();
1096
1097        for &v in &diff {
1098            assert_eq!(v, 0, "self-diff should be zero");
1099        }
1100    }
1101
1102    #[test]
1103    fn test_cpu_mse_identical() {
1104        let img = make_rgba(8, 8, 128);
1105        let acc = CpuAccelerator::new();
1106        let mse = acc.mse(&img, &img, 8, 8).unwrap();
1107        assert!(
1108            mse.abs() < 1e-10,
1109            "MSE of identical images should be 0, got {mse}"
1110        );
1111    }
1112
1113    #[test]
1114    fn test_cpu_mse_max_error() {
1115        // Black vs white → max MSE = 255^2 = 65025.
1116        let a = make_rgba(4, 4, 0);
1117        let b = make_rgba(4, 4, 255);
1118        let acc = CpuAccelerator::new();
1119        let mse = acc.mse(&a, &b, 4, 4).unwrap();
1120        assert!(
1121            (mse - 65025.0).abs() < 1.0,
1122            "max MSE should be 65025, got {mse}"
1123        );
1124    }
1125
1126    // ---- AcceleratorBuilder ------------------------------------------------
1127
1128    #[test]
1129    fn test_builder_force_cpu() {
1130        let acc = AcceleratorBuilder::new().force_cpu(true).build().unwrap();
1131        assert_eq!(acc.name(), "CPU SIMD");
1132        assert!(!acc.is_gpu());
1133    }
1134
1135    #[test]
1136    fn test_builder_build_cpu_static() {
1137        let acc = AcceleratorBuilder::build_cpu();
1138        assert_eq!(acc.name(), "CPU SIMD");
1139    }
1140
1141    #[test]
1142    #[ignore] // Requires GPU hardware probe; run with --ignored
1143    fn test_builder_default_builds() {
1144        // Should never panic even without a GPU.
1145        let acc = AcceleratorBuilder::new().build().unwrap();
1146        assert!(!acc.name().is_empty());
1147    }
1148
1149    #[test]
1150    fn test_cpu_rgb_red_pixel() {
1151        // Red pixel → Y ≈ 0.299 * 255 ≈ 76
1152        let input = vec![255u8, 0, 0, 255];
1153        let mut yuv = vec![0u8; 4];
1154        let acc = CpuAccelerator::new();
1155        acc.rgb_to_yuv(&input, &mut yuv, 1, 1).unwrap();
1156        assert!(
1157            (yuv[0] as i32 - 76).abs() <= 2,
1158            "Y for red should be ~76, got {}",
1159            yuv[0]
1160        );
1161    }
1162}