Skip to main content

corrmatch/bank/
mod.rs

1//! Precomputed template assets for coarse-to-fine search.
2//!
3//! Compiling template assets once amortizes the cost of building pyramids and
4//! rotated variants across multiple match calls. Each cached rotation stores
5//! precomputed masked plans (ZNCC and SSD) for fast score evaluation.
6//! Rotated templates are cached lazily per level; each angle slot is populated
7//! at most once and stored in a `OnceLock` for thread-safe reuse when parallel
8//! search is introduced later.
9
10mod angles;
11
12pub use angles::AngleGrid;
13
14use crate::image::pyramid::{downsample_u8_2x2_box, ImagePyramid};
15use crate::image::{ImageView, OwnedImage};
16use crate::template::rotate::rotate_u8_bilinear_masked;
17use crate::template::{
18    MaskedSsdTemplatePlan, MaskedTemplatePlan, SsdTemplatePlan, Template, TemplatePlan,
19};
20use crate::util::{CorrMatchError, CorrMatchResult};
21#[cfg(feature = "rayon")]
22use rayon::prelude::*;
23use std::sync::{Arc, OnceLock};
24
25fn trim_degenerate_levels(levels: &mut Vec<OwnedImage>, min_dim: usize) -> CorrMatchResult<()> {
26    let mut last_err: Option<CorrMatchError> = None;
27    loop {
28        let level = match levels.last() {
29            Some(level) => level,
30            None => {
31                return Err(last_err.unwrap_or(CorrMatchError::DegenerateTemplate {
32                    reason: "zero variance",
33                }));
34            }
35        };
36
37        if level.width() < min_dim || level.height() < min_dim {
38            levels.pop();
39            last_err = Some(CorrMatchError::DegenerateTemplate {
40                reason: "template too small for rotation",
41            });
42            continue;
43        }
44
45        match TemplatePlan::from_view(level.view()) {
46            Ok(_) => return Ok(()),
47            Err(err @ CorrMatchError::DegenerateTemplate { .. }) => {
48                levels.pop();
49                last_err = Some(err);
50            }
51            Err(err) => return Err(err),
52        }
53    }
54}
55
56fn downsample_mask(mask: &[u8], width: usize, height: usize) -> CorrMatchResult<Vec<u8>> {
57    let needed = width
58        .checked_mul(height)
59        .ok_or(CorrMatchError::InvalidDimensions { width, height })?;
60    if mask.len() < needed {
61        return Err(CorrMatchError::BufferTooSmall {
62            needed,
63            got: mask.len(),
64        });
65    }
66    if mask.len() > needed {
67        return Err(CorrMatchError::InvalidDimensions { width, height });
68    }
69    if width < 2 || height < 2 {
70        return Err(CorrMatchError::InvalidDimensions { width, height });
71    }
72
73    let dst_width = width / 2;
74    let dst_height = height / 2;
75    let dst_len = dst_width
76        .checked_mul(dst_height)
77        .ok_or(CorrMatchError::InvalidDimensions {
78            width: dst_width,
79            height: dst_height,
80        })?;
81    let mut dst = vec![0u8; dst_len];
82
83    for y in 0..dst_height {
84        let row0 = &mask[(y * 2) * width..(y * 2) * width + width];
85        let row1 = &mask[(y * 2 + 1) * width..(y * 2 + 1) * width + width];
86        for x in 0..dst_width {
87            let idx = 2 * x;
88            let m = row0[idx] & row0[idx + 1] & row1[idx] & row1[idx + 1];
89            dst[y * dst_width + x] = if m == 0 { 0 } else { 1 };
90        }
91    }
92
93    Ok(dst)
94}
95
96fn rotate_downsample_to_level(
97    base: ImageView<'_, u8>,
98    angle: f32,
99    fill: u8,
100    level: usize,
101) -> CorrMatchResult<(OwnedImage, Vec<u8>)> {
102    let (mut img, mut mask) = rotate_u8_bilinear_masked(base, angle, fill);
103    for _ in 0..level {
104        let view = img.view();
105        let next_img = downsample_u8_2x2_box(view)?;
106        let next_mask = downsample_mask(&mask, view.width(), view.height())?;
107        img = next_img;
108        mask = next_mask;
109    }
110    Ok((img, mask))
111}
112
113/// Configuration for compiling template assets with rotation support.
114#[derive(Clone, Debug)]
115pub struct CompileConfig {
116    /// Maximum pyramid levels to build.
117    pub max_levels: usize,
118    /// Coarse rotation step in degrees at level 0.
119    pub coarse_step_deg: f32,
120    /// Minimum rotation step in degrees across levels.
121    pub min_step_deg: f32,
122    /// Fill value used for out-of-bounds rotations.
123    pub fill_value: u8,
124    /// Precompute all rotations for the coarsest level.
125    pub precompute_coarsest: bool,
126}
127
128impl Default for CompileConfig {
129    fn default() -> Self {
130        Self {
131            max_levels: 6,
132            coarse_step_deg: 10.0,
133            min_step_deg: 0.5,
134            fill_value: 0,
135            precompute_coarsest: true,
136        }
137    }
138}
139
140impl CompileConfig {
141    /// Validates the configuration, returning an error if any parameter is invalid.
142    pub fn validate(&self) -> CorrMatchResult<()> {
143        if self.max_levels == 0 {
144            return Err(CorrMatchError::InvalidConfig {
145                reason: "max_levels must be at least 1",
146            });
147        }
148        if !self.coarse_step_deg.is_finite() || self.coarse_step_deg <= 0.0 {
149            return Err(CorrMatchError::InvalidConfig {
150                reason: "coarse_step_deg must be a positive finite value",
151            });
152        }
153        if !self.min_step_deg.is_finite() || self.min_step_deg <= 0.0 {
154            return Err(CorrMatchError::InvalidConfig {
155                reason: "min_step_deg must be a positive finite value",
156            });
157        }
158        if self.min_step_deg > self.coarse_step_deg {
159            return Err(CorrMatchError::InvalidConfig {
160                reason: "min_step_deg must not exceed coarse_step_deg",
161            });
162        }
163        Ok(())
164    }
165}
166
167/// Configuration for compiling template assets without rotation support.
168#[derive(Clone, Debug)]
169pub struct CompileConfigNoRot {
170    /// Maximum pyramid levels to build.
171    pub max_levels: usize,
172}
173
174impl Default for CompileConfigNoRot {
175    fn default() -> Self {
176        Self { max_levels: 6 }
177    }
178}
179
180pub(crate) struct RotatedTemplate {
181    angle_deg: f32,
182    zncc: MaskedTemplatePlan,
183    ssd: MaskedSsdTemplatePlan,
184}
185
186impl RotatedTemplate {
187    pub(crate) fn zncc_plan(&self) -> &MaskedTemplatePlan {
188        &self.zncc
189    }
190
191    pub(crate) fn ssd_plan(&self) -> &MaskedSsdTemplatePlan {
192        &self.ssd
193    }
194}
195
196struct LevelBank {
197    grid: AngleGrid,
198    slots: Vec<OnceLock<RotatedTemplate>>,
199}
200
201/// Compiled template assets with rotation support.
202pub struct CompiledTemplateRot {
203    levels: Vec<OwnedImage>,
204    banks: Vec<LevelBank>,
205    unmasked_zncc: Vec<TemplatePlan>,
206    unmasked_ssd: Vec<SsdTemplatePlan>,
207    cfg: CompileConfig,
208}
209
210impl CompiledTemplateRot {
211    /// Compiles template assets for matching with rotation support.
212    pub fn compile(tpl: &Template, cfg: CompileConfig) -> CorrMatchResult<Self> {
213        let _span = trace_span!(
214            "compile_template",
215            rotation = true,
216            max_levels = cfg.max_levels
217        )
218        .entered();
219
220        let pyramid = ImagePyramid::build_u8(tpl.view(), cfg.max_levels)?;
221        let mut levels = pyramid.into_levels();
222        trim_degenerate_levels(&mut levels, 3)?;
223
224        let mut unmasked_zncc = Vec::with_capacity(levels.len());
225        let mut unmasked_ssd = Vec::with_capacity(levels.len());
226        for level in levels.iter() {
227            unmasked_zncc.push(TemplatePlan::from_view(level.view())?);
228            unmasked_ssd.push(SsdTemplatePlan::from_view(level.view())?);
229        }
230
231        let mut banks = Vec::with_capacity(levels.len());
232        let coarsest_idx = levels.len().saturating_sub(1);
233        for (level_idx, _level) in levels.iter().enumerate() {
234            let shift = coarsest_idx.saturating_sub(level_idx);
235            let factor = (1u64.checked_shl(shift as u32).unwrap_or(u64::MAX)) as f32;
236            let step = (cfg.coarse_step_deg / factor).max(cfg.min_step_deg);
237            let grid = AngleGrid::full(step)?;
238            let slots = (0..grid.len()).map(|_| OnceLock::new()).collect();
239            banks.push(LevelBank { grid, slots });
240        }
241
242        if cfg.precompute_coarsest {
243            let coarsest_idx = levels.len().saturating_sub(1);
244            let base = levels.first().ok_or(CorrMatchError::IndexOutOfBounds {
245                index: 0,
246                len: levels.len(),
247                context: "level",
248            })?;
249            let coarsest = levels
250                .get(coarsest_idx)
251                .ok_or(CorrMatchError::IndexOutOfBounds {
252                    index: coarsest_idx,
253                    len: levels.len(),
254                    context: "level",
255                })?;
256            if let Some(bank) = banks.get_mut(coarsest_idx) {
257                let _precompute_span =
258                    trace_span!("precompute_rotations", count = bank.grid.len()).entered();
259
260                #[cfg(feature = "rayon")]
261                {
262                    // Parallel precomputation of rotated templates
263                    let angles: Vec<(usize, f32)> = bank.grid.iter().enumerate().collect();
264                    let results: Vec<CorrMatchResult<(usize, RotatedTemplate)>> = angles
265                        .into_par_iter()
266                        .map(|(idx, angle)| {
267                            let (rotated_img, mask) = rotate_downsample_to_level(
268                                base.view(),
269                                angle,
270                                cfg.fill_value,
271                                coarsest_idx,
272                            )?;
273                            debug_assert_eq!(rotated_img.width(), coarsest.width());
274                            debug_assert_eq!(rotated_img.height(), coarsest.height());
275                            let mask: Arc<[u8]> = Arc::from(mask);
276                            let zncc_plan = MaskedTemplatePlan::from_rotated_parts(
277                                rotated_img.view(),
278                                mask.clone(),
279                                angle,
280                            )?;
281                            let ssd_plan = MaskedSsdTemplatePlan::from_rotated_parts(
282                                rotated_img.view(),
283                                mask,
284                                angle,
285                            )?;
286                            let rotated = RotatedTemplate {
287                                angle_deg: angle,
288                                zncc: zncc_plan,
289                                ssd: ssd_plan,
290                            };
291                            Ok((idx, rotated))
292                        })
293                        .collect();
294
295                    // Store results (checking for errors)
296                    for result in results {
297                        let (idx, rotated) = result?;
298                        let _ = bank.slots[idx].set(rotated);
299                    }
300                }
301
302                #[cfg(not(feature = "rayon"))]
303                {
304                    // Sequential precomputation
305                    for (idx, angle) in bank.grid.iter().enumerate() {
306                        let (rotated_img, mask) = rotate_downsample_to_level(
307                            base.view(),
308                            angle,
309                            cfg.fill_value,
310                            coarsest_idx,
311                        )?;
312                        debug_assert_eq!(rotated_img.width(), coarsest.width());
313                        debug_assert_eq!(rotated_img.height(), coarsest.height());
314                        let mask: Arc<[u8]> = Arc::from(mask);
315                        let zncc_plan = MaskedTemplatePlan::from_rotated_parts(
316                            rotated_img.view(),
317                            mask.clone(),
318                            angle,
319                        )?;
320                        let ssd_plan = MaskedSsdTemplatePlan::from_rotated_parts(
321                            rotated_img.view(),
322                            mask,
323                            angle,
324                        )?;
325                        let rotated = RotatedTemplate {
326                            angle_deg: angle,
327                            zncc: zncc_plan,
328                            ssd: ssd_plan,
329                        };
330                        let _ = bank.slots[idx].set(rotated);
331                    }
332                }
333            }
334        }
335
336        Ok(Self {
337            levels,
338            banks,
339            unmasked_zncc,
340            unmasked_ssd,
341            cfg,
342        })
343    }
344
345    /// Returns the number of pyramid levels.
346    pub fn num_levels(&self) -> usize {
347        self.levels.len()
348    }
349
350    /// Returns the width and height for a pyramid level.
351    pub fn level_size(&self, level: usize) -> Option<(usize, usize)> {
352        self.levels
353            .get(level)
354            .map(|img| (img.width(), img.height()))
355    }
356
357    /// Returns the angle grid for a pyramid level.
358    pub fn angle_grid(&self, level: usize) -> Option<&AngleGrid> {
359        self.banks.get(level).map(|bank| &bank.grid)
360    }
361
362    /// Returns an unmasked ZNCC template plan for a given level.
363    pub fn unmasked_zncc_plan(&self, level: usize) -> CorrMatchResult<&TemplatePlan> {
364        self.unmasked_zncc
365            .get(level)
366            .ok_or(CorrMatchError::IndexOutOfBounds {
367                index: level,
368                len: self.unmasked_zncc.len(),
369                context: "level",
370            })
371    }
372
373    /// Returns an unmasked SSD template plan for a given level.
374    pub fn unmasked_ssd_plan(&self, level: usize) -> CorrMatchResult<&SsdTemplatePlan> {
375        self.unmasked_ssd
376            .get(level)
377            .ok_or(CorrMatchError::IndexOutOfBounds {
378                index: level,
379                len: self.unmasked_ssd.len(),
380                context: "level",
381            })
382    }
383
384    pub(crate) fn rotated(
385        &self,
386        level: usize,
387        angle_idx: usize,
388    ) -> CorrMatchResult<&RotatedTemplate> {
389        let bank = self
390            .banks
391            .get(level)
392            .ok_or(CorrMatchError::IndexOutOfBounds {
393                index: level,
394                len: self.banks.len(),
395                context: "level",
396            })?;
397        let slot = bank
398            .slots
399            .get(angle_idx)
400            .ok_or(CorrMatchError::IndexOutOfBounds {
401                index: angle_idx,
402                len: bank.slots.len(),
403                context: "angle_idx",
404            })?;
405        let level_img = self
406            .levels
407            .get(level)
408            .ok_or(CorrMatchError::IndexOutOfBounds {
409                index: level,
410                len: self.levels.len(),
411                context: "level",
412            })?;
413        let angle = bank.grid.angle_at(angle_idx);
414        if let Some(rotated) = slot.get() {
415            debug_assert!((rotated.angle_deg - angle).abs() < 1e-6);
416            debug_assert_eq!(rotated.zncc.width(), level_img.width());
417            debug_assert_eq!(rotated.zncc.height(), level_img.height());
418            return Ok(rotated);
419        }
420        let base = self
421            .levels
422            .first()
423            .ok_or(CorrMatchError::IndexOutOfBounds {
424                index: 0,
425                len: self.levels.len(),
426                context: "level",
427            })?;
428        let (rotated_img, mask) =
429            rotate_downsample_to_level(base.view(), angle, self.cfg.fill_value, level)?;
430        debug_assert_eq!(rotated_img.width(), level_img.width());
431        debug_assert_eq!(rotated_img.height(), level_img.height());
432        let mask: Arc<[u8]> = Arc::from(mask);
433        let zncc_plan =
434            MaskedTemplatePlan::from_rotated_parts(rotated_img.view(), mask.clone(), angle)?;
435        let ssd_plan = MaskedSsdTemplatePlan::from_rotated_parts(rotated_img.view(), mask, angle)?;
436        let rotated = RotatedTemplate {
437            angle_deg: angle,
438            zncc: zncc_plan,
439            ssd: ssd_plan,
440        };
441        let _ = slot.set(rotated);
442        Ok(slot.get().expect("rotated template should be initialized"))
443    }
444}
445
446/// Compiled template assets without rotation support.
447pub struct CompiledTemplateNoRot {
448    levels: Vec<OwnedImage>,
449    unmasked_zncc: Vec<TemplatePlan>,
450    unmasked_ssd: Vec<SsdTemplatePlan>,
451}
452
453impl CompiledTemplateNoRot {
454    /// Compiles template assets without rotation support.
455    pub fn compile(tpl: &Template, cfg: CompileConfigNoRot) -> CorrMatchResult<Self> {
456        let _span = trace_span!(
457            "compile_template",
458            rotation = false,
459            max_levels = cfg.max_levels
460        )
461        .entered();
462
463        let pyramid = ImagePyramid::build_u8(tpl.view(), cfg.max_levels)?;
464        let mut levels = pyramid.into_levels();
465        trim_degenerate_levels(&mut levels, 1)?;
466
467        let mut unmasked_zncc = Vec::with_capacity(levels.len());
468        let mut unmasked_ssd = Vec::with_capacity(levels.len());
469        for level in levels.iter() {
470            unmasked_zncc.push(TemplatePlan::from_view(level.view())?);
471            unmasked_ssd.push(SsdTemplatePlan::from_view(level.view())?);
472        }
473
474        Ok(Self {
475            levels,
476            unmasked_zncc,
477            unmasked_ssd,
478        })
479    }
480
481    /// Returns the number of pyramid levels.
482    pub fn num_levels(&self) -> usize {
483        self.levels.len()
484    }
485
486    /// Returns the width and height for a pyramid level.
487    pub fn level_size(&self, level: usize) -> Option<(usize, usize)> {
488        self.levels
489            .get(level)
490            .map(|img| (img.width(), img.height()))
491    }
492
493    /// Returns an unmasked ZNCC template plan for a given level.
494    pub fn unmasked_zncc_plan(&self, level: usize) -> CorrMatchResult<&TemplatePlan> {
495        self.unmasked_zncc
496            .get(level)
497            .ok_or(CorrMatchError::IndexOutOfBounds {
498                index: level,
499                len: self.unmasked_zncc.len(),
500                context: "level",
501            })
502    }
503
504    /// Returns an unmasked SSD template plan for a given level.
505    pub fn unmasked_ssd_plan(&self, level: usize) -> CorrMatchResult<&SsdTemplatePlan> {
506        self.unmasked_ssd
507            .get(level)
508            .ok_or(CorrMatchError::IndexOutOfBounds {
509                index: level,
510                len: self.unmasked_ssd.len(),
511                context: "level",
512            })
513    }
514}
515
516/// Compiled template assets for rotated or unrotated matching.
517///
518/// Use `Template::compile`/`CompiledTemplate::compile_rotated` when rotation
519/// search is required, or `CompiledTemplate::compile_unrotated` for the fast
520/// translation-only path.
521pub enum CompiledTemplate {
522    /// Rotation-enabled assets.
523    Rotated(CompiledTemplateRot),
524    /// Rotation-disabled assets.
525    Unrotated(CompiledTemplateNoRot),
526}
527
528impl CompiledTemplate {
529    /// Compiles rotation-enabled template assets.
530    pub fn compile_rotated(tpl: &Template, cfg: CompileConfig) -> CorrMatchResult<Self> {
531        Ok(Self::Rotated(CompiledTemplateRot::compile(tpl, cfg)?))
532    }
533
534    /// Compiles rotation-disabled template assets.
535    pub fn compile_unrotated(tpl: &Template, cfg: CompileConfigNoRot) -> CorrMatchResult<Self> {
536        Ok(Self::Unrotated(CompiledTemplateNoRot::compile(tpl, cfg)?))
537    }
538
539    /// Compiles rotation-enabled template assets (backwards-compatible default).
540    pub fn compile(tpl: &Template, cfg: CompileConfig) -> CorrMatchResult<Self> {
541        Self::compile_rotated(tpl, cfg)
542    }
543
544    /// Returns the number of pyramid levels.
545    pub fn num_levels(&self) -> usize {
546        match self {
547            Self::Rotated(rot) => rot.num_levels(),
548            Self::Unrotated(unrot) => unrot.num_levels(),
549        }
550    }
551
552    /// Returns the width and height for a pyramid level.
553    pub fn level_size(&self, level: usize) -> Option<(usize, usize)> {
554        match self {
555            Self::Rotated(rot) => rot.level_size(level),
556            Self::Unrotated(unrot) => unrot.level_size(level),
557        }
558    }
559
560    /// Returns the angle grid for a pyramid level.
561    pub fn angle_grid(&self, level: usize) -> Option<&AngleGrid> {
562        match self {
563            Self::Rotated(rot) => rot.angle_grid(level),
564            Self::Unrotated(_) => None,
565        }
566    }
567
568    /// Returns an unmasked ZNCC template plan for a given level.
569    pub fn unmasked_zncc_plan(&self, level: usize) -> CorrMatchResult<&TemplatePlan> {
570        match self {
571            Self::Rotated(rot) => rot.unmasked_zncc_plan(level),
572            Self::Unrotated(unrot) => unrot.unmasked_zncc_plan(level),
573        }
574    }
575
576    /// Returns an unmasked SSD template plan for a given level.
577    pub fn unmasked_ssd_plan(&self, level: usize) -> CorrMatchResult<&SsdTemplatePlan> {
578        match self {
579            Self::Rotated(rot) => rot.unmasked_ssd_plan(level),
580            Self::Unrotated(unrot) => unrot.unmasked_ssd_plan(level),
581        }
582    }
583
584    /// Returns the rotated template entry for a given level and angle.
585    pub(crate) fn rotated(
586        &self,
587        level: usize,
588        angle_idx: usize,
589    ) -> CorrMatchResult<&RotatedTemplate> {
590        match self {
591            Self::Rotated(rot) => rot.rotated(level, angle_idx),
592            Self::Unrotated(_) => Err(CorrMatchError::RotationUnavailable {
593                reason: "compiled without rotation support",
594            }),
595        }
596    }
597
598    /// Returns a masked ZNCC template plan for a given level and angle.
599    pub fn rotated_zncc_plan(
600        &self,
601        level: usize,
602        angle_idx: usize,
603    ) -> CorrMatchResult<&MaskedTemplatePlan> {
604        Ok(self.rotated(level, angle_idx)?.zncc_plan())
605    }
606
607    /// Returns a masked SSD template plan for a given level and angle.
608    pub fn rotated_ssd_plan(
609        &self,
610        level: usize,
611        angle_idx: usize,
612    ) -> CorrMatchResult<&MaskedSsdTemplatePlan> {
613        Ok(self.rotated(level, angle_idx)?.ssd_plan())
614    }
615}