Skip to main content

oximedia_gpu/
pipeline_stages.rs

1//! GPU-style image processing pipeline stage abstraction.
2//!
3//! This module defines the [`ImagePipelineStage`] trait and a set of built-in
4//! concrete stages.  Stages are chained in an [`ImageComputePipeline`] that
5//! validates pixel-format compatibility before execution.
6//!
7//! # Example
8//! ```no_run
9//! use oximedia_gpu::pipeline_stages::{
10//!     ImageComputePipeline, GrayscaleStage, GaussianBlurStage, SobelStage,
11//! };
12//!
13//! let mut pipeline = ImageComputePipeline::new(4, 4);
14//! pipeline.add_stage(Box::new(GrayscaleStage)).expect("add grayscale");
15//! pipeline.add_stage(Box::new(GaussianBlurStage { sigma: 1.0 })).expect("add blur");
16//! pipeline.add_stage(Box::new(SobelStage)).expect("add sobel");
17//!
18//! let rgba: Vec<u8> = (0..16).flat_map(|i: u8| [i * 4, i * 4, i * 4, 255]).collect();
19//! let result = pipeline.execute(&rgba).expect("execute");
20//! assert_eq!(result.len(), 4 * 4); // Gray8 output
21//! ```
22
23#![allow(clippy::cast_precision_loss)]
24#![allow(clippy::cast_possible_truncation)]
25#![allow(clippy::cast_sign_loss)]
26#![allow(clippy::cast_lossless)]
27
28// ---------------------------------------------------------------------------
29// PixelFormat
30// ---------------------------------------------------------------------------
31
32/// Pixel-format descriptor used to type-check pipeline stage connections.
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
34pub enum PixelFormat {
35    /// 8-bit RGBA, 4 bytes per pixel.
36    Rgba8,
37    /// 8-bit RGB, 3 bytes per pixel.
38    Rgb8,
39    /// 8-bit grayscale, 1 byte per pixel.
40    Gray8,
41    /// Planar YCbCr 4:2:0.
42    Yuv420,
43    /// Planar YCbCr 4:2:2.
44    Yuv422,
45    /// 8-bit BGRA, 4 bytes per pixel.
46    Bgra8,
47    /// 32-bit float RGBA, 16 bytes per pixel.
48    F32Rgba,
49}
50
51impl PixelFormat {
52    /// Human-readable name.
53    #[must_use]
54    pub fn name(self) -> &'static str {
55        match self {
56            Self::Rgba8 => "RGBA8",
57            Self::Rgb8 => "RGB8",
58            Self::Gray8 => "Gray8",
59            Self::Yuv420 => "YUV420",
60            Self::Yuv422 => "YUV422",
61            Self::Bgra8 => "BGRA8",
62            Self::F32Rgba => "F32RGBA",
63        }
64    }
65
66    /// Bytes per pixel (for packed formats; approximate for planar).
67    #[must_use]
68    pub fn bytes_per_pixel(self) -> usize {
69        match self {
70            Self::Rgba8 => 4,
71            Self::Rgb8 => 3,
72            Self::Gray8 => 1,
73            Self::Yuv420 => 2, // approximate (1.5 but rounded up)
74            Self::Yuv422 => 2,
75            Self::Bgra8 => 4,
76            Self::F32Rgba => 16,
77        }
78    }
79}
80
81impl std::fmt::Display for PixelFormat {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        f.write_str(self.name())
84    }
85}
86
87// ---------------------------------------------------------------------------
88// ImagePipelineStage trait
89// ---------------------------------------------------------------------------
90
91/// A single processing stage in an [`ImageComputePipeline`].
92///
93/// Implementations must be `Send + Sync` so that pipelines can eventually
94/// be executed in parallel contexts.
95pub trait ImagePipelineStage: Send + Sync {
96    /// Human-readable stage name.
97    fn name(&self) -> &str;
98
99    /// Transform `input` and return the resulting buffer.
100    fn process(&self, input: &[u8], width: u32, height: u32) -> Vec<u8>;
101
102    /// Pixel format this stage expects as input.
103    fn input_format(&self) -> PixelFormat;
104
105    /// Pixel format this stage produces as output.
106    fn output_format(&self) -> PixelFormat;
107}
108
109// ---------------------------------------------------------------------------
110// Built-in stages
111// ---------------------------------------------------------------------------
112
113// --- GrayscaleStage ---------------------------------------------------------
114
115/// Convert RGBA to grayscale using the BT.601 luma formula.
116///
117/// Input: [`PixelFormat::Rgba8`], Output: [`PixelFormat::Gray8`].
118pub struct GrayscaleStage;
119
120impl ImagePipelineStage for GrayscaleStage {
121    fn name(&self) -> &str {
122        "Grayscale"
123    }
124
125    fn process(&self, input: &[u8], width: u32, height: u32) -> Vec<u8> {
126        let n = width as usize * height as usize;
127        let expected = n * 4;
128        if input.len() != expected {
129            return Vec::new();
130        }
131        let mut out = Vec::with_capacity(n);
132        // Process 4 pixels per iteration for auto-vectorisation.
133        let chunk4 = n / 4;
134        let rem = n % 4;
135        for i in 0..chunk4 {
136            let base = i * 16;
137            for offset in [0_usize, 4, 8, 12] {
138                let b = base + offset;
139                let y = luma_bt601(input[b], input[b + 1], input[b + 2]);
140                out.push(y);
141            }
142        }
143        let rem_start = chunk4 * 16;
144        for p in 0..rem {
145            let b = rem_start + p * 4;
146            out.push(luma_bt601(input[b], input[b + 1], input[b + 2]));
147        }
148        out
149    }
150
151    fn input_format(&self) -> PixelFormat {
152        PixelFormat::Rgba8
153    }
154    fn output_format(&self) -> PixelFormat {
155        PixelFormat::Gray8
156    }
157}
158
159// --- GaussianBlurStage ------------------------------------------------------
160
161/// Apply a separable Gaussian blur.
162///
163/// Supports both [`PixelFormat::Gray8`] and [`PixelFormat::Rgba8`] input
164/// (configured at construction).  The `input_fmt` field selects the mode;
165/// for Gray8 each byte is a single luma sample, for Rgba8 each channel is
166/// blurred independently.
167pub struct GaussianBlurStage {
168    /// Standard deviation of the Gaussian kernel.
169    pub sigma: f32,
170}
171
172impl ImagePipelineStage for GaussianBlurStage {
173    fn name(&self) -> &str {
174        "GaussianBlur"
175    }
176
177    fn process(&self, input: &[u8], width: u32, height: u32) -> Vec<u8> {
178        let w = width as usize;
179        let h = height as usize;
180        let n = w * h;
181
182        // Determine channel count from input length.
183        let channels = if input.len() == n * 4 {
184            4usize
185        } else if input.len() == n {
186            1
187        } else {
188            return Vec::new();
189        };
190
191        if self.sigma <= 0.0 {
192            return input.to_vec();
193        }
194
195        let radius = (3.0 * self.sigma).ceil() as usize;
196        let kernel = build_1d_gaussian(radius, self.sigma);
197
198        // Separate into per-channel f32 planes, blur each, recombine.
199        let mut out = vec![0u8; input.len()];
200        for ch in 0..channels {
201            let plane: Vec<f32> = (0..n).map(|i| input[i * channels + ch] as f32).collect();
202            let blurred = gaussian_pass_2d(&plane, w, h, &kernel, radius);
203            for i in 0..n {
204                out[i * channels + ch] = blurred[i].round().clamp(0.0, 255.0) as u8;
205            }
206        }
207        out
208    }
209
210    fn input_format(&self) -> PixelFormat {
211        PixelFormat::Gray8
212    }
213    fn output_format(&self) -> PixelFormat {
214        PixelFormat::Gray8
215    }
216}
217
218// --- SobelStage -------------------------------------------------------------
219
220/// Detect edges using the Sobel gradient magnitude operator.
221///
222/// Input: [`PixelFormat::Gray8`], Output: [`PixelFormat::Gray8`].
223pub struct SobelStage;
224
225impl ImagePipelineStage for SobelStage {
226    fn name(&self) -> &str {
227        "Sobel"
228    }
229
230    fn process(&self, input: &[u8], width: u32, height: u32) -> Vec<u8> {
231        let w = width as usize;
232        let h = height as usize;
233        if input.len() != w * h {
234            return Vec::new();
235        }
236
237        let gray_f: Vec<f32> = input.iter().map(|&b| b as f32).collect();
238        let mut out = vec![0u8; w * h];
239
240        for row in 1..h.saturating_sub(1) {
241            let rb = row * w;
242            for col in 1..w.saturating_sub(1) {
243                let tl = gray_f[(row - 1) * w + col - 1];
244                let tc = gray_f[(row - 1) * w + col];
245                let tr = gray_f[(row - 1) * w + col + 1];
246                let ml = gray_f[row * w + col - 1];
247                let mr = gray_f[row * w + col + 1];
248                let bl = gray_f[(row + 1) * w + col - 1];
249                let bc = gray_f[(row + 1) * w + col];
250                let br = gray_f[(row + 1) * w + col + 1];
251
252                let gx = -tl + tr - 2.0 * ml + 2.0 * mr - bl + br;
253                let gy = -tl - 2.0 * tc - tr + bl + 2.0 * bc + br;
254                let mag = (gx * gx + gy * gy).sqrt();
255                out[rb + col] = mag.round().clamp(0.0, 255.0) as u8;
256            }
257        }
258        out
259    }
260
261    fn input_format(&self) -> PixelFormat {
262        PixelFormat::Gray8
263    }
264    fn output_format(&self) -> PixelFormat {
265        PixelFormat::Gray8
266    }
267}
268
269// --- ThresholdStage ---------------------------------------------------------
270
271/// Binary threshold: pixels ≥ threshold → 255, otherwise 0.
272///
273/// Input: [`PixelFormat::Gray8`], Output: [`PixelFormat::Gray8`].
274pub struct ThresholdStage {
275    /// Threshold value (inclusive).
276    pub threshold: u8,
277}
278
279impl ImagePipelineStage for ThresholdStage {
280    fn name(&self) -> &str {
281        "Threshold"
282    }
283
284    fn process(&self, input: &[u8], width: u32, height: u32) -> Vec<u8> {
285        if input.len() != width as usize * height as usize {
286            return Vec::new();
287        }
288        input
289            .iter()
290            .map(|&px| if px >= self.threshold { 255 } else { 0 })
291            .collect()
292    }
293
294    fn input_format(&self) -> PixelFormat {
295        PixelFormat::Gray8
296    }
297    fn output_format(&self) -> PixelFormat {
298        PixelFormat::Gray8
299    }
300}
301
302// --- ColorConvertStage ------------------------------------------------------
303
304/// Convert between pixel formats.
305///
306/// Currently supported conversions:
307/// - `Rgba8 → Gray8` (BT.601 luma)
308/// - `Rgba8 → Bgra8` (channel swap)
309/// - `Bgra8 → Rgba8` (channel swap)
310/// - `Gray8 → Rgba8` (broadcast luma, alpha = 255)
311/// - Identity (same format → passthrough)
312///
313/// All other combinations return an empty vector.
314pub struct ColorConvertStage {
315    /// Source pixel format.
316    pub from: PixelFormat,
317    /// Target pixel format.
318    pub to: PixelFormat,
319}
320
321impl ImagePipelineStage for ColorConvertStage {
322    fn name(&self) -> &str {
323        "ColorConvert"
324    }
325
326    fn process(&self, input: &[u8], width: u32, height: u32) -> Vec<u8> {
327        let n = width as usize * height as usize;
328
329        if self.from == self.to {
330            return input.to_vec();
331        }
332
333        match (self.from, self.to) {
334            (PixelFormat::Rgba8, PixelFormat::Gray8) => {
335                if input.len() != n * 4 {
336                    return Vec::new();
337                }
338                (0..n)
339                    .map(|i| {
340                        let b = i * 4;
341                        luma_bt601(input[b], input[b + 1], input[b + 2])
342                    })
343                    .collect()
344            }
345            (PixelFormat::Gray8, PixelFormat::Rgba8) => {
346                if input.len() != n {
347                    return Vec::new();
348                }
349                input.iter().flat_map(|&px| [px, px, px, 255]).collect()
350            }
351            (PixelFormat::Rgba8, PixelFormat::Bgra8) | (PixelFormat::Bgra8, PixelFormat::Rgba8) => {
352                if input.len() != n * 4 {
353                    return Vec::new();
354                }
355                let mut out = input.to_vec();
356                for i in 0..n {
357                    let b = i * 4;
358                    out.swap(b, b + 2); // swap R and B
359                }
360                out
361            }
362            _ => Vec::new(), // unsupported conversion
363        }
364    }
365
366    fn input_format(&self) -> PixelFormat {
367        self.from
368    }
369    fn output_format(&self) -> PixelFormat {
370        self.to
371    }
372}
373
374// --- OverlayStage -----------------------------------------------------------
375
376/// Composite a pre-loaded overlay image over the input frame.
377///
378/// The overlay is blended using `alpha` as a uniform opacity multiplier
379/// (0.0 = invisible, 1.0 = opaque overlay) in addition to the overlay's
380/// own alpha channel.
381///
382/// Input: [`PixelFormat::Rgba8`], Output: [`PixelFormat::Rgba8`].
383pub struct OverlayStage {
384    /// Overlay image data (RGBA8, same dimensions as the pipeline).
385    pub overlay: Vec<u8>,
386    /// Uniform opacity for the overlay (0.0 – 1.0).
387    pub alpha: f32,
388}
389
390impl ImagePipelineStage for OverlayStage {
391    fn name(&self) -> &str {
392        "Overlay"
393    }
394
395    fn process(&self, input: &[u8], width: u32, height: u32) -> Vec<u8> {
396        let n = width as usize * height as usize;
397        let expected = n * 4;
398        if input.len() != expected || self.overlay.len() != expected {
399            return input.to_vec();
400        }
401
402        let alpha_clamp = self.alpha.clamp(0.0, 1.0);
403        let mut out = vec![0u8; expected];
404
405        for i in 0..n {
406            let b = i * 4;
407            let bg_r = input[b] as f32;
408            let bg_g = input[b + 1] as f32;
409            let bg_b = input[b + 2] as f32;
410            let bg_a = input[b + 3] as f32 / 255.0;
411
412            let ov_r = self.overlay[b] as f32;
413            let ov_g = self.overlay[b + 1] as f32;
414            let ov_b = self.overlay[b + 2] as f32;
415            let ov_a = (self.overlay[b + 3] as f32 / 255.0) * alpha_clamp;
416
417            // Porter-Duff "over".
418            let out_a = ov_a + bg_a * (1.0 - ov_a);
419            if out_a <= 0.0 {
420                continue;
421            }
422            let inv = 1.0 / out_a;
423            out[b] = ((ov_r * ov_a + bg_r * bg_a * (1.0 - ov_a)) * inv)
424                .round()
425                .clamp(0.0, 255.0) as u8;
426            out[b + 1] = ((ov_g * ov_a + bg_g * bg_a * (1.0 - ov_a)) * inv)
427                .round()
428                .clamp(0.0, 255.0) as u8;
429            out[b + 2] = ((ov_b * ov_a + bg_b * bg_a * (1.0 - ov_a)) * inv)
430                .round()
431                .clamp(0.0, 255.0) as u8;
432            out[b + 3] = (out_a * 255.0).round().clamp(0.0, 255.0) as u8;
433        }
434        out
435    }
436
437    fn input_format(&self) -> PixelFormat {
438        PixelFormat::Rgba8
439    }
440    fn output_format(&self) -> PixelFormat {
441        PixelFormat::Rgba8
442    }
443}
444
445// ---------------------------------------------------------------------------
446// ImageComputePipeline
447// ---------------------------------------------------------------------------
448
449/// A linear sequence of [`ImagePipelineStage`]s.
450///
451/// Before execution, the pipeline validates that each stage's output format
452/// matches the next stage's input format.
453pub struct ImageComputePipeline {
454    stages: Vec<Box<dyn ImagePipelineStage>>,
455    /// Width in pixels for all frames processed by this pipeline.
456    pub width: u32,
457    /// Height in pixels for all frames processed by this pipeline.
458    pub height: u32,
459}
460
461impl ImageComputePipeline {
462    /// Create an empty pipeline for frames of size `width × height`.
463    #[must_use]
464    pub fn new(width: u32, height: u32) -> Self {
465        Self {
466            stages: Vec::new(),
467            width,
468            height,
469        }
470    }
471
472    /// Append a stage to the pipeline.
473    ///
474    /// Validates that the new stage's input format matches the previous
475    /// stage's output format.  If the pipeline is empty any format is
476    /// accepted.
477    ///
478    /// # Errors
479    ///
480    /// Returns an error string describing the format mismatch.
481    pub fn add_stage(&mut self, stage: Box<dyn ImagePipelineStage>) -> Result<(), String> {
482        if let Some(prev) = self.stages.last() {
483            let prev_out = prev.output_format();
484            let next_in = stage.input_format();
485            if prev_out != next_in {
486                return Err(format!(
487                    "Format mismatch between '{}' (output: {}) and '{}' (input: {})",
488                    prev.name(),
489                    prev_out,
490                    stage.name(),
491                    next_in,
492                ));
493            }
494        }
495        self.stages.push(stage);
496        Ok(())
497    }
498
499    /// Run all stages in sequence and return the final output.
500    ///
501    /// The `input` slice is passed to the first stage.  Each subsequent
502    /// stage receives the output of the previous stage.
503    ///
504    /// # Errors
505    ///
506    /// Returns an error string if any stage produces an empty output
507    /// (which indicates a dimension mismatch at runtime).
508    pub fn execute(&self, input: &[u8]) -> Result<Vec<u8>, String> {
509        if self.stages.is_empty() {
510            return Ok(input.to_vec());
511        }
512
513        let mut current: Vec<u8> = input.to_vec();
514        for stage in &self.stages {
515            let next = stage.process(&current, self.width, self.height);
516            if next.is_empty() {
517                return Err(format!(
518                    "Stage '{}' returned empty output (possible dimension mismatch)",
519                    stage.name()
520                ));
521            }
522            current = next;
523        }
524        Ok(current)
525    }
526
527    /// Number of stages in this pipeline.
528    #[must_use]
529    pub fn stage_count(&self) -> usize {
530        self.stages.len()
531    }
532
533    /// Validate format compatibility for all consecutive stage pairs.
534    ///
535    /// # Errors
536    ///
537    /// Returns the first format mismatch found, if any.
538    pub fn validate(&self) -> Result<(), String> {
539        for pair in self.stages.windows(2) {
540            let a = &pair[0];
541            let b = &pair[1];
542            if a.output_format() != b.input_format() {
543                return Err(format!(
544                    "Stage '{}' outputs {} but '{}' expects {}",
545                    a.name(),
546                    a.output_format(),
547                    b.name(),
548                    b.input_format(),
549                ));
550            }
551        }
552        Ok(())
553    }
554
555    /// Names of all stages, in execution order.
556    #[must_use]
557    pub fn stage_names(&self) -> Vec<&str> {
558        self.stages.iter().map(|s| s.name()).collect()
559    }
560}
561
562impl std::fmt::Debug for ImageComputePipeline {
563    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
564        f.debug_struct("ImageComputePipeline")
565            .field("width", &self.width)
566            .field("height", &self.height)
567            .field("stage_count", &self.stages.len())
568            .field("stages", &self.stage_names())
569            .finish()
570    }
571}
572
573// ---------------------------------------------------------------------------
574// Internal helpers
575// ---------------------------------------------------------------------------
576
577#[inline(always)]
578fn luma_bt601(r: u8, g: u8, b: u8) -> u8 {
579    let y = 0.299_f32 * r as f32 + 0.587_f32 * g as f32 + 0.114_f32 * b as f32;
580    y.round().clamp(0.0, 255.0) as u8
581}
582
583fn build_1d_gaussian(radius: usize, sigma: f32) -> Vec<f32> {
584    let len = 2 * radius + 1;
585    let two_sigma_sq = 2.0 * sigma * sigma;
586    let mut k: Vec<f32> = (0..len)
587        .map(|i| {
588            let x = (i as isize - radius as isize) as f32;
589            (-x * x / two_sigma_sq).exp()
590        })
591        .collect();
592    let sum: f32 = k.iter().sum();
593    if sum > 0.0 {
594        k.iter_mut().for_each(|v| *v /= sum);
595    }
596    k
597}
598
599fn gaussian_pass_2d(plane: &[f32], w: usize, h: usize, kernel: &[f32], radius: usize) -> Vec<f32> {
600    // Horizontal pass.
601    let mut tmp = vec![0.0_f32; w * h];
602    for row in 0..h {
603        let rs = row * w;
604        for col in 0..w {
605            let (mut acc, mut wsum) = (0.0_f32, 0.0_f32);
606            for (ki, &kv) in kernel.iter().enumerate() {
607                let src_col = col as isize + ki as isize - radius as isize;
608                if src_col >= 0 && src_col < w as isize {
609                    acc += plane[rs + src_col as usize] * kv;
610                    wsum += kv;
611                }
612            }
613            tmp[rs + col] = if wsum > 0.0 { acc / wsum } else { 0.0 };
614        }
615    }
616    // Vertical pass.
617    let mut out = vec![0.0_f32; w * h];
618    for col in 0..w {
619        for row in 0..h {
620            let (mut acc, mut wsum) = (0.0_f32, 0.0_f32);
621            for (ki, &kv) in kernel.iter().enumerate() {
622                let src_row = row as isize + ki as isize - radius as isize;
623                if src_row >= 0 && src_row < h as isize {
624                    acc += tmp[src_row as usize * w + col] * kv;
625                    wsum += kv;
626                }
627            }
628            out[row * w + col] = if wsum > 0.0 { acc / wsum } else { 0.0 };
629        }
630    }
631    out
632}
633
634// ---------------------------------------------------------------------------
635// Tests
636// ---------------------------------------------------------------------------
637
638#[cfg(test)]
639mod tests {
640    use super::*;
641
642    fn rgba_frame(w: u32, h: u32, r: u8, g: u8, b: u8) -> Vec<u8> {
643        let n = w as usize * h as usize;
644        (0..n).flat_map(|_| [r, g, b, 255]).collect()
645    }
646
647    fn gray_frame(w: u32, h: u32, v: u8) -> Vec<u8> {
648        vec![v; w as usize * h as usize]
649    }
650
651    // --- PixelFormat ---
652
653    #[test]
654    fn test_pixel_format_display() {
655        assert_eq!(PixelFormat::Rgba8.to_string(), "RGBA8");
656        assert_eq!(PixelFormat::Gray8.to_string(), "Gray8");
657    }
658
659    // --- GrayscaleStage ---
660
661    #[test]
662    fn test_grayscale_white() {
663        let stage = GrayscaleStage;
664        let input = rgba_frame(4, 4, 255, 255, 255);
665        let out = stage.process(&input, 4, 4);
666        assert_eq!(out.len(), 16);
667        assert!(
668            out.iter().all(|&v| v > 250),
669            "white should map to ~255 gray"
670        );
671    }
672
673    #[test]
674    fn test_grayscale_format() {
675        let stage = GrayscaleStage;
676        assert_eq!(stage.input_format(), PixelFormat::Rgba8);
677        assert_eq!(stage.output_format(), PixelFormat::Gray8);
678    }
679
680    // --- GaussianBlurStage ---
681
682    #[test]
683    fn test_gaussian_blur_constant_gray() {
684        let stage = GaussianBlurStage { sigma: 1.5 };
685        let input = gray_frame(8, 8, 100);
686        let out = stage.process(&input, 8, 8);
687        assert_eq!(out.len(), 64);
688        for &v in &out {
689            assert!(
690                (v as i32 - 100).unsigned_abs() <= 2,
691                "constant image should remain ~100, got {v}"
692            );
693        }
694    }
695
696    #[test]
697    fn test_gaussian_blur_wrong_size() {
698        let stage = GaussianBlurStage { sigma: 1.0 };
699        let out = stage.process(&[0u8; 3], 4, 4);
700        assert!(
701            out.is_empty(),
702            "wrong-size input should produce empty output"
703        );
704    }
705
706    // --- SobelStage ---
707
708    #[test]
709    fn test_sobel_flat_is_zero() {
710        let stage = SobelStage;
711        let input = gray_frame(8, 8, 128);
712        let out = stage.process(&input, 8, 8);
713        for row in 1..7_usize {
714            for col in 1..7_usize {
715                assert_eq!(out[row * 8 + col], 0, "flat image interior should be 0");
716            }
717        }
718    }
719
720    #[test]
721    fn test_sobel_output_format() {
722        let stage = SobelStage;
723        assert_eq!(stage.input_format(), PixelFormat::Gray8);
724        assert_eq!(stage.output_format(), PixelFormat::Gray8);
725    }
726
727    // --- ThresholdStage ---
728
729    #[test]
730    fn test_threshold_binary() {
731        let stage = ThresholdStage { threshold: 128 };
732        let input = vec![100u8, 128, 200, 50, 128, 255];
733        let out = stage.process(&input, 6, 1);
734        assert_eq!(out, vec![0, 255, 255, 0, 255, 255]);
735    }
736
737    // --- ColorConvertStage ---
738
739    #[test]
740    fn test_color_convert_identity() {
741        let stage = ColorConvertStage {
742            from: PixelFormat::Rgba8,
743            to: PixelFormat::Rgba8,
744        };
745        let input = rgba_frame(2, 2, 10, 20, 30);
746        let out = stage.process(&input, 2, 2);
747        assert_eq!(out, input);
748    }
749
750    #[test]
751    fn test_color_convert_rgba_to_gray() {
752        let stage = ColorConvertStage {
753            from: PixelFormat::Rgba8,
754            to: PixelFormat::Gray8,
755        };
756        let input = rgba_frame(2, 2, 255, 255, 255);
757        let out = stage.process(&input, 2, 2);
758        assert_eq!(out.len(), 4);
759        assert!(out.iter().all(|&v| v > 250));
760    }
761
762    #[test]
763    fn test_color_convert_rgba_to_bgra_swap() {
764        let stage = ColorConvertStage {
765            from: PixelFormat::Rgba8,
766            to: PixelFormat::Bgra8,
767        };
768        let input = vec![255u8, 0, 0, 255]; // red RGBA
769        let out = stage.process(&input, 1, 1);
770        assert_eq!(&out[0..4], &[0u8, 0, 255, 255]); // should be blue BGRA
771    }
772
773    // --- OverlayStage ---
774
775    #[test]
776    fn test_overlay_transparent_overlay() {
777        let bg = rgba_frame(2, 2, 0, 0, 255);
778        let overlay_data: Vec<u8> = (0..4).flat_map(|_| [255u8, 0, 0, 0u8]).collect(); // fully transparent red
779        let stage = OverlayStage {
780            overlay: overlay_data,
781            alpha: 1.0,
782        };
783        let out = stage.process(&bg, 2, 2);
784        // Fully transparent overlay → output ≈ bg.
785        assert_eq!(&out[0..3], &[0u8, 0, 255]);
786    }
787
788    // --- ImageComputePipeline ---
789
790    #[test]
791    fn test_pipeline_empty_passthrough() {
792        let pipeline = ImageComputePipeline::new(4, 4);
793        let input = gray_frame(4, 4, 77);
794        let out = pipeline.execute(&input).expect("execute");
795        assert_eq!(out, input);
796    }
797
798    #[test]
799    fn test_pipeline_add_stage_format_mismatch() {
800        let mut pipeline = ImageComputePipeline::new(4, 4);
801        pipeline
802            .add_stage(Box::new(GrayscaleStage))
803            .expect("add grayscale");
804        // GrayscaleStage outputs Gray8, but GrayscaleStage itself accepts Rgba8 — mismatch.
805        let result = pipeline.add_stage(Box::new(GrayscaleStage));
806        assert!(result.is_err(), "should detect format mismatch");
807    }
808
809    #[test]
810    fn test_pipeline_validate_ok() {
811        let mut pipeline = ImageComputePipeline::new(4, 4);
812        pipeline
813            .add_stage(Box::new(GrayscaleStage))
814            .expect("grayscale");
815        pipeline.add_stage(Box::new(SobelStage)).expect("sobel");
816        assert!(pipeline.validate().is_ok());
817    }
818
819    #[test]
820    fn test_pipeline_stage_count() {
821        let mut pipeline = ImageComputePipeline::new(4, 4);
822        assert_eq!(pipeline.stage_count(), 0);
823        pipeline.add_stage(Box::new(GrayscaleStage)).expect("add");
824        assert_eq!(pipeline.stage_count(), 1);
825        pipeline.add_stage(Box::new(SobelStage)).expect("add");
826        assert_eq!(pipeline.stage_count(), 2);
827    }
828
829    #[test]
830    fn test_pipeline_full_rgba_to_binary() {
831        // RGBA → Gray8 → Threshold
832        let mut pipeline = ImageComputePipeline::new(4, 4);
833        pipeline.add_stage(Box::new(GrayscaleStage)).expect("gray");
834        pipeline
835            .add_stage(Box::new(ThresholdStage { threshold: 128 }))
836            .expect("thresh");
837
838        let input = rgba_frame(4, 4, 200, 200, 200); // bright grey → luma ≈ 200
839        let out = pipeline.execute(&input).expect("execute");
840        assert_eq!(out.len(), 16);
841        assert!(
842            out.iter().all(|&v| v == 255),
843            "all pixels should be above threshold"
844        );
845    }
846}