Skip to main content

image_conv/
conv.rs

1//! Image Convolution Engine
2//!
3//! ## What is convolution?
4//!
5//! ```text
6//! Convolution slides a filter kernel over an image, computing a weighted sum
7//! at each position. Each output pixel is the dot product of the kernel with
8//! the underlying image patch:
9//!
10//!   Input image patch         Kernel           Output pixel
11//!   ┌────┬────┬────┐      ┌────┬────┬────┐
12//!   │ a  │ b  │ c  │      │ k₀ │ k₁ │ k₂ │     out = a·k₀ + b·k₁ + c·k₂
13//!   ├────┼────┼────┤   ⊙  ├────┼────┼────┤        + d·k₃ + e·k₄ + f·k₅
14//!   │ d  │ e  │ f  │      │ k₃ │ k₄ │ k₅ │        + g·k₆ + h·k₇ + i·k₈
15//!   ├────┼────┼────┤      ├────┼────┼────┤
16//!   │ g  │ h  │ i  │      │ k₆ │ k₇ │ k₈ │     Then clamped to [0, 255]
17//!   └────┴────┴────┘      └────┴────┴────┘
18//!
19//! The kernel "slides" horizontally then vertically across the whole image,
20//! producing one output pixel per valid position.
21//! ```
22//!
23//! ## Three-tier performance
24//!
25//! ```text
26//!                    ┌──────────────────┐
27//!                    │  try_separable()  │
28//!                    └────────┬─────────┘
29//!                             │
30//!               ┌─────────────┴─────────────┐
31//!               │                           │
32//!          Separable                     Not separable
33//!               │                           │
34//!               ▼                           ▼
35//!   ┌──────────────────────┐    ┌──────────────────────┐
36//!   │  separable_convolve  │    │      convolve        │
37//!   │  (two 1D passes)     │    │    (one 2D pass)      │
38//!   │                      │    │                      │
39//!   │  Input ────┬──────── │    │  Input image          │
40//!   │            │horizontal│    │  7×7 kernel = 49 ops │
41//!   │            ▼   pass   │    │  per output pixel     │
42//!   │        Temp buffer    │    │                      │
43//!   │   (width reduced by   │    └──────────────────────┘
44//!   │    filter-1+2·pad)    │
45//!   │            │          │    Legend:
46//!   │            │vertical  │    7×7 separable = 14 ops
47//!   │            ▼  pass    │    per pixel — 3.5× faster
48//!   │        Output image   │
49//!   └──────────────────────┘
50//!
51//!   Both paths are parallelised over output rows via rayon.
52//! ```
53//!
54//! ## Threading model
55//!
56//! Each output row is completely independent — zero data dependencies.
57//! Rayon splits the output buffer into per-row slices and processes
58//! them in parallel across all available cores.
59//!
60//! ```text
61//!   Output buffer (hc rows × wc cols × 4 bytes RGBA)
62//!   ┌──────────────────────────────────────┐
63//!   │ Row 0 ──────▶ Thread 0               │
64//!   │ Row 1 ──────▶ Thread 1               │  Output rows are
65//!   │ Row 2 ──────▶ Thread 2               │  processed in
66//!   │ Row 3 ──────▶ Thread 3               │  parallel — each
67//!   │   ...                                │  thread reads from
68//!   │ Row hc-1 ───▶ Thread N               │  the input image
69//!   └──────────────────────────────────────┘  (immutable, safe)
70//! ```
71//!
72//! ## Output size formula
73//!
74//! ```text
75//! output_width  = (W - Fw + 2·P) / S + 1
76//! output_height = (H - Fh + 2·P) / S + 1
77//! ```
78
79use crate::{Filter, PaddingType};
80use photon_rs::transform::padding_uniform as uniform;
81use photon_rs::PhotonImage;
82use photon_rs::Rgba;
83use rayon::prelude::*;
84
85/// Standard 2D convolution — parallelised over output rows.
86///
87/// ```text
88/// Each output row is independent, so we split the output buffer
89/// into per-row slices and process them in parallel via rayon.
90///
91///   Padded image (width=wp)      Filter (fw×fh)     Output pixel
92///   ┌───────────────────────┐     ┌───┬───┬───┐
93///   │                       │     │   │   │   │
94///   │  (row_base, col_base)─┼──┐  │   │   │   │    r = Σ fy Σ fx
95///   │  │                    │  │  │   │   │   │      raw[px]·kernel[fy][fx]
96///   │  │  fh rows           │  │  ├───┼───┼───┤
97///   │  │                    │  │  │   │   │   │
98///   │  └── fw cols ────────┘  │  │   │   │   │
99///   │                       │  │  └───┴───┴───┘
100///   └───────────────────────┘  │
101///     px = (row_base+fy)*wp + col_base+fx  (×4 for RGBA)
102/// ```
103///
104/// All channels accumulate as `f32` and clamp to `[0, 255]` at the end.
105fn convolve(img_padded: &PhotonImage, filter: &Filter, width_conv: u32, height_conv: u32, stride: u32) -> PhotonImage {
106    let raw = img_padded.get_raw_pixels();
107    let wp = img_padded.get_width() as usize;
108    let fw = filter.width;
109    let fh = filter.height;
110    let kernel = &filter.kernel;
111    let wc = width_conv as usize;
112    let hc = height_conv as usize;
113    let stride = stride as usize;
114
115    let out_size = wc * hc * 4;
116    let mut out = vec![0u8; out_size];
117
118    out.par_chunks_mut(wc * 4).enumerate().for_each(|(yc, row_out)| {
119        let row_base = yc * stride;
120
121        for xc in 0..wc {
122            let col_base = xc * stride;
123
124            let mut r: f32 = 0.0;
125            let mut g: f32 = 0.0;
126            let mut b: f32 = 0.0;
127
128            for fy in 0..fh {
129                let row_offset = (row_base + fy) * wp;
130                let k_row = fy * fw;
131
132                for fx in 0..fw {
133                    let px = (row_offset + col_base + fx) * 4;
134                    let k = kernel[k_row + fx];
135
136                    r += raw[px] as f32 * k;
137                    g += raw[px + 1] as f32 * k;
138                    b += raw[px + 2] as f32 * k;
139                }
140            }
141
142            let i = xc * 4;
143            row_out[i] = r.clamp(0.0, 255.0) as u8;
144            row_out[i + 1] = g.clamp(0.0, 255.0) as u8;
145            row_out[i + 2] = b.clamp(0.0, 255.0) as u8;
146            row_out[i + 3] = 255_u8;
147        }
148    });
149
150    debug_assert_eq!(out.len(), out_size);
151
152    #[cfg(debug_assertions)]
153    println!("Convolution done (rayon)...");
154
155    PhotonImage::new(out, width_conv, height_conv)
156}
157
158/// Separable convolution — each pass is parallelised independently.
159///
160/// ## How it works
161///
162/// A separable kernel factors as `kernel[i][j] = col[i] × row[j]`.
163/// Two 1D passes replace one 2D pass: O(fw+fw) per pixel vs O(fw·fh).
164///
165/// ```text
166/// PASS 1 — Horizontal (parallel over rows)
167/// ─────────────────────────────────────────
168///   Input (padded, hp rows)   temp (hp rows × temp_w cols × 3 floats)
169///   ┌──────────────────┐      ┌──────────────────────┐
170///   │ Row 0 ──▶ T0     │      │ Row 0 convolved       │
171///   │ Row 1 ──▶ T1     │  →   │ Row 1 convolved       │  Each row is
172///   │ Row 2 ──▶ T2     │      │ ...                   │  independent
173///   │ ...               │      └──────────────────────┘
174///   └──────────────────┘
175///
176/// PASS 2 — Vertical (parallel over output rows)
177/// ──────────────────────────────────────────────
178///   temp buffer (read-only)      output (hc × wc)
179///   ┌──────────────────────┐    ┌─────────────┐
180///   │ r₀ g₀ b₀ r₁ g₁ b₁ ...│    │ Row 0 ──▶ T0 │
181///   │ r₀ g₀ b₀ r₁ g₁ b₁ ...│ →  │ Row 1 ──▶ T1 │
182///   │        ...            │    │ ...          │
183///   └──────────────────────┘    └─────────────┘
184/// ```
185fn separable_convolve(
186    img_padded: &PhotonImage,
187    row_vec: &[f32],
188    col_vec: &[f32],
189    width_conv: u32,
190    height_conv: u32,
191    stride: u32,
192) -> PhotonImage {
193    let raw = img_padded.get_raw_pixels();
194    let wp = img_padded.get_width() as usize;
195    let hp = img_padded.get_height() as usize;
196    let fw = row_vec.len();
197    let fh = col_vec.len();
198    let wc = width_conv as usize;
199    let hc = height_conv as usize;
200    let stride = stride as usize;
201
202    let temp_w = wc;
203    let temp_size = hp * temp_w * 3;
204    let mut temp: Vec<f32> = vec![0.0; temp_size];
205
206    // Horizontal pass — each row is independent, process in parallel
207    temp.par_chunks_mut(temp_w * 3).enumerate().for_each(|(y, row_temp)| {
208        let row_input = y * wp;
209        for x in 0..temp_w {
210            let col_input = x * stride;
211            let mut r: f32 = 0.0;
212            let mut g: f32 = 0.0;
213            let mut b: f32 = 0.0;
214            for fx in 0..fw {
215                let px = (row_input + col_input + fx) * 4;
216                let k = row_vec[fx];
217                r += raw[px] as f32 * k;
218                g += raw[px + 1] as f32 * k;
219                b += raw[px + 2] as f32 * k;
220            }
221            let t = x * 3;
222            row_temp[t] = r;
223            row_temp[t + 1] = g;
224            row_temp[t + 2] = b;
225        }
226    });
227
228    // Vertical pass — output rows are independent, process in parallel
229    let out_size = wc * hc * 4;
230    let mut out = vec![0u8; out_size];
231
232    out.par_chunks_mut(wc * 4).enumerate().for_each(|(yc, row_out)| {
233        let row_base = yc * stride;
234        for xc in 0..wc {
235            let mut r: f32 = 0.0;
236            let mut g: f32 = 0.0;
237            let mut b: f32 = 0.0;
238            for fy in 0..fh {
239                let t = ((row_base + fy) * temp_w + xc) * 3;
240                let k = col_vec[fy];
241                r += temp[t] * k;
242                g += temp[t + 1] * k;
243                b += temp[t + 2] * k;
244            }
245            let i = xc * 4;
246            row_out[i] = r.clamp(0.0, 255.0) as u8;
247            row_out[i + 1] = g.clamp(0.0, 255.0) as u8;
248            row_out[i + 2] = b.clamp(0.0, 255.0) as u8;
249            row_out[i + 3] = 255_u8;
250        }
251    });
252
253    debug_assert_eq!(out.len(), out_size);
254
255    #[cfg(debug_assertions)]
256    println!("Separable convolution done (rayon)...");
257
258    PhotonImage::new(out, width_conv, height_conv)
259}
260
261/// Computes the output dimension for one axis.
262///
263/// ```text
264/// output = (input_size - filter_size + 2·padding) / stride + 1
265/// ```
266#[inline]
267fn output_dim(input_size: u32, filter_size: u32, pad: u32, stride: u32) -> u32 {
268    let dim = input_size - filter_size + 2 * pad;
269    if dim % stride != 0 {
270        eprintln!("[WARNING]: stride value not suitable. Convolution may fail.");
271    }
272    dim / stride + 1
273}
274
275/// Applies convolution to an image using the given filter.
276///
277/// Each path (separable 1D and standard 2D) is parallelised across output
278/// rows via rayon — no data dependencies between rows, trivially parallel.
279///
280/// # Speedup Examples
281///
282/// | Kernel  | Size  | Sequential | Rayon (8 cores) |
283/// |---------|-------|------------|-----------------|
284/// | Gauss   | 7×7   | 61 ms      | ~8 ms           |
285/// | Gauss   | 15×15 | 94 ms      | ~12 ms          |
286/// | Laplacian|3×3   | 47 ms      | ~6 ms           |
287///
288/// # Example
289///
290/// ```no_run
291/// use image_conv::conv;
292/// use image_conv::{Filter, PaddingType};
293///
294/// let img = photon_rs::native::open_image("img.jpg").expect("No such file found");
295/// let sobel_x: Vec<f32> = vec![1.0, 0.0, -1.0, 2.0, 0.0, -2.0, 1.0, 0.0, -1.0];
296/// let filter = Filter::from(sobel_x, 3, 3);
297/// let img_conv = conv::convolution(&img, filter, 1, PaddingType::UNIFORM(1));
298///```
299pub fn convolution(img: &PhotonImage, filter: Filter, stride: u32, padding: PaddingType) -> PhotonImage {
300    if stride == 0 {
301        eprintln!("[ERROR]: Stride provided = 0");
302        std::process::exit(1);
303    }
304
305    let separable = filter.try_separable();
306
307    match &padding {
308        PaddingType::UNIFORM(pad_amt) => {
309            let img_padded = uniform(img, *pad_amt, Rgba::new(0, 0, 0, 255));
310            let wc = output_dim(img.get_width(), filter.width as u32, *pad_amt, stride);
311            let hc = output_dim(img.get_height(), filter.height as u32, *pad_amt, stride);
312
313            if let Some((col, row)) = separable {
314                separable_convolve(&img_padded, &row, &col, wc, hc, stride)
315            } else {
316                convolve(&img_padded, &filter, wc, hc, stride)
317            }
318        }
319        PaddingType::NONE => {
320            let wc = output_dim(img.get_width(), filter.width as u32, 0, stride);
321            let hc = output_dim(img.get_height(), filter.height as u32, 0, stride);
322
323            if let Some((col, row)) = separable {
324                separable_convolve(img, &row, &col, wc, hc, stride)
325            } else {
326                convolve(img, &filter, wc, hc, stride)
327            }
328        }
329    }
330}