1use rayon::prelude::*;
2use yscv_tensor::{AlignedVec, Tensor};
3
4use super::super::ImgProcError;
5use super::super::shape::hwc_shape;
6
7use super::geometry::sobel_3x3_gradients;
8
9const RAYON_THRESHOLD: usize = 4096;
10
11#[allow(unsafe_code, clippy::uninit_vec)]
13pub fn threshold_binary(
14 input: &Tensor,
15 threshold: f32,
16 max_val: f32,
17) -> Result<Tensor, ImgProcError> {
18 let (h, w, channels) = hwc_shape(input)?;
19 let data = input.data();
20 let len = data.len();
21 let mut out = AlignedVec::<f32>::uninitialized(len);
23
24 let row_len = w * channels;
25
26 #[cfg(target_os = "macos")]
27 if len >= RAYON_THRESHOLD && !cfg!(miri) {
28 let src_ptr = data.as_ptr() as usize;
29 let dst_ptr = out.as_mut_ptr() as usize;
30 use super::u8ops::gcd;
31 gcd::parallel_for(h, |y| {
32 let src = unsafe {
33 std::slice::from_raw_parts((src_ptr as *const f32).add(y * row_len), row_len)
34 };
35 let dst = unsafe {
36 std::slice::from_raw_parts_mut((dst_ptr as *mut f32).add(y * row_len), row_len)
37 };
38 threshold_binary_simd_slice(src, dst, threshold, max_val);
39 });
40 return Tensor::from_aligned(vec![h, w, channels], out).map_err(Into::into);
41 }
42
43 if len >= RAYON_THRESHOLD {
44 out.par_chunks_mut(row_len)
45 .enumerate()
46 .for_each(|(y, dst)| {
47 let src = &data[y * row_len..(y + 1) * row_len];
48 threshold_binary_simd_slice(src, dst, threshold, max_val);
49 });
50 } else {
51 threshold_binary_simd_slice(data, &mut out, threshold, max_val);
52 }
53 Tensor::from_aligned(vec![h, w, channels], out).map_err(Into::into)
54}
55
56#[allow(unsafe_code, clippy::uninit_vec)]
58pub fn threshold_binary_inv(
59 input: &Tensor,
60 threshold: f32,
61 max_val: f32,
62) -> Result<Tensor, ImgProcError> {
63 let (h, w, channels) = hwc_shape(input)?;
64 let data = input.data();
65 let len = data.len();
66 let mut out = AlignedVec::<f32>::uninitialized(len);
68
69 let row_len = w * channels;
70
71 #[cfg(target_os = "macos")]
72 if len >= RAYON_THRESHOLD && !cfg!(miri) {
73 let src_ptr = data.as_ptr() as usize;
74 let dst_ptr = out.as_mut_ptr() as usize;
75 use super::u8ops::gcd;
76 gcd::parallel_for(h, |y| {
77 let src = unsafe {
78 std::slice::from_raw_parts((src_ptr as *const f32).add(y * row_len), row_len)
79 };
80 let dst = unsafe {
81 std::slice::from_raw_parts_mut((dst_ptr as *mut f32).add(y * row_len), row_len)
82 };
83 threshold_binary_inv_simd_slice(src, dst, threshold, max_val);
84 });
85 return Tensor::from_aligned(vec![h, w, channels], out).map_err(Into::into);
86 }
87
88 if len >= RAYON_THRESHOLD {
89 out.par_chunks_mut(row_len)
90 .enumerate()
91 .for_each(|(y, dst)| {
92 let src = &data[y * row_len..(y + 1) * row_len];
93 threshold_binary_inv_simd_slice(src, dst, threshold, max_val);
94 });
95 } else {
96 threshold_binary_inv_simd_slice(data, &mut out, threshold, max_val);
97 }
98 Tensor::from_aligned(vec![h, w, channels], out).map_err(Into::into)
99}
100
101#[allow(unsafe_code, clippy::uninit_vec)]
103pub fn threshold_truncate(input: &Tensor, threshold: f32) -> Result<Tensor, ImgProcError> {
104 let (h, w, channels) = hwc_shape(input)?;
105 let data = input.data();
106 let len = data.len();
107 let mut out = AlignedVec::<f32>::uninitialized(len);
109
110 let row_len = w * channels;
111
112 #[cfg(target_os = "macos")]
113 if len >= RAYON_THRESHOLD && !cfg!(miri) {
114 let src_ptr = data.as_ptr() as usize;
115 let dst_ptr = out.as_mut_ptr() as usize;
116 use super::u8ops::gcd;
117 gcd::parallel_for(h, |y| {
118 let src = unsafe {
119 std::slice::from_raw_parts((src_ptr as *const f32).add(y * row_len), row_len)
120 };
121 let dst = unsafe {
122 std::slice::from_raw_parts_mut((dst_ptr as *mut f32).add(y * row_len), row_len)
123 };
124 threshold_truncate_simd_slice(src, dst, threshold);
125 });
126 return Tensor::from_aligned(vec![h, w, channels], out).map_err(Into::into);
127 }
128
129 if len >= RAYON_THRESHOLD {
130 out.par_chunks_mut(row_len)
131 .enumerate()
132 .for_each(|(y, dst)| {
133 let src = &data[y * row_len..(y + 1) * row_len];
134 threshold_truncate_simd_slice(src, dst, threshold);
135 });
136 } else {
137 threshold_truncate_simd_slice(data, &mut out, threshold);
138 }
139 Tensor::from_aligned(vec![h, w, channels], out).map_err(Into::into)
140}
141
142#[allow(unsafe_code)]
144#[inline(always)]
145fn threshold_binary_simd_slice(src: &[f32], dst: &mut [f32], threshold: f32, max_val: f32) {
146 debug_assert_eq!(src.len(), dst.len());
147 let len = src.len();
148 let mut i = 0usize;
149
150 if !cfg!(miri) {
151 #[cfg(target_arch = "aarch64")]
152 {
153 if std::arch::is_aarch64_feature_detected!("neon") {
154 i = unsafe {
156 threshold_binary_neon(src.as_ptr(), dst.as_mut_ptr(), len, threshold, max_val)
157 };
158 }
159 }
160 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
161 {
162 if std::is_x86_feature_detected!("avx") {
163 i = unsafe {
165 threshold_binary_avx(src.as_ptr(), dst.as_mut_ptr(), len, threshold, max_val)
166 };
167 } else if std::is_x86_feature_detected!("sse") {
168 i = unsafe {
170 threshold_binary_sse(src.as_ptr(), dst.as_mut_ptr(), len, threshold, max_val)
171 };
172 }
173 }
174 }
175
176 while i < len {
178 dst[i] = if src[i] > threshold { max_val } else { 0.0 };
179 i += 1;
180 }
181}
182
183#[cfg(target_arch = "aarch64")]
184#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
185#[target_feature(enable = "neon")]
186unsafe fn threshold_binary_neon(
187 src: *const f32,
188 dst: *mut f32,
189 len: usize,
190 threshold: f32,
191 max_val: f32,
192) -> usize {
193 use std::arch::aarch64::*;
194 let thresh_v = vdupq_n_f32(threshold);
195 let max_v = vdupq_n_f32(max_val);
196 let zero_v = vdupq_n_f32(0.0);
197 let mut x = 0usize;
198 while x + 32 <= len {
200 let v0 = vld1q_f32(src.add(x));
201 let v1 = vld1q_f32(src.add(x + 4));
202 let v2 = vld1q_f32(src.add(x + 8));
203 let v3 = vld1q_f32(src.add(x + 12));
204 let v4 = vld1q_f32(src.add(x + 16));
205 let v5 = vld1q_f32(src.add(x + 20));
206 let v6 = vld1q_f32(src.add(x + 24));
207 let v7 = vld1q_f32(src.add(x + 28));
208 vst1q_f32(
209 dst.add(x),
210 vbslq_f32(vcgtq_f32(v0, thresh_v), max_v, zero_v),
211 );
212 vst1q_f32(
213 dst.add(x + 4),
214 vbslq_f32(vcgtq_f32(v1, thresh_v), max_v, zero_v),
215 );
216 vst1q_f32(
217 dst.add(x + 8),
218 vbslq_f32(vcgtq_f32(v2, thresh_v), max_v, zero_v),
219 );
220 vst1q_f32(
221 dst.add(x + 12),
222 vbslq_f32(vcgtq_f32(v3, thresh_v), max_v, zero_v),
223 );
224 vst1q_f32(
225 dst.add(x + 16),
226 vbslq_f32(vcgtq_f32(v4, thresh_v), max_v, zero_v),
227 );
228 vst1q_f32(
229 dst.add(x + 20),
230 vbslq_f32(vcgtq_f32(v5, thresh_v), max_v, zero_v),
231 );
232 vst1q_f32(
233 dst.add(x + 24),
234 vbslq_f32(vcgtq_f32(v6, thresh_v), max_v, zero_v),
235 );
236 vst1q_f32(
237 dst.add(x + 28),
238 vbslq_f32(vcgtq_f32(v7, thresh_v), max_v, zero_v),
239 );
240 x += 32;
241 }
242 while x + 4 <= len {
243 let v = vld1q_f32(src.add(x));
244 let mask = vcgtq_f32(v, thresh_v);
245 let result = vbslq_f32(mask, max_v, zero_v);
246 vst1q_f32(dst.add(x), result);
247 x += 4;
248 }
249 x
250}
251
252#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
253#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
254#[target_feature(enable = "avx")]
255unsafe fn threshold_binary_avx(
256 src: *const f32,
257 dst: *mut f32,
258 len: usize,
259 threshold: f32,
260 max_val: f32,
261) -> usize {
262 #[cfg(target_arch = "x86")]
263 use std::arch::x86::*;
264 #[cfg(target_arch = "x86_64")]
265 use std::arch::x86_64::*;
266 let thresh_v = _mm256_set1_ps(threshold);
267 let max_v = _mm256_set1_ps(max_val);
268 let mut x = 0usize;
269 while x + 32 <= len {
271 let v0 = _mm256_loadu_ps(src.add(x));
272 let v1 = _mm256_loadu_ps(src.add(x + 8));
273 let v2 = _mm256_loadu_ps(src.add(x + 16));
274 let v3 = _mm256_loadu_ps(src.add(x + 24));
275 _mm256_storeu_ps(
276 dst.add(x),
277 _mm256_and_ps(_mm256_cmp_ps::<14>(v0, thresh_v), max_v),
278 );
279 _mm256_storeu_ps(
280 dst.add(x + 8),
281 _mm256_and_ps(_mm256_cmp_ps::<14>(v1, thresh_v), max_v),
282 );
283 _mm256_storeu_ps(
284 dst.add(x + 16),
285 _mm256_and_ps(_mm256_cmp_ps::<14>(v2, thresh_v), max_v),
286 );
287 _mm256_storeu_ps(
288 dst.add(x + 24),
289 _mm256_and_ps(_mm256_cmp_ps::<14>(v3, thresh_v), max_v),
290 );
291 x += 32;
292 }
293 while x + 8 <= len {
294 _mm256_storeu_ps(
295 dst.add(x),
296 _mm256_and_ps(
297 _mm256_cmp_ps::<14>(_mm256_loadu_ps(src.add(x)), thresh_v),
298 max_v,
299 ),
300 );
301 x += 8;
302 }
303 x
304}
305
306#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
307#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
308#[target_feature(enable = "sse")]
309unsafe fn threshold_binary_sse(
310 src: *const f32,
311 dst: *mut f32,
312 len: usize,
313 threshold: f32,
314 max_val: f32,
315) -> usize {
316 #[cfg(target_arch = "x86")]
317 use std::arch::x86::*;
318 #[cfg(target_arch = "x86_64")]
319 use std::arch::x86_64::*;
320 let thresh_v = _mm_set1_ps(threshold);
321 let max_v = _mm_set1_ps(max_val);
322 let mut x = 0usize;
323 while x + 4 <= len {
324 let v = _mm_loadu_ps(src.add(x));
325 let mask = _mm_cmpgt_ps(v, thresh_v);
326 let result = _mm_and_ps(mask, max_v);
327 _mm_storeu_ps(dst.add(x), result);
328 x += 4;
329 }
330 x
331}
332
333#[allow(unsafe_code)]
335#[inline(always)]
336fn threshold_binary_inv_simd_slice(src: &[f32], dst: &mut [f32], threshold: f32, max_val: f32) {
337 debug_assert_eq!(src.len(), dst.len());
338 let len = src.len();
339 let mut i = 0usize;
340
341 if !cfg!(miri) {
342 #[cfg(target_arch = "aarch64")]
343 {
344 if std::arch::is_aarch64_feature_detected!("neon") {
345 i = unsafe {
347 threshold_binary_inv_neon(
348 src.as_ptr(),
349 dst.as_mut_ptr(),
350 len,
351 threshold,
352 max_val,
353 )
354 };
355 }
356 }
357 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
358 {
359 if std::is_x86_feature_detected!("avx") {
360 i = unsafe {
362 threshold_binary_inv_avx(
363 src.as_ptr(),
364 dst.as_mut_ptr(),
365 len,
366 threshold,
367 max_val,
368 )
369 };
370 } else if std::is_x86_feature_detected!("sse") {
371 i = unsafe {
373 threshold_binary_inv_sse(
374 src.as_ptr(),
375 dst.as_mut_ptr(),
376 len,
377 threshold,
378 max_val,
379 )
380 };
381 }
382 }
383 }
384
385 while i < len {
387 dst[i] = if src[i] > threshold { 0.0 } else { max_val };
388 i += 1;
389 }
390}
391
392#[cfg(target_arch = "aarch64")]
393#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
394#[target_feature(enable = "neon")]
395unsafe fn threshold_binary_inv_neon(
396 src: *const f32,
397 dst: *mut f32,
398 len: usize,
399 threshold: f32,
400 max_val: f32,
401) -> usize {
402 use std::arch::aarch64::*;
403 let thresh_v = vdupq_n_f32(threshold);
404 let max_v = vdupq_n_f32(max_val);
405 let zero_v = vdupq_n_f32(0.0);
406 let mut x = 0usize;
407 while x + 16 <= len {
409 let v0 = vld1q_f32(src.add(x));
410 let v1 = vld1q_f32(src.add(x + 4));
411 let v2 = vld1q_f32(src.add(x + 8));
412 let v3 = vld1q_f32(src.add(x + 12));
413 vst1q_f32(
414 dst.add(x),
415 vbslq_f32(vcgtq_f32(v0, thresh_v), zero_v, max_v),
416 );
417 vst1q_f32(
418 dst.add(x + 4),
419 vbslq_f32(vcgtq_f32(v1, thresh_v), zero_v, max_v),
420 );
421 vst1q_f32(
422 dst.add(x + 8),
423 vbslq_f32(vcgtq_f32(v2, thresh_v), zero_v, max_v),
424 );
425 vst1q_f32(
426 dst.add(x + 12),
427 vbslq_f32(vcgtq_f32(v3, thresh_v), zero_v, max_v),
428 );
429 x += 16;
430 }
431 while x + 4 <= len {
432 let v = vld1q_f32(src.add(x));
433 let mask = vcgtq_f32(v, thresh_v);
434 let result = vbslq_f32(mask, zero_v, max_v);
435 vst1q_f32(dst.add(x), result);
436 x += 4;
437 }
438 x
439}
440
441#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
442#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
443#[target_feature(enable = "avx")]
444unsafe fn threshold_binary_inv_avx(
445 src: *const f32,
446 dst: *mut f32,
447 len: usize,
448 threshold: f32,
449 max_val: f32,
450) -> usize {
451 #[cfg(target_arch = "x86")]
452 use std::arch::x86::*;
453 #[cfg(target_arch = "x86_64")]
454 use std::arch::x86_64::*;
455 let thresh_v = _mm256_set1_ps(threshold);
456 let max_v = _mm256_set1_ps(max_val);
457 let mut x = 0usize;
458 while x + 8 <= len {
459 let v = _mm256_loadu_ps(src.add(x));
460 let mask = _mm256_cmp_ps::<14>(v, thresh_v);
462 let result = _mm256_andnot_ps(mask, max_v);
463 _mm256_storeu_ps(dst.add(x), result);
464 x += 8;
465 }
466 x
467}
468
469#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
470#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
471#[target_feature(enable = "sse")]
472unsafe fn threshold_binary_inv_sse(
473 src: *const f32,
474 dst: *mut f32,
475 len: usize,
476 threshold: f32,
477 max_val: f32,
478) -> usize {
479 #[cfg(target_arch = "x86")]
480 use std::arch::x86::*;
481 #[cfg(target_arch = "x86_64")]
482 use std::arch::x86_64::*;
483 let thresh_v = _mm_set1_ps(threshold);
484 let max_v = _mm_set1_ps(max_val);
485 let mut x = 0usize;
486 while x + 4 <= len {
487 let v = _mm_loadu_ps(src.add(x));
488 let mask = _mm_cmpgt_ps(v, thresh_v);
489 let result = _mm_andnot_ps(mask, max_v);
490 _mm_storeu_ps(dst.add(x), result);
491 x += 4;
492 }
493 x
494}
495
496#[allow(unsafe_code)]
498#[inline(always)]
499fn threshold_truncate_simd_slice(src: &[f32], dst: &mut [f32], threshold: f32) {
500 debug_assert_eq!(src.len(), dst.len());
501 let len = src.len();
502 let mut i = 0usize;
503
504 if !cfg!(miri) {
505 #[cfg(target_arch = "aarch64")]
506 {
507 if std::arch::is_aarch64_feature_detected!("neon") {
508 i = unsafe {
510 threshold_truncate_neon(src.as_ptr(), dst.as_mut_ptr(), len, threshold)
511 };
512 }
513 }
514 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
515 {
516 if std::is_x86_feature_detected!("avx") {
517 i = unsafe {
519 threshold_truncate_avx(src.as_ptr(), dst.as_mut_ptr(), len, threshold)
520 };
521 } else if std::is_x86_feature_detected!("sse") {
522 i = unsafe {
524 threshold_truncate_sse(src.as_ptr(), dst.as_mut_ptr(), len, threshold)
525 };
526 }
527 }
528 }
529
530 while i < len {
532 dst[i] = if src[i] > threshold {
533 threshold
534 } else {
535 src[i]
536 };
537 i += 1;
538 }
539}
540
541#[cfg(target_arch = "aarch64")]
542#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
543#[target_feature(enable = "neon")]
544unsafe fn threshold_truncate_neon(
545 src: *const f32,
546 dst: *mut f32,
547 len: usize,
548 threshold: f32,
549) -> usize {
550 use std::arch::aarch64::*;
551 let thresh_v = vdupq_n_f32(threshold);
552 let mut x = 0usize;
553 while x + 16 <= len {
555 let v0 = vld1q_f32(src.add(x));
556 let v1 = vld1q_f32(src.add(x + 4));
557 let v2 = vld1q_f32(src.add(x + 8));
558 let v3 = vld1q_f32(src.add(x + 12));
559 vst1q_f32(dst.add(x), vminq_f32(v0, thresh_v));
560 vst1q_f32(dst.add(x + 4), vminq_f32(v1, thresh_v));
561 vst1q_f32(dst.add(x + 8), vminq_f32(v2, thresh_v));
562 vst1q_f32(dst.add(x + 12), vminq_f32(v3, thresh_v));
563 x += 16;
564 }
565 while x + 4 <= len {
566 let v = vld1q_f32(src.add(x));
567 vst1q_f32(dst.add(x), vminq_f32(v, thresh_v));
568 x += 4;
569 }
570 x
571}
572
573#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
574#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
575#[target_feature(enable = "avx")]
576unsafe fn threshold_truncate_avx(
577 src: *const f32,
578 dst: *mut f32,
579 len: usize,
580 threshold: f32,
581) -> usize {
582 #[cfg(target_arch = "x86")]
583 use std::arch::x86::*;
584 #[cfg(target_arch = "x86_64")]
585 use std::arch::x86_64::*;
586 let thresh_v = _mm256_set1_ps(threshold);
587 let mut x = 0usize;
588 while x + 8 <= len {
589 let v = _mm256_loadu_ps(src.add(x));
590 let result = _mm256_min_ps(v, thresh_v);
591 _mm256_storeu_ps(dst.add(x), result);
592 x += 8;
593 }
594 x
595}
596
597#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
598#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
599#[target_feature(enable = "sse")]
600unsafe fn threshold_truncate_sse(
601 src: *const f32,
602 dst: *mut f32,
603 len: usize,
604 threshold: f32,
605) -> usize {
606 #[cfg(target_arch = "x86")]
607 use std::arch::x86::*;
608 #[cfg(target_arch = "x86_64")]
609 use std::arch::x86_64::*;
610 let thresh_v = _mm_set1_ps(threshold);
611 let mut x = 0usize;
612 while x + 4 <= len {
613 let v = _mm_loadu_ps(src.add(x));
614 let result = _mm_min_ps(v, thresh_v);
615 _mm_storeu_ps(dst.add(x), result);
616 x += 4;
617 }
618 x
619}
620
621pub fn threshold_otsu(input: &Tensor, max_val: f32) -> Result<(f32, Tensor), ImgProcError> {
624 let (_h, _w, channels) = hwc_shape(input)?;
625 if channels != 1 {
626 return Err(ImgProcError::InvalidChannelCount {
627 expected: 1,
628 got: channels,
629 });
630 }
631
632 let data = input.data();
634 let total = data.len() as f32;
635 let mut hist = [0u32; 256];
636 for &v in data {
637 let bin = (v.clamp(0.0, 1.0) * 255.0) as usize;
638 hist[bin.min(255)] += 1;
639 }
640
641 let mut sum_total = 0.0f64;
642 for (i, &count) in hist.iter().enumerate() {
643 sum_total += i as f64 * count as f64;
644 }
645
646 let mut sum_bg = 0.0f64;
647 let mut weight_bg = 0.0f64;
648 let mut max_variance = 0.0f64;
649 let mut best_t = 0usize;
650
651 for (t, &count) in hist.iter().enumerate() {
652 weight_bg += count as f64;
653 if weight_bg == 0.0 {
654 continue;
655 }
656 let weight_fg = total as f64 - weight_bg;
657 if weight_fg == 0.0 {
658 break;
659 }
660
661 sum_bg += t as f64 * count as f64;
662 let mean_bg = sum_bg / weight_bg;
663 let mean_fg = (sum_total - sum_bg) / weight_fg;
664 let diff = mean_bg - mean_fg;
665 let variance = weight_bg * weight_fg * diff * diff;
666
667 if variance > max_variance {
668 max_variance = variance;
669 best_t = t;
670 }
671 }
672
673 let threshold = best_t as f32 / 255.0;
674 let thresholded = threshold_binary(input, threshold, max_val)?;
675 Ok((threshold, thresholded))
676}
677
678pub struct CannyScratch {
687 magnitude: Vec<f32>,
688 direction: Vec<u8>,
689 nms: Vec<f32>,
690 edges: Vec<u8>,
691 queue: Vec<usize>,
692}
693
694impl CannyScratch {
695 pub fn new() -> Self {
697 Self {
698 magnitude: Vec::new(),
699 direction: Vec::new(),
700 nms: Vec::new(),
701 edges: Vec::new(),
702 queue: Vec::new(),
703 }
704 }
705
706 fn ensure_capacity(&mut self, len: usize) {
707 self.magnitude.resize(len, 0.0);
708 self.direction.resize(len, 0);
709 self.nms.resize(len, 0.0);
710 self.edges.resize(len, 0);
711 }
712}
713
714impl Default for CannyScratch {
715 fn default() -> Self {
716 Self::new()
717 }
718}
719
720pub fn canny(input: &Tensor, low_thresh: f32, high_thresh: f32) -> Result<Tensor, ImgProcError> {
721 let mut scratch = CannyScratch::new();
722 canny_with_scratch(input, low_thresh, high_thresh, &mut scratch)
723}
724
725pub fn canny_with_scratch(
732 input: &Tensor,
733 low_thresh: f32,
734 high_thresh: f32,
735 scratch: &mut CannyScratch,
736) -> Result<Tensor, ImgProcError> {
737 let (_h, _w, channels) = hwc_shape(input)?;
738 if channels != 1 {
739 return Err(ImgProcError::InvalidChannelCount {
740 expected: 1,
741 got: channels,
742 });
743 }
744
745 let (h, w, _) = hwc_shape(input)?;
746 let (gx, gy) = sobel_3x3_gradients(input)?;
747 let gx_data = gx.data();
748 let gy_data = gy.data();
749 let len = h * w;
750
751 scratch.ensure_capacity(len);
752 let magnitude = &mut scratch.magnitude;
753 let direction = &mut scratch.direction;
754 let nms = &mut scratch.nms;
755 let edges = &mut scratch.edges;
756
757 for i in 0..len {
761 let dx = gx_data[i];
762 let dy = gy_data[i];
763 let adx = dx.abs();
764 let ady = dy.abs();
765 let (big, small) = if adx > ady { (adx, ady) } else { (ady, adx) };
766 magnitude[i] = big + 0.414 * small;
767
768 direction[i] = if ady * 5.0 < adx * 2.0 {
775 0
777 } else if adx * 5.0 < ady * 2.0 {
778 2
780 } else if (dx > 0.0) == (dy > 0.0) {
781 1
783 } else {
784 3
786 };
787 }
788
789 for v in nms.iter_mut() {
791 *v = 0.0;
792 }
793 for y in 1..h.saturating_sub(1) {
794 for x in 1..w.saturating_sub(1) {
795 let idx = y * w + x;
796 let mag = magnitude[idx];
797 let (n1, n2) = match direction[idx] {
798 0 => (magnitude[y * w + x - 1], magnitude[y * w + x + 1]),
799 1 => (
800 magnitude[(y - 1) * w + x + 1],
801 magnitude[(y + 1) * w + x - 1],
802 ),
803 2 => (magnitude[(y - 1) * w + x], magnitude[(y + 1) * w + x]),
804 _ => (
805 magnitude[(y - 1) * w + x - 1],
806 magnitude[(y + 1) * w + x + 1],
807 ),
808 };
809 if mag >= n1 && mag >= n2 {
810 nms[idx] = mag;
811 }
812 }
813 }
814
815 for v in edges.iter_mut() {
817 *v = 0;
818 }
819 scratch.queue.clear();
820
821 for i in 0..len {
823 if nms[i] >= high_thresh {
824 edges[i] = 2;
825 scratch.queue.push(i);
826 } else if nms[i] >= low_thresh {
827 edges[i] = 1;
828 }
829 }
830
831 let mut head = 0;
833 while head < scratch.queue.len() {
834 let idx = scratch.queue[head];
835 head += 1;
836 let y = idx / w;
837 let x = idx % w;
838 if y == 0 || y >= h - 1 || x == 0 || x >= w - 1 {
839 continue;
840 }
841 for dy in [-1isize, 0, 1] {
843 for dx in [-1isize, 0, 1] {
844 if dy == 0 && dx == 0 {
845 continue;
846 }
847 let ny = (y as isize + dy) as usize;
848 let nx = (x as isize + dx) as usize;
849 let ni = ny * w + nx;
850 if edges[ni] == 1 {
851 edges[ni] = 2;
852 scratch.queue.push(ni);
853 }
854 }
855 }
856 }
857
858 let out: Vec<f32> = edges
859 .iter()
860 .map(|&e| if e == 2 { 1.0 } else { 0.0 })
861 .collect();
862 Tensor::from_vec(vec![h, w, 1], out).map_err(Into::into)
863}
864
865pub fn adaptive_threshold_mean(
870 input: &Tensor,
871 max_val: f32,
872 block_size: usize,
873 constant: f32,
874) -> Result<Tensor, ImgProcError> {
875 let (h, w, channels) = hwc_shape(input)?;
876 if channels != 1 {
877 return Err(ImgProcError::InvalidChannelCount {
878 expected: 1,
879 got: channels,
880 });
881 }
882 if block_size == 0 || block_size.is_multiple_of(2) {
883 return Err(ImgProcError::InvalidBlockSize { block_size });
884 }
885
886 let data = input.data();
887 let half = (block_size / 2) as isize;
888 let mut out = vec![0.0f32; h * w];
889
890 for y in 0..h {
891 for x in 0..w {
892 let mut sum = 0.0f32;
893 let mut count = 0u32;
894 for ky in -half..=half {
895 for kx in -half..=half {
896 let sy = y as isize + ky;
897 let sx = x as isize + kx;
898 if sy >= 0 && sy < h as isize && sx >= 0 && sx < w as isize {
899 sum += data[sy as usize * w + sx as usize];
900 count += 1;
901 }
902 }
903 }
904 let local_mean = sum / count as f32;
905 let threshold = local_mean - constant;
906 out[y * w + x] = if data[y * w + x] > threshold {
907 max_val
908 } else {
909 0.0
910 };
911 }
912 }
913
914 Tensor::from_vec(vec![h, w, 1], out).map_err(Into::into)
915}
916
917pub fn adaptive_threshold_gaussian(
921 input: &Tensor,
922 max_val: f32,
923 block_size: usize,
924 constant: f32,
925) -> Result<Tensor, ImgProcError> {
926 let (h, w, channels) = hwc_shape(input)?;
927 if channels != 1 {
928 return Err(ImgProcError::InvalidChannelCount {
929 expected: 1,
930 got: channels,
931 });
932 }
933 if block_size == 0 || block_size.is_multiple_of(2) {
934 return Err(ImgProcError::InvalidBlockSize { block_size });
935 }
936
937 let half = block_size / 2;
938 let sigma = 0.3 * ((block_size as f64 - 1.0) * 0.5 - 1.0) + 0.8;
939 let sigma2 = sigma * sigma;
940
941 let mut kernel = vec![0.0f64; block_size * block_size];
942 let mut ksum = 0.0f64;
943 for ky in 0..block_size {
944 for kx in 0..block_size {
945 let dy = ky as f64 - half as f64;
946 let dx = kx as f64 - half as f64;
947 let val = (-(dy * dy + dx * dx) / (2.0 * sigma2)).exp();
948 kernel[ky * block_size + kx] = val;
949 ksum += val;
950 }
951 }
952 for v in &mut kernel {
953 *v /= ksum;
954 }
955
956 let data = input.data();
957 let half_i = half as isize;
958 let mut out = vec![0.0f32; h * w];
959
960 for y in 0..h {
961 for x in 0..w {
962 let mut wsum = 0.0f64;
963 let mut wnorm = 0.0f64;
964 for ky in -half_i..=half_i {
965 for kx in -half_i..=half_i {
966 let sy = y as isize + ky;
967 let sx = x as isize + kx;
968 if sy >= 0 && sy < h as isize && sx >= 0 && sx < w as isize {
969 let kw =
970 kernel[(ky + half_i) as usize * block_size + (kx + half_i) as usize];
971 wsum += data[sy as usize * w + sx as usize] as f64 * kw;
972 wnorm += kw;
973 }
974 }
975 }
976 let local_mean = (wsum / wnorm) as f32;
977 let threshold = local_mean - constant;
978 out[y * w + x] = if data[y * w + x] > threshold {
979 max_val
980 } else {
981 0.0
982 };
983 }
984 }
985
986 Tensor::from_vec(vec![h, w, 1], out).map_err(Into::into)
987}
988
989pub fn connected_components_4(input: &Tensor) -> Result<(Tensor, usize), ImgProcError> {
994 let (h, w, channels) = hwc_shape(input)?;
995 if channels != 1 {
996 return Err(ImgProcError::InvalidChannelCount {
997 expected: 1,
998 got: channels,
999 });
1000 }
1001
1002 let data = input.data();
1003 let len = h * w;
1004 let mut labels = vec![0u32; len];
1005 let mut next_label = 1u32;
1006 let mut equivalences: Vec<u32> = vec![0];
1007
1008 for y in 0..h {
1009 for x in 0..w {
1010 let idx = y * w + x;
1011 if data[idx] <= 0.0 {
1012 continue;
1013 }
1014 let left = if x > 0 { labels[y * w + x - 1] } else { 0 };
1015 let above = if y > 0 { labels[(y - 1) * w + x] } else { 0 };
1016
1017 match (left > 0, above > 0) {
1018 (false, false) => {
1019 labels[idx] = next_label;
1020 equivalences.push(next_label);
1021 next_label += 1;
1022 }
1023 (true, false) => labels[idx] = left,
1024 (false, true) => labels[idx] = above,
1025 (true, true) => {
1026 let rl = find_root(&equivalences, left);
1027 let ra = find_root(&equivalences, above);
1028 labels[idx] = rl.min(ra);
1029 if rl != ra {
1030 let (lo, hi) = if rl < ra { (rl, ra) } else { (ra, rl) };
1031 equivalences[hi as usize] = lo;
1032 }
1033 }
1034 }
1035 }
1036 }
1037
1038 let mut canonical = vec![0u32; next_label as usize];
1039 let mut label_count = 0u32;
1040 #[allow(clippy::needless_range_loop)]
1041 for i in 1..next_label as usize {
1042 let root = find_root(&equivalences, i as u32);
1043 if root == i as u32 {
1044 label_count += 1;
1045 canonical[i] = label_count;
1046 }
1047 }
1048 #[allow(clippy::needless_range_loop)]
1049 for i in 1..next_label as usize {
1050 let root = find_root(&equivalences, i as u32);
1051 canonical[i] = canonical[root as usize];
1052 }
1053
1054 let out: Vec<f32> = labels
1055 .iter()
1056 .map(|&l| {
1057 if l == 0 {
1058 0.0
1059 } else {
1060 canonical[l as usize] as f32
1061 }
1062 })
1063 .collect();
1064
1065 Ok((Tensor::from_vec(vec![h, w, 1], out)?, label_count as usize))
1066}
1067
1068pub(crate) fn find_root(equiv: &[u32], mut label: u32) -> u32 {
1069 while equiv[label as usize] != label {
1070 label = equiv[label as usize];
1071 }
1072 label
1073}