birdnet_onnx/
rangefilter.rs

1//! Range filter for location and date-based species filtering
2//!
3//! # Usage Examples
4//!
5//! ## Basic Filtering
6//!
7//! ```ignore
8//! use birdnet_onnx::{Classifier, RangeFilter};
9//!
10//! let classifier = Classifier::builder()
11//!     .model_path("birdnet.onnx")
12//!     .labels_path("labels.txt")
13//!     .build()?;
14//!
15//! let range_filter = RangeFilter::builder()
16//!     .model_path("meta_model.onnx")
17//!     .from_classifier_labels(classifier.labels())
18//!     .threshold(0.01)
19//!     .build()?;
20//!
21//! // Get predictions
22//! let result = classifier.predict(&audio_segment)?;
23//!
24//! // Get location scores
25//! let location_scores = range_filter.predict(60.17, 24.94, 6, 15)?;
26//!
27//! // Filter predictions
28//! let filtered = range_filter.filter_predictions(
29//!     &result.predictions,
30//!     &location_scores,
31//!     false, // don't rerank
32//! );
33//! ```
34//!
35//! ## Batch Processing
36//!
37//! ```ignore
38//! // Calculate location scores once
39//! let location_scores = range_filter.predict(lat, lon, month, day)?;
40//!
41//! // Process multiple audio segments
42//! let mut predictions_batch = Vec::new();
43//! for segment in audio_segments {
44//!     let result = classifier.predict(&segment)?;
45//!     predictions_batch.push(result.predictions);
46//! }
47//!
48//! // Filter all at once
49//! let filtered_batch = range_filter.filter_batch_predictions(
50//!     predictions_batch,
51//!     &location_scores,
52//!     true, // rerank by location score
53//! );
54//! ```
55
56use crate::error::{Error, Result};
57use crate::labels::parse_labels;
58use crate::types::{LabelFormat, LocationScore, Prediction};
59use ndarray::Array2;
60use ort::session::Session;
61use ort::value::Value;
62use std::sync::{Arc, Mutex};
63
64/// Calculate week number for `BirdNET` meta model (48-week year, 4 weeks per month).
65///
66/// `BirdNET` assumes each month has exactly 4 weeks, creating a 48-week year.
67/// Week calculation: weeksFromMonths = (month - 1) * 4; weekInMonth = (day - 1) / 7 + 1
68///
69/// # Arguments
70/// * `month` - Month number (1-12)
71/// * `day` - Day of month (1-31)
72///
73/// # Returns
74/// Week number as f32 (typically 1-48, but can exceed 48 for days 29-31)
75#[must_use]
76#[allow(clippy::cast_precision_loss)]
77pub const fn calculate_week(month: u32, day: u32) -> f32 {
78    let weeks_from_months = (month - 1) * 4;
79    let week_in_month = (day - 1) / 7 + 1;
80    (weeks_from_months + week_in_month) as f32
81}
82
83/// Validate geographic coordinates.
84///
85/// # Arguments
86/// * `latitude` - Latitude in degrees (-90 to 90)
87/// * `longitude` - Longitude in degrees (-180 to 180)
88///
89/// # Errors
90/// Returns `Error::InvalidCoordinates` if values are out of range
91pub fn validate_coordinates(latitude: f32, longitude: f32) -> Result<()> {
92    if !(-90.0..=90.0).contains(&latitude) {
93        return Err(Error::InvalidCoordinates {
94            latitude,
95            longitude,
96            reason: format!("latitude must be in range [-90, 90], got {latitude}"),
97        });
98    }
99    if !(-180.0..=180.0).contains(&longitude) {
100        return Err(Error::InvalidCoordinates {
101            latitude,
102            longitude,
103            reason: format!("longitude must be in range [-180, 180], got {longitude}"),
104        });
105    }
106    Ok(())
107}
108
109/// Validate date parameters for `BirdNET` calendar.
110///
111/// # Arguments
112/// * `month` - Month number (1-12)
113/// * `day` - Day of month (1-31)
114///
115/// # Errors
116/// Returns `Error::InvalidDate` if values are out of range
117pub fn validate_date(month: u32, day: u32) -> Result<()> {
118    if !(1..=12).contains(&month) {
119        return Err(Error::InvalidDate {
120            month,
121            day,
122            reason: format!("month must be in range [1, 12], got {month}"),
123        });
124    }
125    if !(1..=31).contains(&day) {
126        return Err(Error::InvalidDate {
127            month,
128            day,
129            reason: format!("day must be in range [1, 31], got {day}"),
130        });
131    }
132    Ok(())
133}
134
135/// Labels source for builder
136#[derive(Debug)]
137enum Labels {
138    Path(String),
139    InMemory(Vec<String>),
140}
141
142/// Builder for constructing a `RangeFilter`
143#[derive(Debug)]
144pub struct RangeFilterBuilder {
145    model_path: Option<String>,
146    labels: Option<Labels>,
147    execution_providers: Vec<ort::execution_providers::ExecutionProviderDispatch>,
148    threshold: f32,
149}
150
151impl Default for RangeFilterBuilder {
152    fn default() -> Self {
153        Self::new()
154    }
155}
156
157impl RangeFilterBuilder {
158    /// Create a new range filter builder
159    #[must_use]
160    pub const fn new() -> Self {
161        Self {
162            model_path: None,
163            labels: None,
164            execution_providers: Vec::new(),
165            threshold: 0.01,
166        }
167    }
168
169    /// Set the path to the ONNX meta model file (required)
170    #[must_use]
171    pub fn model_path(mut self, path: impl Into<String>) -> Self {
172        self.model_path = Some(path.into());
173        self
174    }
175
176    /// Set the path to the labels file (required, must match model output size)
177    #[must_use]
178    pub fn labels_path(mut self, path: impl Into<String>) -> Self {
179        self.labels = Some(Labels::Path(path.into()));
180        self
181    }
182
183    /// Set species labels directly (required, must match model output size)
184    #[must_use]
185    pub fn labels(mut self, labels: Vec<String>) -> Self {
186        self.labels = Some(Labels::InMemory(labels));
187        self
188    }
189
190    /// Use labels from an existing Classifier.
191    ///
192    /// This is a convenience method that copies labels from a classifier,
193    /// ensuring they stay in sync with the main model.
194    #[must_use]
195    pub fn from_classifier_labels(mut self, labels: &[String]) -> Self {
196        self.labels = Some(Labels::InMemory(labels.to_vec()));
197        self
198    }
199
200    /// Add an execution provider (GPU, CPU, etc.)
201    #[must_use]
202    pub fn execution_provider(
203        mut self,
204        provider: impl Into<ort::execution_providers::ExecutionProviderDispatch>,
205    ) -> Self {
206        self.execution_providers.push(provider.into());
207        self
208    }
209
210    /// Set minimum score threshold (default: 0.01)
211    #[must_use]
212    pub const fn threshold(mut self, threshold: f32) -> Self {
213        self.threshold = threshold;
214        self
215    }
216
217    /// Build the range filter
218    ///
219    /// # Errors
220    /// Returns error if model path or labels not set, or if model loading fails
221    pub fn build(self) -> Result<RangeFilter> {
222        let model_path = self.model_path.ok_or(Error::ModelPathRequired)?;
223        let labels_source = self.labels.ok_or(Error::LabelsRequired)?;
224
225        // Load labels from file or use in-memory vector
226        let labels = match labels_source {
227            Labels::Path(path) => {
228                // Read and parse labels file (text format: one label per line)
229                let content = std::fs::read_to_string(&path).map_err(|e| Error::LabelLoad {
230                    path: path.clone(),
231                    reason: e.to_string(),
232                })?;
233                parse_labels(&content, LabelFormat::Text)?
234            }
235            Labels::InMemory(labels) => labels,
236        };
237
238        // Build session with execution providers
239        let mut session_builder = Session::builder().map_err(Error::ModelLoad)?;
240
241        if !self.execution_providers.is_empty() {
242            session_builder = session_builder
243                .with_execution_providers(self.execution_providers)
244                .map_err(Error::ModelLoad)?;
245        }
246
247        let session = session_builder
248            .commit_from_file(&model_path)
249            .map_err(Error::ModelLoad)?;
250
251        // Validate label count matches model output
252        let output_shapes = extract_output_shapes(&session)?;
253
254        // Meta model should have exactly one output
255        if output_shapes.len() != 1 {
256            return Err(Error::ModelDetection {
257                reason: format!("meta model expects 1 output, got {}", output_shapes.len()),
258            });
259        }
260
261        let expected = extract_last_dim(&output_shapes[0])?;
262        if labels.len() != expected {
263            return Err(Error::LabelCount {
264                expected,
265                got: labels.len(),
266            });
267        }
268
269        Ok(RangeFilter {
270            inner: Arc::new(RangeFilterInner {
271                session: Mutex::new(session),
272                labels,
273                threshold: self.threshold,
274            }),
275        })
276    }
277}
278
279/// Extract output tensor shapes from session
280fn extract_output_shapes(session: &Session) -> Result<Vec<Vec<i64>>> {
281    session
282        .outputs
283        .iter()
284        .map(|output| {
285            let shape = output
286                .output_type
287                .tensor_shape()
288                .ok_or_else(|| Error::ModelDetection {
289                    reason: "output is not a tensor".to_string(),
290                })?;
291            Ok(shape.iter().copied().collect())
292        })
293        .collect()
294}
295
296/// Extract last dimension from output shape
297fn extract_last_dim(shape: &[i64]) -> Result<usize> {
298    let value = shape.last().copied().ok_or_else(|| Error::ModelDetection {
299        reason: "empty output shape".to_string(),
300    })?;
301
302    usize::try_from(value).map_err(|_| Error::ModelDetection {
303        reason: format!("invalid dimension: {value}"),
304    })
305}
306
307/// Filter multiple prediction sets with the same location scores.
308///
309/// This is a helper for batch processing - runs filtering on each
310/// prediction set using the same location scores.
311fn filter_batch_predictions_impl(
312    predictions_batch: Vec<Vec<crate::types::Prediction>>,
313    location_scores: &[LocationScore],
314    threshold: f32,
315    rerank: bool,
316) -> Vec<Vec<crate::types::Prediction>> {
317    predictions_batch
318        .into_iter()
319        .map(|preds| filter_predictions_impl(&preds, location_scores, threshold, rerank))
320        .collect()
321}
322
323/// Filter predictions based on meta model location scores
324///
325/// # Arguments
326/// * `predictions` - Original predictions from audio analysis
327/// * `location_scores` - Location-based species scores from meta model
328/// * `threshold` - Minimum location score threshold
329/// * `rerank` - Whether to rerank by location score (multiply confidence by location score)
330///
331/// # Returns
332/// Filtered predictions, optionally reranked by location score
333fn filter_predictions_impl(
334    predictions: &[Prediction],
335    location_scores: &[LocationScore],
336    threshold: f32,
337    rerank: bool,
338) -> Vec<Prediction> {
339    // Build lookup map from species to location score
340    let location_map: std::collections::HashMap<&str, f32> = location_scores
341        .iter()
342        .map(|score| (score.species.as_str(), score.score))
343        .collect();
344
345    // Filter and optionally rerank predictions
346    let mut filtered: Vec<Prediction> = predictions
347        .iter()
348        .filter_map(|pred| {
349            let location_score = location_map.get(pred.species.as_str()).copied();
350            match location_score {
351                Some(score) if score >= threshold => {
352                    // Species in meta model with score >= threshold: keep and optionally rerank
353                    let confidence = if rerank {
354                        pred.confidence * score
355                    } else {
356                        pred.confidence
357                    };
358                    Some(Prediction {
359                        species: pred.species.clone(),
360                        confidence,
361                        index: pred.index,
362                    })
363                }
364                Some(_) => {
365                    // Species in meta model with score < threshold: filter out
366                    None
367                }
368                None => {
369                    // Species NOT in meta model: keep unchanged
370                    Some(Prediction {
371                        species: pred.species.clone(),
372                        confidence: pred.confidence,
373                        index: pred.index,
374                    })
375                }
376            }
377        })
378        .collect();
379
380    // Re-sort by confidence descending if reranked
381    if rerank {
382        filtered.sort_unstable_by(|a, b| b.confidence.total_cmp(&a.confidence));
383    }
384
385    filtered
386}
387
388/// Internal state for `RangeFilter`
389struct RangeFilterInner {
390    session: Mutex<Session>,
391    labels: Vec<String>,
392    threshold: f32,
393}
394
395/// Thread-safe range filter for location-based species filtering
396#[derive(Clone)]
397pub struct RangeFilter {
398    inner: Arc<RangeFilterInner>,
399}
400
401impl std::fmt::Debug for RangeFilter {
402    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
403        f.debug_struct("RangeFilter")
404            .field("labels_count", &self.inner.labels.len())
405            .field("threshold", &self.inner.threshold)
406            .finish_non_exhaustive()
407    }
408}
409
410impl RangeFilter {
411    /// Create a new range filter builder
412    #[must_use]
413    pub const fn builder() -> RangeFilterBuilder {
414        RangeFilterBuilder::new()
415    }
416
417    /// Get species probability scores for given location and date
418    ///
419    /// # Arguments
420    /// * `latitude` - Latitude in degrees (-90 to 90)
421    /// * `longitude` - Longitude in degrees (-180 to 180)
422    /// * `month` - Month number (1-12)
423    /// * `day` - Day of month (1-31)
424    ///
425    /// # Returns
426    /// Vector of `LocationScore` sorted by score (descending)
427    ///
428    /// # Errors
429    /// Returns error if:
430    /// - Coordinates are invalid (latitude not in [-90, 90] or longitude not in [-180, 180])
431    /// - Date parameters are invalid (month not in [1, 12] or day not in [1, 31])
432    /// - Session lock is poisoned
433    /// - ONNX inference fails
434    #[allow(clippy::significant_drop_tightening)]
435    pub fn predict(
436        &self,
437        latitude: f32,
438        longitude: f32,
439        month: u32,
440        day: u32,
441    ) -> Result<Vec<LocationScore>> {
442        // Validate coordinates
443        validate_coordinates(latitude, longitude)?;
444
445        // Validate date parameters
446        validate_date(month, day)?;
447
448        // Calculate week number
449        let week = calculate_week(month, day);
450
451        // Create input tensor [1, 3] with [latitude, longitude, week]
452        let input_data = vec![latitude, longitude, week];
453        let input_array = Array2::from_shape_vec((1, 3), input_data).map_err(|e| {
454            Error::RangeFilterInference(format!("failed to create input array: {e}"))
455        })?;
456
457        let input_value = Value::from_array(input_array).map_err(|e| {
458            Error::RangeFilterInference(format!("failed to create input tensor: {e}"))
459        })?;
460
461        // Run inference with locked session
462        let mut session = self
463            .inner
464            .session
465            .lock()
466            .map_err(|e| Error::RangeFilterInference(format!("session lock poisoned: {e}")))?;
467
468        let outputs = session
469            .run(ort::inputs![input_value])
470            .map_err(|e| Error::RangeFilterInference(e.to_string()))?;
471
472        // Extract output tensor (build validates exactly one output exists)
473        let tensor = outputs.values().next().ok_or_else(|| {
474            Error::RangeFilterInference("model returned no output tensors".to_string())
475        })?;
476
477        let (_, data) = tensor
478            .try_extract_tensor::<f32>()
479            .map_err(|e| Error::RangeFilterInference(e.to_string()))?;
480
481        // Build scores above threshold
482        let mut scores: Vec<LocationScore> = data
483            .iter()
484            .enumerate()
485            .filter_map(|(i, &score)| {
486                if score >= self.inner.threshold && i < self.inner.labels.len() {
487                    Some(LocationScore {
488                        species: self.inner.labels[i].clone(),
489                        score,
490                        index: i,
491                    })
492                } else {
493                    None
494                }
495            })
496            .collect();
497
498        // Sort by score descending
499        scores.sort_unstable_by(|a, b| b.score.total_cmp(&a.score));
500
501        Ok(scores)
502    }
503
504    /// Filter predictions based on location scores from meta model
505    ///
506    /// # Arguments
507    /// * `predictions` - Original predictions from audio analysis
508    /// * `location_scores` - Location-based species scores (from `predict`)
509    /// * `rerank` - Whether to rerank by multiplying confidence by location score
510    ///
511    /// # Returns
512    /// Filtered predictions, optionally reranked by location score
513    ///
514    /// # Example
515    /// ```no_run
516    /// # use birdnet_onnx::{RangeFilter, Prediction};
517    /// # fn example(filter: &RangeFilter, predictions: Vec<Prediction>) -> birdnet_onnx::Result<()> {
518    /// // Get location scores for a specific place and time
519    /// let location_scores = filter.predict(45.0, -122.0, 6, 15)?;
520    ///
521    /// // Filter predictions to only include species likely at this location
522    /// let filtered = filter.filter_predictions(&predictions, &location_scores, false);
523    /// # Ok(())
524    /// # }
525    /// ```
526    #[must_use]
527    pub fn filter_predictions(
528        &self,
529        predictions: &[Prediction],
530        location_scores: &[LocationScore],
531        rerank: bool,
532    ) -> Vec<Prediction> {
533        filter_predictions_impl(predictions, location_scores, self.inner.threshold, rerank)
534    }
535
536    /// Filter multiple prediction sets using location scores.
537    ///
538    /// This is a convenience method for batch processing multiple audio files
539    /// from the same location. Predict location scores once, then apply to
540    /// multiple prediction sets.
541    ///
542    /// # Arguments
543    /// * `predictions_batch` - Vector of prediction vectors to filter
544    /// * `location_scores` - Location scores from `predict()`
545    /// * `rerank` - Whether to rerank each prediction set
546    ///
547    /// # Returns
548    /// Vector of filtered prediction vectors
549    ///
550    /// # Example
551    /// ```ignore
552    /// let location_scores = range_filter.predict(lat, lon, month, day)?;
553    ///
554    /// let mut predictions_batch = Vec::new();
555    /// for segment in audio_segments {
556    ///     let result = classifier.predict(&segment)?;
557    ///     predictions_batch.push(result.predictions);
558    /// }
559    ///
560    /// let filtered_batch = range_filter.filter_batch_predictions(
561    ///     predictions_batch,
562    ///     &location_scores,
563    ///     rerank,
564    /// );
565    /// ```
566    #[must_use]
567    pub fn filter_batch_predictions(
568        &self,
569        predictions_batch: Vec<Vec<crate::types::Prediction>>,
570        location_scores: &[LocationScore],
571        rerank: bool,
572    ) -> Vec<Vec<crate::types::Prediction>> {
573        filter_batch_predictions_impl(
574            predictions_batch,
575            location_scores,
576            self.inner.threshold,
577            rerank,
578        )
579    }
580}
581
582#[cfg(test)]
583mod tests {
584    #![allow(clippy::unwrap_used)]
585    #![allow(clippy::float_cmp)]
586    use super::*;
587
588    #[test]
589    fn test_calculate_week_january_first() {
590        // January 1st = month 1, day 1
591        // weeksFromMonths = (1 - 1) * 4 = 0
592        // weekInMonth = (1 - 1) / 7 + 1 = 1
593        // week = 0 + 1 = 1
594        let week = calculate_week(1, 1);
595        assert_eq!(week, 1.0);
596    }
597
598    #[test]
599    fn test_calculate_week_january_eighth() {
600        // January 8th = month 1, day 8
601        // weeksFromMonths = 0
602        // weekInMonth = (8 - 1) / 7 + 1 = 2
603        // week = 0 + 2 = 2
604        let week = calculate_week(1, 8);
605        assert_eq!(week, 2.0);
606    }
607
608    #[test]
609    fn test_calculate_week_february_first() {
610        // February 1st = month 2, day 1
611        // weeksFromMonths = (2 - 1) * 4 = 4
612        // weekInMonth = (1 - 1) / 7 + 1 = 1
613        // week = 4 + 1 = 5
614        let week = calculate_week(2, 1);
615        assert_eq!(week, 5.0);
616    }
617
618    #[test]
619    fn test_calculate_week_december_last() {
620        // December 31st = month 12, day 31
621        // weeksFromMonths = (12 - 1) * 4 = 44
622        // weekInMonth = (31 - 1) / 7 + 1 = 5
623        // week = 44 + 5 = 49
624        // Note: BirdNET uses 48-week year, but days 29-31 can exceed 48
625        let week = calculate_week(12, 31);
626        assert_eq!(week, 49.0);
627    }
628
629    #[test]
630    fn test_validate_coordinates_valid() {
631        assert!(validate_coordinates(45.0, -122.0).is_ok());
632        assert!(validate_coordinates(0.0, 0.0).is_ok());
633        assert!(validate_coordinates(-90.0, -180.0).is_ok());
634        assert!(validate_coordinates(90.0, 180.0).is_ok());
635    }
636
637    #[test]
638    fn test_validate_coordinates_invalid_latitude() {
639        let result = validate_coordinates(91.0, 0.0);
640        assert!(result.is_err());
641        assert!(matches!(
642            result.unwrap_err(),
643            Error::InvalidCoordinates { .. }
644        ));
645    }
646
647    #[test]
648    fn test_validate_coordinates_invalid_longitude() {
649        let result = validate_coordinates(0.0, 181.0);
650        assert!(result.is_err());
651    }
652
653    #[test]
654    fn test_validate_date_valid() {
655        assert!(validate_date(1, 1).is_ok());
656        assert!(validate_date(6, 15).is_ok());
657        assert!(validate_date(12, 31).is_ok());
658    }
659
660    #[test]
661    fn test_validate_date_invalid_month_zero() {
662        let result = validate_date(0, 1);
663        assert!(result.is_err());
664        assert!(matches!(result.unwrap_err(), Error::InvalidDate { .. }));
665    }
666
667    #[test]
668    fn test_validate_date_invalid_month_thirteen() {
669        let result = validate_date(13, 1);
670        assert!(result.is_err());
671        assert!(matches!(result.unwrap_err(), Error::InvalidDate { .. }));
672    }
673
674    #[test]
675    fn test_validate_date_invalid_day_zero() {
676        let result = validate_date(1, 0);
677        assert!(result.is_err());
678        assert!(matches!(result.unwrap_err(), Error::InvalidDate { .. }));
679    }
680
681    #[test]
682    fn test_validate_date_invalid_day_thirty_two() {
683        let result = validate_date(1, 32);
684        assert!(result.is_err());
685        assert!(matches!(result.unwrap_err(), Error::InvalidDate { .. }));
686    }
687
688    #[test]
689    fn test_range_filter_builder_missing_model_path() {
690        let result = RangeFilter::builder().build();
691        assert!(result.is_err());
692        assert!(matches!(result.unwrap_err(), Error::ModelPathRequired));
693    }
694
695    #[test]
696    fn test_range_filter_builder_missing_labels() {
697        let result = RangeFilter::builder().model_path("/tmp/model.onnx").build();
698        assert!(result.is_err());
699        assert!(matches!(result.unwrap_err(), Error::LabelsRequired));
700    }
701
702    #[test]
703    fn test_filter_predictions_above_threshold() {
704        // Setup test data
705        let predictions = vec![
706            Prediction {
707                species: "Species A".to_string(),
708                confidence: 0.8,
709                index: 0,
710            },
711            Prediction {
712                species: "Species B".to_string(),
713                confidence: 0.3,
714                index: 1,
715            },
716            Prediction {
717                species: "Species C".to_string(),
718                confidence: 0.05,
719                index: 2,
720            },
721        ];
722
723        let location_scores = vec![
724            LocationScore {
725                species: "Species A".to_string(),
726                score: 0.9,
727                index: 0,
728            },
729            LocationScore {
730                species: "Species B".to_string(),
731                score: 0.02,
732                index: 1,
733            },
734            LocationScore {
735                species: "Species C".to_string(),
736                score: 0.5,
737                index: 2,
738            },
739        ];
740
741        let threshold = 0.03;
742        let rerank = false;
743
744        // Call filter_predictions_impl (will fail - not implemented yet)
745        let filtered = filter_predictions_impl(&predictions, &location_scores, threshold, rerank);
746
747        // Species B should be filtered out (score 0.02 < threshold 0.03)
748        assert_eq!(filtered.len(), 2);
749        assert_eq!(filtered[0].species, "Species A");
750        assert_eq!(filtered[1].species, "Species C");
751    }
752
753    #[test]
754    fn test_filter_predictions_with_rerank() {
755        // Setup test data with different confidence and location scores
756        let predictions = vec![
757            Prediction {
758                species: "Species A".to_string(),
759                confidence: 0.9, // High confidence
760                index: 0,
761            },
762            Prediction {
763                species: "Species B".to_string(),
764                confidence: 0.8, // Medium-high confidence
765                index: 1,
766            },
767            Prediction {
768                species: "Species C".to_string(),
769                confidence: 0.7, // Medium confidence
770                index: 2,
771            },
772        ];
773
774        let location_scores = vec![
775            LocationScore {
776                species: "Species A".to_string(),
777                score: 0.5, // Medium location score
778                index: 0,
779            },
780            LocationScore {
781                species: "Species B".to_string(),
782                score: 0.9, // High location score
783                index: 1,
784            },
785            LocationScore {
786                species: "Species C".to_string(),
787                score: 0.6, // Medium location score
788                index: 2,
789            },
790        ];
791
792        let threshold = 0.03;
793        let rerank = true;
794
795        let filtered = filter_predictions_impl(&predictions, &location_scores, threshold, rerank);
796
797        // All should pass threshold
798        assert_eq!(filtered.len(), 3);
799
800        // After reranking (confidence * location_score):
801        // Species A: 0.9 * 0.5 = 0.45
802        // Species B: 0.8 * 0.9 = 0.72 (highest)
803        // Species C: 0.7 * 0.6 = 0.42
804        // Should be sorted: B, A, C
805        assert_eq!(filtered[0].species, "Species B");
806        assert_eq!(filtered[1].species, "Species A");
807        assert_eq!(filtered[2].species, "Species C");
808
809        // Verify reranked scores
810        assert!((filtered[0].confidence - 0.72).abs() < 0.001);
811        assert!((filtered[1].confidence - 0.45).abs() < 0.001);
812        assert!((filtered[2].confidence - 0.42).abs() < 0.001);
813    }
814
815    #[test]
816    fn test_filter_predictions_species_not_in_meta_model() {
817        // Setup test data where some predictions are not in meta model
818        let predictions = vec![
819            Prediction {
820                species: "Species A".to_string(),
821                confidence: 0.8,
822                index: 0,
823            },
824            Prediction {
825                species: "Species B".to_string(),
826                confidence: 0.7,
827                index: 1,
828            },
829            Prediction {
830                species: "Species D".to_string(), // Not in meta model
831                confidence: 0.9,
832                index: 3,
833            },
834        ];
835
836        let location_scores = vec![
837            LocationScore {
838                species: "Species A".to_string(),
839                score: 0.9,
840                index: 0,
841            },
842            LocationScore {
843                species: "Species C".to_string(), // Not in predictions
844                score: 0.8,
845                index: 2,
846            },
847        ];
848
849        let threshold = 0.03;
850        let rerank = false;
851
852        let filtered = filter_predictions_impl(&predictions, &location_scores, threshold, rerank);
853
854        // Species A (in meta, score >= threshold): KEEP
855        // Species B (NOT in meta model): KEEP unchanged
856        // Species D (NOT in meta model): KEEP unchanged
857        assert_eq!(filtered.len(), 3);
858        assert_eq!(filtered[0].species, "Species A");
859        assert_eq!(filtered[0].confidence, 0.8);
860        assert_eq!(filtered[1].species, "Species B");
861        assert_eq!(filtered[1].confidence, 0.7);
862        assert_eq!(filtered[2].species, "Species D");
863        assert_eq!(filtered[2].confidence, 0.9);
864    }
865
866    #[test]
867    fn test_builder_from_classifier_labels() {
868        // This test verifies the builder can accept a label reference
869        // We can't test with a real Classifier without a model file,
870        // so we test the builder configuration
871
872        let labels = vec!["Species A".to_string(), "Species B".to_string()];
873        let builder = RangeFilterBuilder::new().from_classifier_labels(&labels);
874
875        // Verify labels were set (we'll need to expose this for testing)
876        assert!(matches!(builder.labels, Some(Labels::InMemory(_))));
877    }
878
879    #[test]
880    fn test_filter_batch_predictions() {
881        use crate::types::Prediction;
882
883        let batch1 = vec![Prediction {
884            species: "Species A".to_string(),
885            confidence: 0.8,
886            index: 0,
887        }];
888        let batch2 = vec![Prediction {
889            species: "Species B".to_string(),
890            confidence: 0.6,
891            index: 1,
892        }];
893
894        let predictions_batch = vec![batch1, batch2];
895
896        let location_scores = vec![
897            LocationScore {
898                species: "Species A".to_string(),
899                score: 0.9,
900                index: 0,
901            },
902            LocationScore {
903                species: "Species B".to_string(),
904                score: 0.05,
905                index: 1,
906            },
907        ];
908
909        let threshold = 0.1;
910        let results =
911            filter_batch_predictions_impl(predictions_batch, &location_scores, threshold, false);
912
913        assert_eq!(results.len(), 2);
914        assert_eq!(results[0].len(), 1); // Species A kept
915        assert_eq!(results[1].len(), 0); // Species B filtered
916    }
917}