1use crate::{Filter, PaddingType};
4use photon_rs::transform::padding_uniform as uniform;
5use photon_rs::PhotonImage;
6use photon_rs::Rgba;
7
8fn convolve(img_padded: &PhotonImage, filter: &Filter, width_conv: u32, height_conv: u32, stride: u32) -> PhotonImage {
9 let raw_pixel_padded = img_padded.get_raw_pixels();
10 let width_padded = img_padded.get_width() as usize;
11 let height_padded = img_padded.get_height() as usize;
12 let mut img_conv = vec![];
13
14 let filter_width = filter.width();
15 let filter_height = filter.height();
16
17 let mut pixel = 0_usize;
18 let image_end = (width_padded * height_padded * 4) as usize;
19 let step = 4 * stride as usize;
20
21 while pixel < image_end - 4 {
22 if pixel != 0 && ((pixel / 4) % width_padded) > (width_padded - filter_width) {
23 pixel = ((pixel / 4) / width_padded + stride as usize) * width_padded * 4;
24
25 if (pixel / 4) / width_padded + filter_height > height_padded {
26 break;
27 }
28 }
29 let mut img_conv_r: f32 = 0_f32;
30 let mut img_conv_g: f32 = 0_f32;
31 let mut img_conv_b: f32 = 0_f32;
32
33 for x in 0..filter_width {
34 for y in 0..filter_height {
35 let kernel_element_val = filter
36 .get_element(x, y)
37 .expect("[ERROR]: Tried to access out-of-bounds value in the filter");
38 let img_pixel_r = raw_pixel_padded[x * width_padded * 4 + pixel + y * 4];
39 let img_pixel_g = raw_pixel_padded[x * width_padded * 4 + pixel + y * 4 + 1];
40 let img_pixel_b = raw_pixel_padded[x * width_padded * 4 + pixel + y * 4 + 2];
41
42 img_conv_r += img_pixel_r as f32 * kernel_element_val;
43 img_conv_g += img_pixel_g as f32 * kernel_element_val;
44 img_conv_b += img_pixel_b as f32 * kernel_element_val;
45 }
46 }
47
48 img_conv_r = f32::clamp(img_conv_r, 0.0, 255.0);
49 img_conv_g = f32::clamp(img_conv_g, 0.0, 255.0);
50 img_conv_b = f32::clamp(img_conv_b, 0.0, 255.0);
51
52 img_conv.push(img_conv_r as u8);
53 img_conv.push(img_conv_g as u8);
54 img_conv.push(img_conv_b as u8);
55 img_conv.push(255_u8);
56
57 pixel += step;
58 }
59
60 for _ in (img_conv.len()..(width_conv * height_conv * 4) as usize).step_by(1) {
61 img_conv.push(255_u8);
62 img_conv.push(255_u8);
63 img_conv.push(255_u8);
64 img_conv.push(255_u8);
65 }
66
67 #[cfg(debug_assertions)]
68 println!("Convolution done...");
69
70 PhotonImage::new(img_conv, width_conv, height_conv)
71}
72
73fn adjust_convolution_params(
74 img: &PhotonImage,
75 img_padded: &PhotonImage,
76 filter: &Filter,
77 stride: u32,
78 padding: PaddingType,
79) -> PhotonImage {
80 let mut img_conv_width: u32;
81 let mut img_conv_height: u32;
82
83 match padding {
84 PaddingType::UNIFORM(pad_amt) => {
85 img_conv_width = img.get_width() - filter.width() as u32 + 2 * pad_amt;
86 if img_conv_width % stride != 0 {
87 eprintln!("[WARNING]: stride value not suitable. Convolution may fail.");
88 }
89 img_conv_width /= stride;
90 img_conv_width += 1;
91
92 img_conv_height = img.get_height() - filter.height() as u32 + 2 * pad_amt;
93 if img_conv_height % stride != 0 {
94 eprintln!("[WARNING]: stride value not suitable. Convolution may fail.");
95 }
96 img_conv_height /= stride;
97 img_conv_height += 1;
98 }
99
100 PaddingType::NONE => {
101 img_conv_width = img.get_width() - filter.width() as u32;
102 if img_conv_width % stride != 0 {
103 eprintln!("[WARNING]: stride value not suitable. Convolution may fail.");
104 }
105 img_conv_width /= stride;
106 img_conv_width += 1;
107
108 img_conv_height = img.get_height() - filter.height() as u32;
109 if img_conv_height % stride != 0 {
110 eprintln!("[WARNING]: stride value not suitable. Convolution may fail.");
111 }
112 img_conv_height /= stride;
113 img_conv_height += 1;
114 }
115 };
116
117 convolve(img_padded, filter, img_conv_width, img_conv_height, stride)
118}
119
120pub fn convolution(img: &PhotonImage, filter: Filter, stride: u32, padding: PaddingType) -> PhotonImage {
141 match stride {
142 0 => {
143 eprintln!("[ERROR]: Stride provided = 0");
144 std::process::exit(1);
145 }
146
147 1 => match padding {
148 PaddingType::UNIFORM(padding_amt) => {
149 let padding_color = Rgba::new(0, 0, 0, 255);
150 let img_padded = uniform(&img, padding_amt, padding_color);
151 adjust_convolution_params(img, &img_padded, &filter, stride, padding)
152 }
153 PaddingType::NONE => {
154 let img_padded = img.clone();
155 adjust_convolution_params(img, &img_padded, &filter, stride, padding)
156 }
157 },
158 _ => match padding {
159 PaddingType::UNIFORM(padding_amt) => {
160 let padding_color = Rgba::new(0, 0, 0, 255);
161 let img_padded = uniform(&img, padding_amt, padding_color);
162 adjust_convolution_params(img, &img_padded, &filter, stride, padding)
163 }
164 PaddingType::NONE => {
165 let img_padded = img.clone();
166 adjust_convolution_params(img, &img_padded, &filter, stride, padding)
167 }
168 },
169 }
170}