sharpened_bilinear/
lib.rs

1use image::{DynamicImage, RgbImage, RgbaImage};
2
3fn srgb_to_lrgb(v: f32) -> f32 {
4    if !(v > 0.0) {
5        0.0
6    } else if v <= 0.04045 {
7        v / 12.92
8    } else if v < 1.0 {
9        ((v + 0.055) / 1.055).powf(2.4)
10    } else {
11        1.0
12    }
13}
14
15fn lrgb_to_srgb8(v: f32) -> u8 {
16    let v = if !(v > 0.0) {
17        0.0
18    } else if v <= 0.0031308 {
19        12.92 * v
20    } else if v < 1.0 {
21        1.055 * v.powf(1.0 / 2.4) - 0.055
22    } else {
23        1.0
24    };
25    (v * 255.0 + 0.5) as u8
26}
27
28/// Processing-friendly image structure
29/// with separate channels in ARGB or RGB order
30/// in linear color space with alpha premultiplied.
31/// 
32/// If you are not using the [image] library,
33/// you will have to implement the conversion
34/// to this structure and back to your image
35/// format yourself.
36#[derive(Clone, Debug)]
37pub struct FullImage {
38    pub width: usize,
39    pub height: usize,
40    pub has_alpha: bool,
41    pub channels: Vec<ImageChannel>,
42}
43
44impl FullImage {
45    pub fn new(width: usize, height: usize, channel_count: usize, has_alpha: bool) -> Self {
46        Self {
47            width,
48            height,
49            has_alpha,
50            channels: vec![ImageChannel::new(width, height); channel_count],
51        }
52    }
53}
54
55impl From<&DynamicImage> for FullImage {
56    fn from(input: &DynamicImage) -> Self {
57        let (width, height) = (input.width() as usize, input.height() as usize);
58        let has_alpha = input.color().has_alpha();
59        let channel_count = if has_alpha { 4 } else { 3 };
60        let mut output = FullImage::new(width, height, channel_count, has_alpha);
61        if has_alpha {
62            let data = input.to_rgba32f().into_vec();
63            for i in 0..width * height {
64                let r = data[i * 4 + 0];
65                let g = data[i * 4 + 1];
66                let b = data[i * 4 + 2];
67                let a = data[i * 4 + 3];
68                output.channels[0].data[i] = a;
69                output.channels[1].data[i] = srgb_to_lrgb(r) * a;
70                output.channels[2].data[i] = srgb_to_lrgb(g) * a;
71                output.channels[3].data[i] = srgb_to_lrgb(b) * a;
72            }
73        } else {
74            let data = input.to_rgb32f().into_vec();
75            for i in 0..width * height {
76                let r = data[i * 3 + 0];
77                let g = data[i * 3 + 1];
78                let b = data[i * 3 + 2];
79                output.channels[0].data[i] = srgb_to_lrgb(r);
80                output.channels[1].data[i] = srgb_to_lrgb(g);
81                output.channels[2].data[i] = srgb_to_lrgb(b);
82            }
83        }
84        output
85    }
86}
87
88impl From<&FullImage> for DynamicImage {
89    fn from(input: &FullImage) -> Self {
90        if input.channels.len() == 3 && !input.has_alpha {
91            let mut buf = vec![0; input.width * input.height * 3];
92            for i in 0..input.width * input.height {
93                buf[i * 3 + 0] = lrgb_to_srgb8(input.channels[0].data[i]);
94                buf[i * 3 + 1] = lrgb_to_srgb8(input.channels[1].data[i]);
95                buf[i * 3 + 2] = lrgb_to_srgb8(input.channels[2].data[i]);
96            }
97            DynamicImage::ImageRgb8(
98                RgbImage::from_raw(input.width as u32, input.height as u32, buf).unwrap(),
99            )
100        } else if input.channels.len() == 4 && input.has_alpha {
101            let mut buf = vec![0; input.width * input.height * 4];
102            for i in 0..input.width * input.height {
103                let a = input.channels[0].data[i] + f32::EPSILON;
104                let r = input.channels[1].data[i];
105                let g = input.channels[2].data[i];
106                let b = input.channels[3].data[i];
107                buf[i * 4 + 0] = lrgb_to_srgb8(r / a);
108                buf[i * 4 + 1] = lrgb_to_srgb8(g / a);
109                buf[i * 4 + 2] = lrgb_to_srgb8(b / a);
110                buf[i * 4 + 3] = (a * 255.0 + 0.5) as u8;
111            }
112            DynamicImage::ImageRgba8(
113                RgbaImage::from_raw(input.width as u32, input.height as u32, buf).unwrap(),
114            )
115        } else {
116            panic!("This is not ARGB or RGB image");
117        }
118    }
119}
120
121impl From<FullImage> for DynamicImage {
122    fn from(input: FullImage) -> Self {
123        (&input).into()
124    }
125}
126
127/// Separate color channel, dimensions must match the entire image
128#[derive(Clone, Debug)]
129pub struct ImageChannel {
130    pub width: usize,
131    pub height: usize,
132    pub data: Vec<f32>,
133}
134
135#[derive(Copy, Clone, Default)]
136struct Area([f32; 9]);
137
138impl Area {
139    fn map<F: FnMut(f32) -> f32>(&self, f: &mut F) -> Self {
140        let mut out = Self::default();
141        for i in 0..9 {
142            out.0[i] = f(self.0[i]);
143        }
144        out
145    }
146    fn zip_map<F: FnMut(f32, f32) -> f32>(&self, other: &Self, f: &mut F) -> Self {
147        let mut out = Self::default();
148        for i in 0..9 {
149            out.0[i] = f(self.0[i], other.0[i]);
150        }
151        out
152    }
153    fn center(&self) -> f32 {
154        self.0[4]
155    }
156    fn borders(&self) -> f32 {
157        self.0[1] + self.0[3] + self.0[5] + self.0[7]
158    }
159    fn corners(&self) -> f32 {
160        self.0[0] + self.0[2] + self.0[6] + self.0[8]
161    }
162    fn integral(&self) -> f32 {
163        (self.center() * 36.0 + self.borders() * 6.0 + self.corners()) / 64.0
164    }
165    fn get(&self, x: usize, y: usize) -> f32 {
166        self.0[y * 3 + x]
167    }
168}
169
170impl ImageChannel {
171    pub fn new(width: usize, height: usize) -> Self {
172        ImageChannel {
173            width,
174            height,
175            data: vec![0.0; width * height],
176        }
177    }
178    pub fn get(&self, x: usize, y: usize) -> f32 {
179        self.data[y * self.width + x]
180    }
181    fn get_area(&self, x: usize, y: usize) -> Area {
182        let x = [x.saturating_sub(1), x, (x + 1).min(self.width - 1)];
183        let y = [y.saturating_sub(1), y, (y + 1).min(self.height - 1)];
184
185        Area([
186            self.get(x[0], y[0]),
187            self.get(x[1], y[0]),
188            self.get(x[2], y[0]),
189            self.get(x[0], y[1]),
190            self.get(x[1], y[1]),
191            self.get(x[2], y[1]),
192            self.get(x[0], y[2]),
193            self.get(x[1], y[2]),
194            self.get(x[2], y[2]),
195        ])
196    }
197    pub fn set(&mut self, x: usize, y: usize, value: f32) {
198        self.data[y * self.width + x] = value
199    }
200}
201
202fn sharp(
203    input_channel: &ImageChannel,
204    target: &ImageChannel,
205    alpha_channel: Option<&ImageChannel>,
206    output_channel: &mut ImageChannel,
207) {
208    for y in 0..input_channel.height {
209        for x in 0..input_channel.width {
210            let area = input_channel.get_area(x, y);
211
212            let target = target.get(x, y);
213            let borders = area.borders();
214            let corners = area.corners();
215            let result = (target * 64.0 - borders * 6.0 - corners) / 36.0;
216
217            let max = if let Some(alpha_channel) = alpha_channel {
218                alpha_channel.get(x, y)
219            } else {
220                1.0
221            };
222            let result = result.clamp(0.0, max);
223
224            output_channel.set(x, y, result);
225        }
226    }
227}
228
229fn adjust_3x3(target: f32, area: Area, alpha: Area) -> Area {
230    let current = area.integral();
231    if current > target {
232        let k = target / current;
233        return area.map(&mut |v| v * k);
234    }
235
236    let max = alpha.integral();
237    if max > current {
238        let k = (target - current) / (max - current);
239        return area.zip_map(&alpha, &mut |v, a| v * (1.0 - k) + a * k);
240    }
241
242    area
243}
244
245#[derive(Clone, Debug)]
246struct Segment {
247    output_index: usize,
248    interpolation_factor: f32,
249    size: f32,
250}
251
252#[derive(Clone, Debug)]
253struct IntersectedPixels([Vec<Segment>; 2]);
254
255impl IntersectedPixels {
256    fn new(old: usize, new: usize, inp_idx: usize) -> Self {
257        let old_div_new = old as f32 / new as f32;
258        let new_div_old = 1.0 / old_div_new;
259
260        let before_center = {
261            // (idx * new / old).floor()
262            let start = (inp_idx * new) / old;
263            // ((idx + 0.5) * new / old).ceil()
264            // ((idx + 1) * new / (old * 2)).ceil()
265            let end = ((inp_idx * 2 + 1) * new).div_ceil(old * 2);
266
267            (start..end)
268                .map(|out_idx| {
269                    let segment_start = out_idx as f32 * old_div_new;
270                    let segment_end = segment_start + old_div_new;
271
272                    let segment_start = segment_start.max(inp_idx as f32);
273                    let segment_end = segment_end.min(inp_idx as f32 + 0.5);
274
275                    let size = (segment_end - segment_start) * new_div_old;
276
277                    let center = (segment_start + segment_end) * 0.5;
278                    let interpolation_factor = center + 0.5 - inp_idx as f32;
279
280                    Segment {
281                        output_index: out_idx,
282                        interpolation_factor,
283                        size,
284                    }
285                })
286                .collect()
287        };
288
289        let after_center = {
290            // ((idx + 0.5) * new / old).floor()
291            // ((idx + 1) * new / (old * 2)).floor()
292            let start = ((inp_idx * 2 + 1) * new) / (old * 2);
293            // ((idx + 1) * new / old).ceil()
294            let end = ((inp_idx + 1) * new).div_ceil(old);
295
296            (start..end)
297                .map(|out_idx| {
298                    let segment_start = out_idx as f32 * old_div_new; // 0
299                    let segment_end = segment_start + old_div_new; // 2
300
301                    let segment_start = segment_start.max(inp_idx as f32 + 0.5);
302                    let segment_end = segment_end.min(inp_idx as f32 + 1.0);
303
304                    let size = (segment_end - segment_start) * new_div_old;
305
306                    let center = (segment_start + segment_end) * 0.5;
307                    let lerp_k = center - 0.5 - inp_idx as f32;
308
309                    Segment {
310                        output_index: out_idx,
311                        interpolation_factor: lerp_k,
312                        size,
313                    }
314                })
315                .collect()
316        };
317
318        IntersectedPixels([before_center, after_center])
319    }
320}
321
322fn bilinear_interpolation(a: f32, b: f32, c: f32, d: f32, tx: f32, ty: f32) -> f32 {
323    let ab = (b - a) * tx + a;
324    let cd = (d - c) * tx + c;
325    (cd - ab) * ty + ab
326}
327
328/// Performs image resampling.
329/// 
330/// `input` - something implementing the [Into]<[FullImage]> trait,
331/// this trait is implemented for [image::DynamicImage];
332/// 
333/// `width` and `height` are the dimensions of the output image, must not be zero;
334/// 
335/// the output type is [FullImage], so the `into()` method must be called to convert it to [image::DynamicImage].
336///
337/// Usage:
338/// ```no_run
339/// # let (width, height) = (1, 1);
340/// let input_image = image::open("input.png").unwrap();
341/// let resized_image: image::DynamicImage =
342///     sharpened_bilinear::resize(&input_image, width, height).into();
343/// resized_image.save("output.png").unwrap();
344/// ```
345pub fn resize(input: impl Into<FullImage>, width: usize, height: usize) -> FullImage {
346    assert!(width > 0);
347    assert!(height > 0);
348    
349    let input = input.into();
350    let mut sharpened_image = input.clone();
351    let mut temp_channel = sharpened_image.channels[0].clone();
352
353    let steps = 8;
354    // independed channels
355    for z in 0..input.channels.len() {
356        if input.has_alpha && z != 0 {
357            continue;
358        }
359        for _ in 0..steps {
360            let target = &input.channels[z];
361            let current_channel = &mut sharpened_image.channels[z];
362            sharp(current_channel, target, None, &mut temp_channel);
363            sharp(&temp_channel, target, None, current_channel);
364        }
365    }
366    // alpha depended channels
367    if input.has_alpha {
368        for z in 1..input.channels.len() {
369            for _ in 0..steps {
370                let target = &input.channels[z];
371                let current_channel = &mut sharpened_image.channels[z];
372                let max = Some(&input.channels[0]);
373                sharp(current_channel, target, max, &mut temp_channel);
374                sharp(&temp_channel, target, max, current_channel);
375            }
376        }
377    }
378
379    let intersected_by_x: Vec<_> = (0..input.width)
380        .map(|inp_idx| IntersectedPixels::new(input.width, width, inp_idx))
381        .collect();
382
383    let intersected_by_y: Vec<_> = (0..input.height)
384        .map(|inp_idx| IntersectedPixels::new(input.height, height, inp_idx))
385        .collect();
386
387    let mut output_image = FullImage::new(width, height, input.channels.len(), input.has_alpha);
388
389    for y in 0..input.height {
390        for x in 0..input.width {
391            let mut alpha_area = Area([1.0; 9]);
392
393            let intersected_x = &intersected_by_x[x];
394            let intersected_y = &intersected_by_y[y];
395
396            for z in 0..input.channels.len() {
397                let area = sharpened_image.channels[z].get_area(x, y);
398                let target = input.channels[z].get(x, y);
399                let area = adjust_3x3(target, area, alpha_area);
400                if input.has_alpha && z == 0 {
401                    alpha_area = area;
402                }
403                for h in 0..2 {
404                    for w in 0..2 {
405                        for y_segment in intersected_y.0[h].iter() {
406                            for x_segment in intersected_x.0[w].iter() {
407                                let result = bilinear_interpolation(
408                                    area.get(0 + w, 0 + h),
409                                    area.get(1 + w, 0 + h),
410                                    area.get(0 + w, 1 + h),
411                                    area.get(1 + w, 1 + h),
412                                    x_segment.interpolation_factor,
413                                    y_segment.interpolation_factor,
414                                ) * y_segment.size
415                                    * x_segment.size;
416                                let x = x_segment.output_index;
417                                let y = y_segment.output_index;
418                                let old = output_image.channels[z].get(x, y);
419                                output_image.channels[z].set(x, y, old + result);
420                            }
421                        }
422                    }
423                }
424            }
425        }
426    }
427
428    output_image
429}