imgproc_rs/filter/
median.rs

1use crate::error;
2use crate::error::{ImgProcResult, ImgProcError};
3use crate::image::{Image, BaseImage};
4
5use std::cmp::{Ordering, Reverse};
6
7/// Applies a median filter, where each output pixel is the median of the pixels in a
8/// `(2 * radius + 1) x (2 * radius + 1)` kernel in the input image. Based on Ben Weiss' partial
9/// histogram method, using a tier radix of 2. A detailed description can be found
10/// [here](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.93.1608&rep=rep1&type=pdf).
11pub fn median_filter(input: &Image<u8>, radius: u32) -> ImgProcResult<Image<u8>> {
12    let mut n_cols = (4.0 * (radius as f32).powf(2.0 / 3.0)).floor() as usize;
13    if n_cols % 2 == 0 {
14        n_cols += 1;
15    }
16
17    let mut output = Image::blank(input.info());
18
19    for x in (0..output.info().width).step_by(n_cols) {
20        process_cols_med(input, &mut output, radius, n_cols, x);
21    }
22
23    Ok(output)
24}
25
26/// Applies an alpha-trimmed mean filter, where each output pixel is the mean of the
27/// pixels in a `(2 * radius + 1) x (2 * radius + 1)` kernel in the input image, with the lowest
28/// `alpha / 2` pixels and the highest `alpha / 2` pixels removed.
29pub fn alpha_trimmed_mean_filter(input: &Image<u8>, radius: u32, alpha: u32) -> ImgProcResult<Image<u8>> {
30    let size = 2 * radius + 1;
31    error::check_even(alpha, "alpha")?;
32    if alpha >= (size * size) {
33        return Err(ImgProcError::InvalidArgError(format!("invalid alpha: size is {}, but alpha is {}", size, alpha)));
34    }
35
36    let mut n_cols = (4.0 * (radius as f32).powf(2.0 / 3.0)).floor() as usize;
37    if n_cols % 2 == 0 {
38        n_cols += 1;
39    }
40
41    let mut output = Image::blank(input.info());
42
43    for x in (0..output.info().width).step_by(n_cols) {
44        process_cols_mean(input, &mut output, radius, alpha, n_cols, x);
45    }
46
47    Ok(output)
48}
49
50/*
51 * The PartialHistograms struct:
52 *
53 * This struct contains the partial histograms, which is a vector of an odd number of histograms
54 * determined by n_cols. The only "complete" histogram is the central histogram (located at
55 * data[n_half]), which is the histogram of the pixel values in the kernel surrounding the
56 * central pixel in the row that is being processed. Each histogram to the left and right of the
57 * central histogram is not another "complete" histogram, but rather is a histogram representing
58 * the difference between the histogram for the pixel at that location and the central histogram.
59 * As such, the values in these "partial" histograms can (and frequently will) be negative.
60 * The "complete" histogram for each non-central pixel is then just the sum of the corresponding
61 * partial histogram and the central histogram.
62 *
63 * Algorithm overview:
64 *
65 * The basic idea of this algorithm is to process a row of n_cols pixels at once using the
66 * partial histograms to efficiently compute the complete histograms for each pixel. To process the
67 * next row, the partial histograms are updated to remove the top row of pixel values from the
68 * previous kernel and add the bottom row of pixel values from the current kernel. Each set of
69 * n_cols columns in the image is processed in this fashion, using a single set of partial
70 * histograms that are updated as the current kernel slides down the image.
71 */
72#[derive(Debug, Clone)]
73struct PartialHistograms {
74    data: Vec<[i32; 256]>, // The partial histograms
75    n_cols: usize, // The number of partial histograms, which is always odd. This also denotes the
76                   // number of columns we can process at once
77    n_half: usize, // Half the number of partial histograms, rounded down
78    radius: usize, // The radius of the kernel we are using
79    size: usize, // The number of pixels in a kernel
80}
81
82impl PartialHistograms {
83    fn new(radius: usize, n_cols: usize) -> Self {
84        let size = (2 * radius + 1) as usize;
85        let n_half = n_cols / 2;
86
87        PartialHistograms {
88            data: vec![[0; 256]; n_cols],
89            n_cols,
90            n_half,
91            radius,
92            size,
93        }
94    }
95
96    // Add or remove a row of pixels from the histograms, as indicated by the add parameter
97    fn update(&mut self, p_in: &[&[u8]], channel_index: usize, add: bool) {
98        let mut inc = 1;
99        if !add {
100            inc *= -1;
101        }
102
103        // Update partial histograms
104        for n in 0..self.n_half {
105            let n_upper = self.n_cols - n - 1;
106
107            for i in n..self.n_half {
108                self.data[n][p_in[i][channel_index] as usize] += inc;
109                self.data[n][p_in[i+self.size][channel_index] as usize] -= inc;
110
111                let i_upper = self.n_cols + 2 * self.radius - i - 1;
112                let i_lower = i_upper - self.size;
113                self.data[n_upper][p_in[i_lower][channel_index] as usize] -= inc;
114                self.data[n_upper][p_in[i_upper][channel_index] as usize] += inc;
115            }
116        }
117
118        // Update central histogram
119        for i in self.n_half..(self.n_half + self.size) {
120            self.data[self.n_half][p_in[i][channel_index] as usize] += inc;
121        }
122    }
123
124    // Returns the number of pixels with a value of key in the kernel for the pixel at the given
125    // index
126    fn get_count(&self, key: usize, index: usize) -> i32 {
127        let mut count = self.data[self.n_half][key as usize];
128        if index != self.n_half {
129            count += self.data[index][key as usize];
130        }
131
132        count as i32
133    }
134}
135
136////////////////////////////
137// Median filter functions
138////////////////////////////
139
140/*
141 * The MedianHist struct:
142 *
143 * In addition to containing the partial histograms, this struct keeps track of each previous
144 * median of the kernel for each pixel. These medians are used as "pivots" to find the next
145 * median: to find the median of the kernel for a given pixel, instead of scanning its histogram
146 * starting from 0, we start from the median of the kernel of the previous pixel in that column.
147 * This value is typically much closer to the current median since the majority of the pixels
148 * in the previous and current kernels are the same, which makes scanning the histogram much
149 * quicker. The "sums", or the number of values in the current histogram that are less than the
150 * previous median, is used to determine if the current median is greater than or less than the
151 * previous median, thus determining if the current histogram should be scanned upwards or
152 * downwards from the previous median, respectively, to find the current median.
153 */
154#[derive(Debug, Clone)]
155struct MedianHist {
156    data: PartialHistograms,
157    sums: Vec<i32>, // Sums to keep track of the number of values less than the previous median
158    pivots: Vec<u8>, // Previous medians to act as "pivots" to find the next median
159}
160
161impl MedianHist {
162    fn new(radius: usize, n_cols: usize) -> Self {
163        MedianHist {
164            data: PartialHistograms::new(radius, n_cols),
165            sums: vec![0; n_cols],
166            pivots: Vec::with_capacity(n_cols),
167        }
168    }
169
170    fn data(&self) -> &PartialHistograms {
171        &self.data
172    }
173
174    fn sums(&self) -> &[i32] {
175        &self.sums
176    }
177
178    fn pivots(&self) -> &[u8] {
179        &self.pivots
180    }
181
182    fn init_pivots(&mut self) {
183        self.pivots = vec![0; self.data.n_cols];
184    }
185
186    fn set_pivot(&mut self, pivot: u8, index: usize) {
187        self.pivots[index] = pivot;
188    }
189
190    fn set_sum(&mut self, sum: i32, index: usize) {
191        self.sums[index] = sum;
192    }
193
194    fn update(&mut self, p_in: &[&[u8]], channel_index: usize, add: bool) {
195        self.data.update(p_in, channel_index, add);
196
197        let mut inc = 1;
198        if !add {
199            inc *= -1;
200        }
201
202        // Update the number of values less than the previous median
203        if !self.pivots.is_empty() {
204            for n in 0..self.data.n_cols {
205                for i in n..(n + self.data.size) {
206                    if p_in[i][channel_index] < self.pivots[n] {
207                        self.sums[n] += inc;
208                    }
209                }
210            }
211        }
212    }
213}
214
215fn process_cols_med(input: &Image<u8>, output: &mut Image<u8>, radius: u32, n_cols: usize, x: u32) {
216    let size = 2 * radius + 1;
217    let center = ((size * size) / 2 + 1) as i32; // Half the number of pixels in a kernel. If
218                                                      // all the pixels in the kernel were sorted,
219                                                      // the index of the median would be (center - 1).
220    let (width, height, channels) = input.info().whc();
221    let mut histograms = vec![MedianHist::new(radius as usize, n_cols); channels as usize];
222    let mut p_out = Vec::with_capacity(channels as usize);
223
224    // Initialize histogram and process first row
225    init_cols_med(input, output, &mut histograms, &mut p_out, radius, center, n_cols, x);
226
227    // Update histogram and process remaining rows
228    let mut row_in = Vec::with_capacity(n_cols);
229    let mut row_out = Vec::with_capacity(n_cols);
230    for j in 1..height {
231        // Update histograms
232        let j_in = (j + radius).clamp(0, input.info().height - 1);
233        let j_out = (j as i32 - radius as i32 - 1).clamp(0, input.info().height as i32 - 1) as u32;
234
235        for i in (x as i32 - radius as i32)..((x + n_cols as u32 + radius) as i32) {
236            let i_clamp = i.clamp(0, width as i32 - 1) as u32;
237            row_in.push(input.get_pixel_unchecked(i_clamp, j_in));
238            row_out.push(input.get_pixel_unchecked(i_clamp, j_out));
239        }
240
241        add_row_med(&mut histograms, &row_in);
242        remove_row_med(&mut histograms, &row_out);
243
244        process_row_med(output, &mut histograms, &mut p_out, center, n_cols, x, j);
245
246        row_in.clear();
247        row_out.clear();
248    }
249}
250
251fn init_cols_med(input: &Image<u8>, output: &mut Image<u8>, histograms: &mut Vec<MedianHist>,
252                 p_out: &mut Vec<u8>, radius: u32, center: i32, n_cols: usize, x: u32) {
253    let (width, height) = input.info().wh();
254
255    // Initialize histograms
256    let mut row_in = Vec::with_capacity(n_cols);
257    for j in -(radius as i32)..(radius as i32 + 1) {
258        for i in (x as i32 - radius as i32)..((x + n_cols as u32 + radius) as i32) {
259            row_in.push(input.get_pixel_unchecked(i.clamp(0, width as i32 - 1) as u32,
260                                                  j.clamp(0, height as i32 - 1) as u32));
261        }
262
263        add_row_med(histograms, &row_in);
264        row_in.clear();
265    }
266
267    // Initialize histogram pivots
268    for hist in histograms.iter_mut() {
269        hist.init_pivots();
270    }
271
272    // Compute first median values
273    for i in 0..n_cols {
274        p_out.clear();
275        for hist in histograms.iter_mut() {
276            let mut sum = 0;
277
278            for key in 0u8..=255 {
279                let add = hist.data().get_count(key as usize, i);
280
281                if sum + add >= center {
282                    p_out.push(key);
283                    hist.set_sum(sum, i);
284                    break;
285                }
286
287                sum += add;
288            }
289        }
290
291        let x_clamp = (x + i as u32).clamp(0, output.info().width - 1);
292        output.set_pixel(x_clamp, 0, &p_out);
293
294        set_pivots_med(histograms, &p_out, i);
295    }
296}
297
298fn process_row_med(output: &mut Image<u8>, histograms: &mut Vec<MedianHist>, p_out: &mut Vec<u8>, center: i32, n_cols: usize, x: u32, y: u32) {
299    for i in 0..n_cols {
300        p_out.clear();
301        for hist in histograms.iter_mut() {
302            let pivot = hist.pivots()[i]; // Get the previous median
303            let mut sum = hist.sums()[i]; // Get the number of values less than
304                                                        // the previous median
305
306            match sum.cmp(&center) {
307                Ordering::Equal => { // The current median is equal to the previous median
308                    p_out.push(pivot);
309                },
310                Ordering::Less => { // The current median is greater than the previous median,
311                                    // so the histogram should be scanned upwards
312                    for key in pivot..=255 {
313                        let add = hist.data().get_count(key as usize, i);
314
315                        if sum + add >= center {
316                            p_out.push(key);
317                            hist.set_sum(sum, i);
318                            break;
319                        }
320
321                        sum += add;
322                    }
323                },
324                Ordering::Greater => { // The current median is less than the previous median, so the histogram
325                                       // should be scanned downwards
326                    for key in (0..pivot).rev() {
327                        sum -= hist.data().get_count(key as usize, i);
328
329                        if sum < center {
330                            p_out.push(key);
331                            hist.set_sum(sum, i);
332                            break;
333                        }
334                    }
335                }
336            }
337        }
338
339        let x_clamp = (x + i as u32).clamp(0, output.info().width - 1);
340        output.set_pixel(x_clamp, y, &p_out);
341
342        set_pivots_med(histograms, &p_out, i);
343    }
344}
345
346fn add_row_med(histograms: &mut Vec<MedianHist>, p_in: &[&[u8]]) {
347    for (c, hist) in histograms.iter_mut().enumerate() {
348        hist.update(p_in, c, true);
349    }
350}
351
352fn remove_row_med(histograms: &mut Vec<MedianHist>, p_in: &[&[u8]]) {
353    for (c, hist) in histograms.iter_mut().enumerate() {
354        hist.update(p_in, c, false);
355    }
356}
357
358fn set_pivots_med(histograms: &mut Vec<MedianHist>, pivots: &[u8], index: usize) {
359    for c in 0..pivots.len() {
360        histograms[c].set_pivot(pivots[c], index);
361    }
362}
363
364////////////////////////////////////////
365// Alpha-trimmed mean filter functions
366////////////////////////////////////////
367
368#[derive(Debug, Clone)]
369struct MeanHist {
370    data: PartialHistograms,
371    sums: Vec<i32>, // The sum of all the participating pixel values in the kernel for each pixel
372    lower: Vec<Vec<u8>>, // Vectors of all the lowest discarded pixel values in the kernel
373                         // for each pixel
374    upper: Vec<Vec<u8>>, // Vectors of all the highest discarded pixel values in the kernel for
375                         // each pixel
376    trim: usize, // The number of pixel values discarded at the low and high ends of each kernel
377                 // (equal to half of alpha)
378    len: f32, // The number of participating pixel values in each kernel
379}
380
381impl MeanHist {
382    fn new(radius: usize, n_cols: usize, alpha: u32) -> Self {
383        let size = 2 * radius + 1;
384        let len = ((size * size) - alpha as usize) as f32;
385
386        MeanHist {
387            data: PartialHistograms::new(radius, n_cols),
388            sums: Vec::with_capacity(n_cols),
389            lower: Vec::with_capacity(n_cols),
390            upper: Vec::with_capacity(n_cols),
391            trim: (alpha as usize) / 2,
392            len,
393        }
394    }
395
396    fn data(&self) -> &PartialHistograms {
397        &self.data
398    }
399
400    fn init(&mut self) {
401        self.sums = vec![0; self.data.n_cols];
402        self.lower = vec![Vec::with_capacity(self.trim); self.data.n_cols];
403        self.upper = vec![Vec::with_capacity(self.trim); self.data.n_cols];
404    }
405
406    // By some miracle, this seems to work!
407    fn update(&mut self, p_in: &[&[u8]], channel_index: usize, add: bool) {
408        if !self.sums.is_empty() {
409            if add {
410                for n in 0..self.data.n_cols {
411                    for i in n..(n + self.data.size) {
412                        let val = p_in[i][channel_index];
413                        let lower = self.lower(n);
414                        let upper = self.upper(n);
415
416                        if val < lower {
417                            self.lower[n].remove(self.trim -  1);
418                            self.sums[n] += lower as i32;
419
420                            let pos = self.lower[n].binary_search(&val).unwrap_or_else(|e| e);
421                            self.lower[n].insert(pos, val);
422                        } else if val > upper {
423                            self.upper[n].remove(self.trim - 1);
424                            self.sums[n] += upper as i32;
425
426                            let pos = self.lower[n].binary_search_by_key(&Reverse(&val), Reverse).unwrap_or_else(|e| e);
427                            self.upper[n].insert(pos, val);
428                        } else {
429                            self.sums[n] += val as i32;
430                        }
431                    }
432                }
433                self.data.update(p_in, channel_index, add);
434            } else {
435                self.data.update(p_in, channel_index, add);
436                for n in 0..self.data.n_cols {
437                    for i in n..(n + self.data.size) {
438                        let val = p_in[i][channel_index];
439                        let lower = self.lower(n);
440                        let upper = self.upper(n);
441
442                        let mut lower_count = self.data.get_count(lower as usize, n);
443                        let mut upper_count = self.data.get_count(upper as usize, n);
444
445                        for j in i..(n + self.data.size) {
446                            if p_in[j][channel_index] == lower {
447                                lower_count += 1;
448                            } else if p_in[j][channel_index] == upper {
449                                upper_count += 1;
450                            }
451                        }
452
453                        for j in self.lower[n].iter().rev() {
454                            if *j == lower {
455                                lower_count -= 1;
456                            } else {
457                                break;
458                            }
459                        }
460
461                        for j in self.upper[n].iter().rev() {
462                            if *j == upper {
463                                upper_count -= 1;
464                            } else {
465                                break;
466                            }
467                        }
468
469                        if val == lower && lower_count == 0 {
470                            self.lower[n].remove(self.trim - 1);
471                            self.get_next_lower(n, lower_count, lower);
472                        } else if val < lower {
473                            let res = self.lower[n].binary_search(&val);
474
475                            match res {
476                                Ok(pos) => {
477                                    self.lower[n].remove(pos);
478                                    self.get_next_lower(n, lower_count, lower);
479                                },
480                                Err(_) => {
481                                    self.sums[n] -= val as i32;
482                                }
483                            }
484                        } else if val == upper && upper_count == 0 {
485                            self.upper[n].remove(self.trim - 1);
486                            self.get_next_upper(n, upper_count, upper);
487                        } else if val > upper {
488                            let res = self.lower[n].binary_search_by_key(&Reverse(&val), Reverse);
489
490                            match res {
491                                Ok(pos) => {
492                                    self.upper[n].remove(pos);
493                                    self.get_next_upper(n, upper_count, upper);
494                                },
495                                Err(_) => {
496                                    self.sums[n] -= val as i32;
497                                }
498                            }
499                        } else {
500                            self.sums[n] -= val as i32;
501                        }
502                    }
503                }
504            }
505        } else {
506            self.data.update(p_in, channel_index, add);
507        }
508    }
509
510    fn set_sum(&mut self, sum: i32, index: usize) {
511        self.sums[index] = sum;
512    }
513
514    fn set_upper(&mut self, vals: Vec<u8>, index: usize) {
515        self.upper[index] = vals;
516    }
517
518    fn set_lower(&mut self, vals: Vec<u8>, index: usize) {
519        self.lower[index] = vals;
520    }
521
522    fn upper(&self, index: usize) -> u8 {
523        self.upper[index][self.trim-1]
524    }
525
526    fn lower(&self, index: usize) -> u8 {
527        self.lower[index][self.trim-1]
528    }
529
530    fn get_mean(&self, index: usize) -> u8 {
531        ((self.sums[index] as f32) / self.len).round() as u8
532    }
533
534    fn get_next_lower(&mut self, n: usize, lower_count: i32, lower: u8) {
535        if lower_count > 0 {
536            self.lower[n].push(lower);
537            self.sums[n] -= lower as i32;
538        } else {
539            for key in (lower + 1)..=255 {
540                if self.data.get_count(key as usize, n) > 0 {
541                    self.lower[n].push(key);
542                    self.sums[n] -= key as i32;
543                    break;
544                }
545            }
546        }
547    }
548
549    fn get_next_upper(&mut self, n: usize, upper_count: i32, upper: u8) {
550        if upper_count > 0 {
551            self.upper[n].push(upper);
552            self.sums[n] -= upper as i32;
553        } else {
554            for key in (0..upper).rev() {
555                if self.data.get_count(key as usize, n) > 0 {
556                    self.upper[n].push(key);
557                    self.sums[n] -= key as i32;
558                    break;
559                }
560            }
561        }
562    }
563}
564
565fn process_cols_mean(input: &Image<u8>, output: &mut Image<u8>, radius: u32, alpha: u32, n_cols: usize, x: u32) {
566    let (width, height, channels) = input.info().whc();
567    let mut histograms = vec![MeanHist::new(radius as usize, n_cols, alpha); channels as usize];
568    let mut p_out = Vec::with_capacity(channels as usize);
569
570    // Initialize histogram and process first row
571    init_cols_mean(input, output, &mut histograms, &mut p_out, radius, alpha, n_cols, x);
572
573    // Update histogram and process remaining rows
574    let mut row_in = Vec::with_capacity(n_cols);
575    let mut row_out = Vec::with_capacity(n_cols);
576    for j in 1..height {
577        // Update histograms
578        let j_in = (j + radius).clamp(0, input.info().height - 1);
579        let j_out = (j as i32 - radius as i32 - 1).clamp(0, input.info().height as i32 - 1) as u32;
580
581        for i in (x as i32 - radius as i32)..((x + n_cols as u32 + radius) as i32) {
582            let i_clamp = i.clamp(0, width as i32 - 1) as u32;
583            row_in.push(input.get_pixel_unchecked(i_clamp, j_in));
584            row_out.push(input.get_pixel_unchecked(i_clamp, j_out));
585        }
586
587        add_row_mean(&mut histograms, &row_in);
588        remove_row_mean(&mut histograms, &row_out);
589
590        process_row_mean(output, &mut histograms, &mut p_out, n_cols, x, j);
591
592        row_in.clear();
593        row_out.clear();
594    }
595}
596
597fn init_cols_mean(input: &Image<u8>, output: &mut Image<u8>, histograms: &mut Vec<MeanHist>,
598                  p_out: &mut Vec<u8>, radius: u32, alpha: u32, n_cols: usize, x: u32) {
599    let (width, height) = input.info().wh();
600    let size = 2 * radius + 1;
601
602    // Initialize histograms
603    let mut row_in = Vec::with_capacity(n_cols);
604    for j in -(radius as i32)..(radius as i32 + 1) {
605        for i in (x as i32 - radius as i32)..((x + n_cols as u32 + radius) as i32) {
606            row_in.push(input.get_pixel_unchecked(i.clamp(0, width as i32 - 1) as u32,
607                                                j.clamp(0, height as i32 - 1) as u32));
608        }
609
610        add_row_mean(histograms, &row_in);
611    }
612
613    // Initialize histograms
614    for hist in histograms.iter_mut() {
615        hist.init();
616    }
617
618    // Compute first mean values
619    let trim = (alpha as usize) / 2;
620    let upper_trim = (size * size) as usize - trim;
621    for i in 0..n_cols {
622        p_out.clear();
623        for hist in histograms.iter_mut() {
624            let mut count = 0;
625            let mut sum = 0;
626            let mut upper = Vec::with_capacity(trim);
627            let mut lower = Vec::with_capacity(trim);
628
629            for key in 0u8..=255 {
630                let mut add = hist.data().get_count(key as usize, i);
631                count += add;
632                sum += add * key as i32;
633
634                while lower.len() < trim && add > 0 {
635                    lower.push(key);
636                    sum -= key as i32;
637                    add -= 1;
638                }
639
640                while (count as usize) > upper_trim && upper.len() < trim && add > 0 {
641                    upper.insert(0, key);
642                    sum -= key as i32;
643                    add -= 1;
644                }
645            }
646
647            hist.set_sum(sum, i);
648            hist.set_upper(upper, i);
649            hist.set_lower(lower, i);
650
651            p_out.push(hist.get_mean(i));
652        }
653
654        let x_clamp = (x + i as u32).clamp(0, output.info().width - 1);
655        output.set_pixel(x_clamp, 0, &p_out);
656    }
657}
658
659fn process_row_mean(output: &mut Image<u8>, histograms: &mut Vec<MeanHist>, p_out: &mut Vec<u8>, n_cols: usize, x: u32, y: u32) {
660    for i in 0..n_cols {
661        p_out.clear();
662        for hist in histograms.iter_mut() {
663            p_out.push(hist.get_mean(i));
664        }
665
666        let x_clamp = (x + i as u32).clamp(0, output.info().width - 1);
667        output.set_pixel(x_clamp, y, &p_out);
668    }
669}
670
671fn add_row_mean(histograms: &mut Vec<MeanHist>, p_in: &[&[u8]]) {
672    for (c, hist) in histograms.iter_mut().enumerate() {
673        hist.update(p_in, c, true);
674    }
675}
676
677fn remove_row_mean(histograms: &mut Vec<MeanHist>, p_in: &[&[u8]]) {
678    for (c, hist) in histograms.iter_mut().enumerate() {
679        hist.update(p_in, c, false);
680    }
681}