1use crate::error;
2use crate::error::{ImgProcResult, ImgProcError};
3use crate::image::{Image, BaseImage};
4
5use std::cmp::{Ordering, Reverse};
6
7pub fn median_filter(input: &Image<u8>, radius: u32) -> ImgProcResult<Image<u8>> {
12 let mut n_cols = (4.0 * (radius as f32).powf(2.0 / 3.0)).floor() as usize;
13 if n_cols % 2 == 0 {
14 n_cols += 1;
15 }
16
17 let mut output = Image::blank(input.info());
18
19 for x in (0..output.info().width).step_by(n_cols) {
20 process_cols_med(input, &mut output, radius, n_cols, x);
21 }
22
23 Ok(output)
24}
25
26pub fn alpha_trimmed_mean_filter(input: &Image<u8>, radius: u32, alpha: u32) -> ImgProcResult<Image<u8>> {
30 let size = 2 * radius + 1;
31 error::check_even(alpha, "alpha")?;
32 if alpha >= (size * size) {
33 return Err(ImgProcError::InvalidArgError(format!("invalid alpha: size is {}, but alpha is {}", size, alpha)));
34 }
35
36 let mut n_cols = (4.0 * (radius as f32).powf(2.0 / 3.0)).floor() as usize;
37 if n_cols % 2 == 0 {
38 n_cols += 1;
39 }
40
41 let mut output = Image::blank(input.info());
42
43 for x in (0..output.info().width).step_by(n_cols) {
44 process_cols_mean(input, &mut output, radius, alpha, n_cols, x);
45 }
46
47 Ok(output)
48}
49
50#[derive(Debug, Clone)]
73struct PartialHistograms {
74 data: Vec<[i32; 256]>, n_cols: usize, n_half: usize, radius: usize, size: usize, }
81
82impl PartialHistograms {
83 fn new(radius: usize, n_cols: usize) -> Self {
84 let size = (2 * radius + 1) as usize;
85 let n_half = n_cols / 2;
86
87 PartialHistograms {
88 data: vec![[0; 256]; n_cols],
89 n_cols,
90 n_half,
91 radius,
92 size,
93 }
94 }
95
96 fn update(&mut self, p_in: &[&[u8]], channel_index: usize, add: bool) {
98 let mut inc = 1;
99 if !add {
100 inc *= -1;
101 }
102
103 for n in 0..self.n_half {
105 let n_upper = self.n_cols - n - 1;
106
107 for i in n..self.n_half {
108 self.data[n][p_in[i][channel_index] as usize] += inc;
109 self.data[n][p_in[i+self.size][channel_index] as usize] -= inc;
110
111 let i_upper = self.n_cols + 2 * self.radius - i - 1;
112 let i_lower = i_upper - self.size;
113 self.data[n_upper][p_in[i_lower][channel_index] as usize] -= inc;
114 self.data[n_upper][p_in[i_upper][channel_index] as usize] += inc;
115 }
116 }
117
118 for i in self.n_half..(self.n_half + self.size) {
120 self.data[self.n_half][p_in[i][channel_index] as usize] += inc;
121 }
122 }
123
124 fn get_count(&self, key: usize, index: usize) -> i32 {
127 let mut count = self.data[self.n_half][key as usize];
128 if index != self.n_half {
129 count += self.data[index][key as usize];
130 }
131
132 count as i32
133 }
134}
135
136#[derive(Debug, Clone)]
155struct MedianHist {
156 data: PartialHistograms,
157 sums: Vec<i32>, pivots: Vec<u8>, }
160
161impl MedianHist {
162 fn new(radius: usize, n_cols: usize) -> Self {
163 MedianHist {
164 data: PartialHistograms::new(radius, n_cols),
165 sums: vec![0; n_cols],
166 pivots: Vec::with_capacity(n_cols),
167 }
168 }
169
170 fn data(&self) -> &PartialHistograms {
171 &self.data
172 }
173
174 fn sums(&self) -> &[i32] {
175 &self.sums
176 }
177
178 fn pivots(&self) -> &[u8] {
179 &self.pivots
180 }
181
182 fn init_pivots(&mut self) {
183 self.pivots = vec![0; self.data.n_cols];
184 }
185
186 fn set_pivot(&mut self, pivot: u8, index: usize) {
187 self.pivots[index] = pivot;
188 }
189
190 fn set_sum(&mut self, sum: i32, index: usize) {
191 self.sums[index] = sum;
192 }
193
194 fn update(&mut self, p_in: &[&[u8]], channel_index: usize, add: bool) {
195 self.data.update(p_in, channel_index, add);
196
197 let mut inc = 1;
198 if !add {
199 inc *= -1;
200 }
201
202 if !self.pivots.is_empty() {
204 for n in 0..self.data.n_cols {
205 for i in n..(n + self.data.size) {
206 if p_in[i][channel_index] < self.pivots[n] {
207 self.sums[n] += inc;
208 }
209 }
210 }
211 }
212 }
213}
214
215fn process_cols_med(input: &Image<u8>, output: &mut Image<u8>, radius: u32, n_cols: usize, x: u32) {
216 let size = 2 * radius + 1;
217 let center = ((size * size) / 2 + 1) as i32; let (width, height, channels) = input.info().whc();
221 let mut histograms = vec![MedianHist::new(radius as usize, n_cols); channels as usize];
222 let mut p_out = Vec::with_capacity(channels as usize);
223
224 init_cols_med(input, output, &mut histograms, &mut p_out, radius, center, n_cols, x);
226
227 let mut row_in = Vec::with_capacity(n_cols);
229 let mut row_out = Vec::with_capacity(n_cols);
230 for j in 1..height {
231 let j_in = (j + radius).clamp(0, input.info().height - 1);
233 let j_out = (j as i32 - radius as i32 - 1).clamp(0, input.info().height as i32 - 1) as u32;
234
235 for i in (x as i32 - radius as i32)..((x + n_cols as u32 + radius) as i32) {
236 let i_clamp = i.clamp(0, width as i32 - 1) as u32;
237 row_in.push(input.get_pixel_unchecked(i_clamp, j_in));
238 row_out.push(input.get_pixel_unchecked(i_clamp, j_out));
239 }
240
241 add_row_med(&mut histograms, &row_in);
242 remove_row_med(&mut histograms, &row_out);
243
244 process_row_med(output, &mut histograms, &mut p_out, center, n_cols, x, j);
245
246 row_in.clear();
247 row_out.clear();
248 }
249}
250
251fn init_cols_med(input: &Image<u8>, output: &mut Image<u8>, histograms: &mut Vec<MedianHist>,
252 p_out: &mut Vec<u8>, radius: u32, center: i32, n_cols: usize, x: u32) {
253 let (width, height) = input.info().wh();
254
255 let mut row_in = Vec::with_capacity(n_cols);
257 for j in -(radius as i32)..(radius as i32 + 1) {
258 for i in (x as i32 - radius as i32)..((x + n_cols as u32 + radius) as i32) {
259 row_in.push(input.get_pixel_unchecked(i.clamp(0, width as i32 - 1) as u32,
260 j.clamp(0, height as i32 - 1) as u32));
261 }
262
263 add_row_med(histograms, &row_in);
264 row_in.clear();
265 }
266
267 for hist in histograms.iter_mut() {
269 hist.init_pivots();
270 }
271
272 for i in 0..n_cols {
274 p_out.clear();
275 for hist in histograms.iter_mut() {
276 let mut sum = 0;
277
278 for key in 0u8..=255 {
279 let add = hist.data().get_count(key as usize, i);
280
281 if sum + add >= center {
282 p_out.push(key);
283 hist.set_sum(sum, i);
284 break;
285 }
286
287 sum += add;
288 }
289 }
290
291 let x_clamp = (x + i as u32).clamp(0, output.info().width - 1);
292 output.set_pixel(x_clamp, 0, &p_out);
293
294 set_pivots_med(histograms, &p_out, i);
295 }
296}
297
298fn process_row_med(output: &mut Image<u8>, histograms: &mut Vec<MedianHist>, p_out: &mut Vec<u8>, center: i32, n_cols: usize, x: u32, y: u32) {
299 for i in 0..n_cols {
300 p_out.clear();
301 for hist in histograms.iter_mut() {
302 let pivot = hist.pivots()[i]; let mut sum = hist.sums()[i]; match sum.cmp(¢er) {
307 Ordering::Equal => { p_out.push(pivot);
309 },
310 Ordering::Less => { for key in pivot..=255 {
313 let add = hist.data().get_count(key as usize, i);
314
315 if sum + add >= center {
316 p_out.push(key);
317 hist.set_sum(sum, i);
318 break;
319 }
320
321 sum += add;
322 }
323 },
324 Ordering::Greater => { for key in (0..pivot).rev() {
327 sum -= hist.data().get_count(key as usize, i);
328
329 if sum < center {
330 p_out.push(key);
331 hist.set_sum(sum, i);
332 break;
333 }
334 }
335 }
336 }
337 }
338
339 let x_clamp = (x + i as u32).clamp(0, output.info().width - 1);
340 output.set_pixel(x_clamp, y, &p_out);
341
342 set_pivots_med(histograms, &p_out, i);
343 }
344}
345
346fn add_row_med(histograms: &mut Vec<MedianHist>, p_in: &[&[u8]]) {
347 for (c, hist) in histograms.iter_mut().enumerate() {
348 hist.update(p_in, c, true);
349 }
350}
351
352fn remove_row_med(histograms: &mut Vec<MedianHist>, p_in: &[&[u8]]) {
353 for (c, hist) in histograms.iter_mut().enumerate() {
354 hist.update(p_in, c, false);
355 }
356}
357
358fn set_pivots_med(histograms: &mut Vec<MedianHist>, pivots: &[u8], index: usize) {
359 for c in 0..pivots.len() {
360 histograms[c].set_pivot(pivots[c], index);
361 }
362}
363
364#[derive(Debug, Clone)]
369struct MeanHist {
370 data: PartialHistograms,
371 sums: Vec<i32>, lower: Vec<Vec<u8>>, upper: Vec<Vec<u8>>, trim: usize, len: f32, }
380
381impl MeanHist {
382 fn new(radius: usize, n_cols: usize, alpha: u32) -> Self {
383 let size = 2 * radius + 1;
384 let len = ((size * size) - alpha as usize) as f32;
385
386 MeanHist {
387 data: PartialHistograms::new(radius, n_cols),
388 sums: Vec::with_capacity(n_cols),
389 lower: Vec::with_capacity(n_cols),
390 upper: Vec::with_capacity(n_cols),
391 trim: (alpha as usize) / 2,
392 len,
393 }
394 }
395
396 fn data(&self) -> &PartialHistograms {
397 &self.data
398 }
399
400 fn init(&mut self) {
401 self.sums = vec![0; self.data.n_cols];
402 self.lower = vec![Vec::with_capacity(self.trim); self.data.n_cols];
403 self.upper = vec![Vec::with_capacity(self.trim); self.data.n_cols];
404 }
405
406 fn update(&mut self, p_in: &[&[u8]], channel_index: usize, add: bool) {
408 if !self.sums.is_empty() {
409 if add {
410 for n in 0..self.data.n_cols {
411 for i in n..(n + self.data.size) {
412 let val = p_in[i][channel_index];
413 let lower = self.lower(n);
414 let upper = self.upper(n);
415
416 if val < lower {
417 self.lower[n].remove(self.trim - 1);
418 self.sums[n] += lower as i32;
419
420 let pos = self.lower[n].binary_search(&val).unwrap_or_else(|e| e);
421 self.lower[n].insert(pos, val);
422 } else if val > upper {
423 self.upper[n].remove(self.trim - 1);
424 self.sums[n] += upper as i32;
425
426 let pos = self.lower[n].binary_search_by_key(&Reverse(&val), Reverse).unwrap_or_else(|e| e);
427 self.upper[n].insert(pos, val);
428 } else {
429 self.sums[n] += val as i32;
430 }
431 }
432 }
433 self.data.update(p_in, channel_index, add);
434 } else {
435 self.data.update(p_in, channel_index, add);
436 for n in 0..self.data.n_cols {
437 for i in n..(n + self.data.size) {
438 let val = p_in[i][channel_index];
439 let lower = self.lower(n);
440 let upper = self.upper(n);
441
442 let mut lower_count = self.data.get_count(lower as usize, n);
443 let mut upper_count = self.data.get_count(upper as usize, n);
444
445 for j in i..(n + self.data.size) {
446 if p_in[j][channel_index] == lower {
447 lower_count += 1;
448 } else if p_in[j][channel_index] == upper {
449 upper_count += 1;
450 }
451 }
452
453 for j in self.lower[n].iter().rev() {
454 if *j == lower {
455 lower_count -= 1;
456 } else {
457 break;
458 }
459 }
460
461 for j in self.upper[n].iter().rev() {
462 if *j == upper {
463 upper_count -= 1;
464 } else {
465 break;
466 }
467 }
468
469 if val == lower && lower_count == 0 {
470 self.lower[n].remove(self.trim - 1);
471 self.get_next_lower(n, lower_count, lower);
472 } else if val < lower {
473 let res = self.lower[n].binary_search(&val);
474
475 match res {
476 Ok(pos) => {
477 self.lower[n].remove(pos);
478 self.get_next_lower(n, lower_count, lower);
479 },
480 Err(_) => {
481 self.sums[n] -= val as i32;
482 }
483 }
484 } else if val == upper && upper_count == 0 {
485 self.upper[n].remove(self.trim - 1);
486 self.get_next_upper(n, upper_count, upper);
487 } else if val > upper {
488 let res = self.lower[n].binary_search_by_key(&Reverse(&val), Reverse);
489
490 match res {
491 Ok(pos) => {
492 self.upper[n].remove(pos);
493 self.get_next_upper(n, upper_count, upper);
494 },
495 Err(_) => {
496 self.sums[n] -= val as i32;
497 }
498 }
499 } else {
500 self.sums[n] -= val as i32;
501 }
502 }
503 }
504 }
505 } else {
506 self.data.update(p_in, channel_index, add);
507 }
508 }
509
510 fn set_sum(&mut self, sum: i32, index: usize) {
511 self.sums[index] = sum;
512 }
513
514 fn set_upper(&mut self, vals: Vec<u8>, index: usize) {
515 self.upper[index] = vals;
516 }
517
518 fn set_lower(&mut self, vals: Vec<u8>, index: usize) {
519 self.lower[index] = vals;
520 }
521
522 fn upper(&self, index: usize) -> u8 {
523 self.upper[index][self.trim-1]
524 }
525
526 fn lower(&self, index: usize) -> u8 {
527 self.lower[index][self.trim-1]
528 }
529
530 fn get_mean(&self, index: usize) -> u8 {
531 ((self.sums[index] as f32) / self.len).round() as u8
532 }
533
534 fn get_next_lower(&mut self, n: usize, lower_count: i32, lower: u8) {
535 if lower_count > 0 {
536 self.lower[n].push(lower);
537 self.sums[n] -= lower as i32;
538 } else {
539 for key in (lower + 1)..=255 {
540 if self.data.get_count(key as usize, n) > 0 {
541 self.lower[n].push(key);
542 self.sums[n] -= key as i32;
543 break;
544 }
545 }
546 }
547 }
548
549 fn get_next_upper(&mut self, n: usize, upper_count: i32, upper: u8) {
550 if upper_count > 0 {
551 self.upper[n].push(upper);
552 self.sums[n] -= upper as i32;
553 } else {
554 for key in (0..upper).rev() {
555 if self.data.get_count(key as usize, n) > 0 {
556 self.upper[n].push(key);
557 self.sums[n] -= key as i32;
558 break;
559 }
560 }
561 }
562 }
563}
564
565fn process_cols_mean(input: &Image<u8>, output: &mut Image<u8>, radius: u32, alpha: u32, n_cols: usize, x: u32) {
566 let (width, height, channels) = input.info().whc();
567 let mut histograms = vec![MeanHist::new(radius as usize, n_cols, alpha); channels as usize];
568 let mut p_out = Vec::with_capacity(channels as usize);
569
570 init_cols_mean(input, output, &mut histograms, &mut p_out, radius, alpha, n_cols, x);
572
573 let mut row_in = Vec::with_capacity(n_cols);
575 let mut row_out = Vec::with_capacity(n_cols);
576 for j in 1..height {
577 let j_in = (j + radius).clamp(0, input.info().height - 1);
579 let j_out = (j as i32 - radius as i32 - 1).clamp(0, input.info().height as i32 - 1) as u32;
580
581 for i in (x as i32 - radius as i32)..((x + n_cols as u32 + radius) as i32) {
582 let i_clamp = i.clamp(0, width as i32 - 1) as u32;
583 row_in.push(input.get_pixel_unchecked(i_clamp, j_in));
584 row_out.push(input.get_pixel_unchecked(i_clamp, j_out));
585 }
586
587 add_row_mean(&mut histograms, &row_in);
588 remove_row_mean(&mut histograms, &row_out);
589
590 process_row_mean(output, &mut histograms, &mut p_out, n_cols, x, j);
591
592 row_in.clear();
593 row_out.clear();
594 }
595}
596
597fn init_cols_mean(input: &Image<u8>, output: &mut Image<u8>, histograms: &mut Vec<MeanHist>,
598 p_out: &mut Vec<u8>, radius: u32, alpha: u32, n_cols: usize, x: u32) {
599 let (width, height) = input.info().wh();
600 let size = 2 * radius + 1;
601
602 let mut row_in = Vec::with_capacity(n_cols);
604 for j in -(radius as i32)..(radius as i32 + 1) {
605 for i in (x as i32 - radius as i32)..((x + n_cols as u32 + radius) as i32) {
606 row_in.push(input.get_pixel_unchecked(i.clamp(0, width as i32 - 1) as u32,
607 j.clamp(0, height as i32 - 1) as u32));
608 }
609
610 add_row_mean(histograms, &row_in);
611 }
612
613 for hist in histograms.iter_mut() {
615 hist.init();
616 }
617
618 let trim = (alpha as usize) / 2;
620 let upper_trim = (size * size) as usize - trim;
621 for i in 0..n_cols {
622 p_out.clear();
623 for hist in histograms.iter_mut() {
624 let mut count = 0;
625 let mut sum = 0;
626 let mut upper = Vec::with_capacity(trim);
627 let mut lower = Vec::with_capacity(trim);
628
629 for key in 0u8..=255 {
630 let mut add = hist.data().get_count(key as usize, i);
631 count += add;
632 sum += add * key as i32;
633
634 while lower.len() < trim && add > 0 {
635 lower.push(key);
636 sum -= key as i32;
637 add -= 1;
638 }
639
640 while (count as usize) > upper_trim && upper.len() < trim && add > 0 {
641 upper.insert(0, key);
642 sum -= key as i32;
643 add -= 1;
644 }
645 }
646
647 hist.set_sum(sum, i);
648 hist.set_upper(upper, i);
649 hist.set_lower(lower, i);
650
651 p_out.push(hist.get_mean(i));
652 }
653
654 let x_clamp = (x + i as u32).clamp(0, output.info().width - 1);
655 output.set_pixel(x_clamp, 0, &p_out);
656 }
657}
658
659fn process_row_mean(output: &mut Image<u8>, histograms: &mut Vec<MeanHist>, p_out: &mut Vec<u8>, n_cols: usize, x: u32, y: u32) {
660 for i in 0..n_cols {
661 p_out.clear();
662 for hist in histograms.iter_mut() {
663 p_out.push(hist.get_mean(i));
664 }
665
666 let x_clamp = (x + i as u32).clamp(0, output.info().width - 1);
667 output.set_pixel(x_clamp, y, &p_out);
668 }
669}
670
671fn add_row_mean(histograms: &mut Vec<MeanHist>, p_in: &[&[u8]]) {
672 for (c, hist) in histograms.iter_mut().enumerate() {
673 hist.update(p_in, c, true);
674 }
675}
676
677fn remove_row_mean(histograms: &mut Vec<MeanHist>, p_in: &[&[u8]]) {
678 for (c, hist) in histograms.iter_mut().enumerate() {
679 hist.update(p_in, c, false);
680 }
681}