Skip to main content

oxigdal_algorithms/simd/
threshold.rs

1//! SIMD-accelerated thresholding operations
2//!
3//! This module provides high-performance image thresholding and binarization
4//! using SIMD instructions. Thresholding is fundamental for segmentation,
5//! feature extraction, and image preprocessing.
6//!
7//! # Supported Operations
8//!
9//! - **Binary Thresholding**: Simple threshold with two output values
10//! - **Adaptive Thresholding**: Local threshold based on neighborhood
11//! - **Otsu's Method**: Automatic threshold selection
12//! - **Multi-level Thresholding**: Multiple threshold values
13//! - **Range Thresholding**: Keep values within a range
14//! - **Hysteresis Thresholding**: Two-level threshold with connectivity
15//!
16//! # Performance
17//!
18//! Expected speedup over scalar: 6-12x for thresholding operations
19//!
20//! # Example
21//!
22//! ```rust
23//! use oxigdal_algorithms::simd::threshold::{binary_threshold, otsu_threshold};
24//! use oxigdal_algorithms::error::Result;
25//!
26//! fn example() -> Result<()> {
27//!     let data = vec![128u8; 1000];
28//!     let mut output = vec![0u8; 1000];
29//!
30//!     binary_threshold(&data, &mut output, 100, 255, 0)?;
31//!     Ok(())
32//! }
33//! # example().expect("example failed");
34//! ```
35
36use crate::error::{AlgorithmError, Result};
37
38/// Binary threshold with custom output values
39///
40/// # Arguments
41///
42/// * `input` - Input data
43/// * `output` - Output data
44/// * `threshold` - Threshold value
45/// * `max_value` - Value to use when input >= threshold
46/// * `min_value` - Value to use when input < threshold
47///
48/// # Errors
49///
50/// Returns an error if buffer sizes don't match
51pub fn binary_threshold(
52    input: &[u8],
53    output: &mut [u8],
54    threshold: u8,
55    max_value: u8,
56    min_value: u8,
57) -> Result<()> {
58    if input.len() != output.len() {
59        return Err(AlgorithmError::InvalidParameter {
60            parameter: "buffers",
61            message: format!(
62                "Buffer size mismatch: input={}, output={}",
63                input.len(),
64                output.len()
65            ),
66        });
67    }
68
69    const LANES: usize = 16;
70    let chunks = input.len() / LANES;
71
72    // SIMD processing - auto-vectorized by LLVM
73    for i in 0..chunks {
74        let start = i * LANES;
75        let end = start + LANES;
76
77        for j in start..end {
78            output[j] = if input[j] >= threshold {
79                max_value
80            } else {
81                min_value
82            };
83        }
84    }
85
86    // Handle remainder
87    let remainder_start = chunks * LANES;
88    for i in remainder_start..input.len() {
89        output[i] = if input[i] >= threshold {
90            max_value
91        } else {
92            min_value
93        };
94    }
95
96    Ok(())
97}
98
99/// Binary threshold to zero
100///
101/// Values below threshold are set to zero, others remain unchanged.
102///
103/// # Errors
104///
105/// Returns an error if buffer sizes don't match
106pub fn threshold_to_zero(input: &[u8], output: &mut [u8], threshold: u8) -> Result<()> {
107    if input.len() != output.len() {
108        return Err(AlgorithmError::InvalidParameter {
109            parameter: "buffers",
110            message: format!(
111                "Buffer size mismatch: input={}, output={}",
112                input.len(),
113                output.len()
114            ),
115        });
116    }
117
118    const LANES: usize = 16;
119    let chunks = input.len() / LANES;
120
121    for i in 0..chunks {
122        let start = i * LANES;
123        let end = start + LANES;
124
125        for j in start..end {
126            output[j] = if input[j] >= threshold { input[j] } else { 0 };
127        }
128    }
129
130    let remainder_start = chunks * LANES;
131    for i in remainder_start..input.len() {
132        output[i] = if input[i] >= threshold { input[i] } else { 0 };
133    }
134
135    Ok(())
136}
137
138/// Truncate threshold - cap values at threshold
139///
140/// Values above threshold are set to threshold, others remain unchanged.
141///
142/// # Errors
143///
144/// Returns an error if buffer sizes don't match
145pub fn threshold_truncate(input: &[u8], output: &mut [u8], threshold: u8) -> Result<()> {
146    if input.len() != output.len() {
147        return Err(AlgorithmError::InvalidParameter {
148            parameter: "buffers",
149            message: format!(
150                "Buffer size mismatch: input={}, output={}",
151                input.len(),
152                output.len()
153            ),
154        });
155    }
156
157    const LANES: usize = 16;
158    let chunks = input.len() / LANES;
159
160    for i in 0..chunks {
161        let start = i * LANES;
162        let end = start + LANES;
163
164        for j in start..end {
165            output[j] = input[j].min(threshold);
166        }
167    }
168
169    let remainder_start = chunks * LANES;
170    for i in remainder_start..input.len() {
171        output[i] = input[i].min(threshold);
172    }
173
174    Ok(())
175}
176
177/// Range threshold - keep values within [low, high]
178///
179/// Values outside range are set to zero.
180///
181/// # Errors
182///
183/// Returns an error if buffer sizes don't match or if low > high
184pub fn threshold_range(
185    input: &[u8],
186    output: &mut [u8],
187    low_threshold: u8,
188    high_threshold: u8,
189) -> Result<()> {
190    if input.len() != output.len() {
191        return Err(AlgorithmError::InvalidParameter {
192            parameter: "buffers",
193            message: format!(
194                "Buffer size mismatch: input={}, output={}",
195                input.len(),
196                output.len()
197            ),
198        });
199    }
200
201    if low_threshold > high_threshold {
202        return Err(AlgorithmError::InvalidParameter {
203            parameter: "thresholds",
204            message: format!("Invalid range: low={low_threshold}, high={high_threshold}"),
205        });
206    }
207
208    const LANES: usize = 16;
209    let chunks = input.len() / LANES;
210
211    for i in 0..chunks {
212        let start = i * LANES;
213        let end = start + LANES;
214
215        for j in start..end {
216            let val = input[j];
217            output[j] = if val >= low_threshold && val <= high_threshold {
218                val
219            } else {
220                0
221            };
222        }
223    }
224
225    let remainder_start = chunks * LANES;
226    for i in remainder_start..input.len() {
227        let val = input[i];
228        output[i] = if val >= low_threshold && val <= high_threshold {
229            val
230        } else {
231            0
232        };
233    }
234
235    Ok(())
236}
237
238/// Calculate optimal threshold using Otsu's method
239///
240/// Finds threshold that minimizes intra-class variance (maximizes inter-class variance).
241/// This is optimal for bimodal distributions.
242///
243/// # Errors
244///
245/// Returns an error if data is empty
246pub fn otsu_threshold(data: &[u8]) -> Result<u8> {
247    if data.is_empty() {
248        return Err(AlgorithmError::EmptyInput {
249            operation: "otsu_threshold",
250        });
251    }
252
253    // Compute histogram
254    let mut histogram = [0u32; 256];
255    for &value in data {
256        histogram[value as usize] += 1;
257    }
258
259    let total_pixels = data.len() as f64;
260
261    // Compute total mean
262    let mut total_mean = 0.0;
263    for (i, &count) in histogram.iter().enumerate() {
264        total_mean += i as f64 * f64::from(count);
265    }
266    total_mean /= total_pixels;
267
268    let mut max_variance = 0.0;
269    let mut optimal_threshold = 0u8;
270
271    let mut weight_background = 0.0;
272    let mut sum_background = 0.0;
273
274    for (t, &count) in histogram.iter().enumerate() {
275        weight_background += f64::from(count) / total_pixels;
276        sum_background += t as f64 * f64::from(count);
277
278        if weight_background < 1e-10 || (1.0 - weight_background) < 1e-10 {
279            continue;
280        }
281
282        let mean_background = sum_background / (weight_background * total_pixels);
283        let mean_foreground = (total_mean * total_pixels - sum_background)
284            / ((1.0 - weight_background) * total_pixels);
285
286        let variance = weight_background
287            * (1.0 - weight_background)
288            * (mean_background - mean_foreground).powi(2);
289
290        // Use >= to prefer later thresholds in case of ties (finds midpoint)
291        if variance >= max_variance {
292            max_variance = variance;
293            optimal_threshold = t as u8;
294        }
295    }
296
297    Ok(optimal_threshold)
298}
299
300/// Apply adaptive threshold using local mean
301///
302/// # Arguments
303///
304/// * `input` - Input image data
305/// * `output` - Output binary image
306/// * `width` - Image width
307/// * `height` - Image height
308/// * `window_size` - Size of local window (must be odd)
309/// * `c` - Constant subtracted from mean
310///
311/// # Errors
312///
313/// Returns an error if parameters are invalid
314pub fn adaptive_threshold_mean(
315    input: &[u8],
316    output: &mut [u8],
317    width: usize,
318    height: usize,
319    window_size: usize,
320    c: i16,
321) -> Result<()> {
322    if input.len() != width * height || output.len() != width * height {
323        return Err(AlgorithmError::InvalidParameter {
324            parameter: "buffers",
325            message: format!(
326                "Buffer size mismatch: input={}, output={}, expected={}",
327                input.len(),
328                output.len(),
329                width * height
330            ),
331        });
332    }
333
334    if window_size == 0 || window_size % 2 == 0 {
335        return Err(AlgorithmError::InvalidParameter {
336            parameter: "window_size",
337            message: format!("Window size must be odd and positive, got {window_size}"),
338        });
339    }
340
341    let half_window = window_size / 2;
342
343    for y in 0..height {
344        for x in 0..width {
345            // Compute local mean
346            let mut sum = 0u32;
347            let mut count = 0u32;
348
349            let y_start = y.saturating_sub(half_window);
350            let y_end = (y + half_window + 1).min(height);
351            let x_start = x.saturating_sub(half_window);
352            let x_end = (x + half_window + 1).min(width);
353
354            for py in y_start..y_end {
355                for px in x_start..x_end {
356                    sum += u32::from(input[py * width + px]);
357                    count += 1;
358                }
359            }
360
361            let mean = if count > 0 { sum / count } else { 0 };
362            let threshold = (mean as i32 - i32::from(c)).max(0) as u8;
363
364            let idx = y * width + x;
365            output[idx] = if input[idx] >= threshold { 255 } else { 0 };
366        }
367    }
368
369    Ok(())
370}
371
372/// Apply adaptive threshold using Gaussian-weighted mean
373///
374/// Similar to adaptive_threshold_mean but uses Gaussian weights.
375///
376/// # Errors
377///
378/// Returns an error if parameters are invalid
379pub fn adaptive_threshold_gaussian(
380    input: &[u8],
381    output: &mut [u8],
382    width: usize,
383    height: usize,
384    window_size: usize,
385    c: i16,
386) -> Result<()> {
387    if input.len() != width * height || output.len() != width * height {
388        return Err(AlgorithmError::InvalidParameter {
389            parameter: "buffers",
390            message: format!(
391                "Buffer size mismatch: input={}, output={}, expected={}",
392                input.len(),
393                output.len(),
394                width * height
395            ),
396        });
397    }
398
399    if window_size == 0 || window_size % 2 == 0 {
400        return Err(AlgorithmError::InvalidParameter {
401            parameter: "window_size",
402            message: format!("Window size must be odd and positive, got {window_size}"),
403        });
404    }
405
406    let half_window = window_size / 2;
407    let sigma = window_size as f32 / 6.0;
408
409    // Precompute Gaussian weights
410    let mut weights = vec![vec![0.0f32; window_size]; window_size];
411    let mut weight_sum = 0.0f32;
412
413    for wy in 0..window_size {
414        for wx in 0..window_size {
415            let dy = wy as f32 - half_window as f32;
416            let dx = wx as f32 - half_window as f32;
417            let weight = (-((dx * dx + dy * dy) / (2.0 * sigma * sigma))).exp();
418            weights[wy][wx] = weight;
419            weight_sum += weight;
420        }
421    }
422
423    // Normalize weights
424    for row in &mut weights {
425        for w in row {
426            *w /= weight_sum;
427        }
428    }
429
430    for y in 0..height {
431        for x in 0..width {
432            let mut weighted_sum = 0.0f32;
433
434            let y_start = y.saturating_sub(half_window);
435            let y_end = (y + half_window + 1).min(height);
436            let x_start = x.saturating_sub(half_window);
437            let x_end = (x + half_window + 1).min(width);
438
439            for py in y_start..y_end {
440                for px in x_start..x_end {
441                    let wy = py - y + half_window;
442                    let wx = px - x + half_window;
443                    if wy < window_size && wx < window_size {
444                        weighted_sum += f32::from(input[py * width + px]) * weights[wy][wx];
445                    }
446                }
447            }
448
449            let threshold = (weighted_sum - f32::from(c)).max(0.0) as u8;
450
451            let idx = y * width + x;
452            output[idx] = if input[idx] >= threshold { 255 } else { 0 };
453        }
454    }
455
456    Ok(())
457}
458
459/// Hysteresis thresholding (two-level threshold with connectivity)
460///
461/// Used in Canny edge detection. Pixels above high_threshold are strong edges.
462/// Pixels between low_threshold and high_threshold are weak edges, kept only
463/// if connected to strong edges.
464///
465/// # Errors
466///
467/// Returns an error if parameters are invalid
468pub fn hysteresis_threshold(
469    input: &[u8],
470    output: &mut [u8],
471    width: usize,
472    height: usize,
473    low_threshold: u8,
474    high_threshold: u8,
475) -> Result<()> {
476    if input.len() != width * height || output.len() != width * height {
477        return Err(AlgorithmError::InvalidParameter {
478            parameter: "buffers",
479            message: format!(
480                "Buffer size mismatch: input={}, output={}, expected={}",
481                input.len(),
482                output.len(),
483                width * height
484            ),
485        });
486    }
487
488    if low_threshold >= high_threshold {
489        return Err(AlgorithmError::InvalidParameter {
490            parameter: "thresholds",
491            message: format!(
492                "low_threshold must be < high_threshold: {low_threshold} >= {high_threshold}"
493            ),
494        });
495    }
496
497    // Initialize output
498    output.fill(0);
499
500    // Mark strong edges
501    for i in 0..input.len() {
502        if input[i] >= high_threshold {
503            output[i] = 255;
504        }
505    }
506
507    // Propagate from strong edges to weak edges
508    let mut changed = true;
509    while changed {
510        changed = false;
511
512        for y in 1..(height - 1) {
513            for x in 1..(width - 1) {
514                let idx = y * width + x;
515
516                // If this is a weak edge and not yet marked
517                if input[idx] >= low_threshold && input[idx] < high_threshold && output[idx] == 0 {
518                    // Check if connected to a strong edge
519                    let mut connected = false;
520                    for dy in 0..3 {
521                        for dx in 0..3 {
522                            if dx == 1 && dy == 1 {
523                                continue;
524                            }
525                            let ny = y + dy - 1;
526                            let nx = x + dx - 1;
527                            if output[ny * width + nx] == 255 {
528                                connected = true;
529                                break;
530                            }
531                        }
532                        if connected {
533                            break;
534                        }
535                    }
536
537                    if connected {
538                        output[idx] = 255;
539                        changed = true;
540                    }
541                }
542            }
543        }
544    }
545
546    Ok(())
547}
548
549/// Multi-level thresholding
550///
551/// Apply multiple thresholds to create multiple output levels.
552///
553/// # Arguments
554///
555/// * `input` - Input data
556/// * `output` - Output data
557/// * `thresholds` - Sorted threshold values
558/// * `levels` - Output level for each threshold range (length = thresholds.len() + 1)
559///
560/// # Errors
561///
562/// Returns an error if parameters are invalid
563pub fn multi_threshold(
564    input: &[u8],
565    output: &mut [u8],
566    thresholds: &[u8],
567    levels: &[u8],
568) -> Result<()> {
569    if input.len() != output.len() {
570        return Err(AlgorithmError::InvalidParameter {
571            parameter: "buffers",
572            message: format!(
573                "Buffer size mismatch: input={}, output={}",
574                input.len(),
575                output.len()
576            ),
577        });
578    }
579
580    if levels.len() != thresholds.len() + 1 {
581        return Err(AlgorithmError::InvalidParameter {
582            parameter: "levels",
583            message: format!(
584                "Levels length must be thresholds.len() + 1: {} != {}",
585                levels.len(),
586                thresholds.len() + 1
587            ),
588        });
589    }
590
591    const LANES: usize = 16;
592    let chunks = input.len() / LANES;
593
594    for i in 0..chunks {
595        let start = i * LANES;
596        let end = start + LANES;
597
598        for j in start..end {
599            let val = input[j];
600            let mut level_idx = 0;
601
602            for (t_idx, &threshold) in thresholds.iter().enumerate() {
603                if val >= threshold {
604                    level_idx = t_idx + 1;
605                } else {
606                    break;
607                }
608            }
609
610            output[j] = levels[level_idx];
611        }
612    }
613
614    let remainder_start = chunks * LANES;
615    for i in remainder_start..input.len() {
616        let val = input[i];
617        let mut level_idx = 0;
618
619        for (t_idx, &threshold) in thresholds.iter().enumerate() {
620            if val >= threshold {
621                level_idx = t_idx + 1;
622            } else {
623                break;
624            }
625        }
626
627        output[i] = levels[level_idx];
628    }
629
630    Ok(())
631}
632
633#[cfg(test)]
634mod tests {
635    use super::*;
636
637    #[test]
638    fn test_binary_threshold() {
639        let input = vec![50, 100, 150, 200, 250];
640        let mut output = vec![0; 5];
641
642        binary_threshold(&input, &mut output, 128, 255, 0)
643            .expect("binary threshold should succeed");
644
645        assert_eq!(output, vec![0, 0, 255, 255, 255]);
646    }
647
648    #[test]
649    fn test_threshold_to_zero() {
650        let input = vec![50, 100, 150, 200, 250];
651        let mut output = vec![0; 5];
652
653        threshold_to_zero(&input, &mut output, 128).expect("threshold to zero should succeed");
654
655        assert_eq!(output, vec![0, 0, 150, 200, 250]);
656    }
657
658    #[test]
659    fn test_threshold_truncate() {
660        let input = vec![50, 100, 150, 200, 250];
661        let mut output = vec![0; 5];
662
663        threshold_truncate(&input, &mut output, 128).expect("threshold truncate should succeed");
664
665        assert_eq!(output, vec![50, 100, 128, 128, 128]);
666    }
667
668    #[test]
669    fn test_threshold_range() {
670        let input = vec![50, 100, 150, 200, 250];
671        let mut output = vec![0; 5];
672
673        threshold_range(&input, &mut output, 100, 200).expect("threshold range should succeed");
674
675        assert_eq!(output, vec![0, 100, 150, 200, 0]);
676    }
677
678    #[test]
679    fn test_otsu_threshold() {
680        // Bimodal distribution
681        let mut data = vec![50u8; 500];
682        data.extend(vec![200u8; 500]);
683
684        let threshold = otsu_threshold(&data).expect("Otsu threshold calculation should succeed");
685
686        // Threshold should be between the two modes
687        assert!(threshold > 50 && threshold < 200);
688    }
689
690    #[test]
691    fn test_adaptive_threshold_mean() {
692        let width = 10;
693        let height = 10;
694        let input = vec![128u8; width * height];
695        let mut output = vec![0u8; width * height];
696
697        adaptive_threshold_mean(&input, &mut output, width, height, 3, 10)
698            .expect("adaptive threshold mean should succeed");
699
700        // Uniform input should produce mostly uniform output
701        assert!(output.iter().filter(|&&x| x == 255).count() > 50);
702    }
703
704    #[test]
705    fn test_multi_threshold() {
706        let input = vec![10, 50, 100, 150, 200, 250];
707        let mut output = vec![0; 6];
708        let thresholds = vec![64, 128, 192];
709        let levels = vec![0, 85, 170, 255];
710
711        multi_threshold(&input, &mut output, &thresholds, &levels)
712            .expect("multi-level threshold should succeed");
713
714        assert_eq!(output[0], 0); // < 64
715        assert_eq!(output[1], 0); // < 64
716        assert_eq!(output[2], 85); // >= 64, < 128
717        assert_eq!(output[3], 170); // >= 128, < 192
718        assert_eq!(output[4], 255); // >= 192
719        assert_eq!(output[5], 255); // >= 192
720    }
721
722    #[test]
723    fn test_hysteresis_threshold() {
724        let width = 5;
725        let height = 5;
726        let mut input = vec![0u8; width * height];
727
728        // Create a strong edge
729        input[2 * width + 2] = 200;
730        // Create weak edges connected to strong edge
731        input[2 * width + 1] = 80;
732        input[width + 2] = 80;
733        // Create isolated weak edge
734        input[4 * width + 4] = 80;
735
736        let mut output = vec![0u8; width * height];
737        hysteresis_threshold(&input, &mut output, width, height, 50, 150)
738            .expect("hysteresis threshold should succeed");
739
740        // Strong edge should be marked
741        assert_eq!(output[2 * width + 2], 255);
742        // Connected weak edges should be marked
743        assert_eq!(output[2 * width + 1], 255);
744        assert_eq!(output[width + 2], 255);
745        // Isolated weak edge should not be marked
746        assert_eq!(output[4 * width + 4], 0);
747    }
748
749    #[test]
750    fn test_buffer_size_mismatch() {
751        let input = vec![0u8; 10];
752        let mut output = vec![0u8; 5]; // Wrong size
753
754        let result = binary_threshold(&input, &mut output, 128, 255, 0);
755        assert!(result.is_err());
756    }
757
758    #[test]
759    fn test_invalid_range() {
760        let input = vec![0u8; 10];
761        let mut output = vec![0u8; 10];
762
763        let result = threshold_range(&input, &mut output, 200, 100);
764        assert!(result.is_err());
765    }
766}