1use crate::error::{Error, Result};
57use crate::labels::parse_labels;
58use crate::types::{LabelFormat, LocationScore, Prediction};
59use ndarray::Array2;
60use ort::session::Session;
61use ort::value::Value;
62use std::sync::{Arc, Mutex};
63
64#[must_use]
76#[allow(clippy::cast_precision_loss)]
77pub const fn calculate_week(month: u32, day: u32) -> f32 {
78 let weeks_from_months = (month - 1) * 4;
79 let week_in_month = (day - 1) / 7 + 1;
80 (weeks_from_months + week_in_month) as f32
81}
82
83pub fn validate_coordinates(latitude: f32, longitude: f32) -> Result<()> {
92 if !(-90.0..=90.0).contains(&latitude) {
93 return Err(Error::InvalidCoordinates {
94 latitude,
95 longitude,
96 reason: format!("latitude must be in range [-90, 90], got {latitude}"),
97 });
98 }
99 if !(-180.0..=180.0).contains(&longitude) {
100 return Err(Error::InvalidCoordinates {
101 latitude,
102 longitude,
103 reason: format!("longitude must be in range [-180, 180], got {longitude}"),
104 });
105 }
106 Ok(())
107}
108
109pub fn validate_date(month: u32, day: u32) -> Result<()> {
118 if !(1..=12).contains(&month) {
119 return Err(Error::InvalidDate {
120 month,
121 day,
122 reason: format!("month must be in range [1, 12], got {month}"),
123 });
124 }
125 if !(1..=31).contains(&day) {
126 return Err(Error::InvalidDate {
127 month,
128 day,
129 reason: format!("day must be in range [1, 31], got {day}"),
130 });
131 }
132 Ok(())
133}
134
135#[derive(Debug)]
137enum Labels {
138 Path(String),
139 InMemory(Vec<String>),
140}
141
142#[derive(Debug)]
144pub struct RangeFilterBuilder {
145 model_path: Option<String>,
146 labels: Option<Labels>,
147 execution_providers: Vec<ort::execution_providers::ExecutionProviderDispatch>,
148 threshold: f32,
149}
150
151impl Default for RangeFilterBuilder {
152 fn default() -> Self {
153 Self::new()
154 }
155}
156
157impl RangeFilterBuilder {
158 #[must_use]
160 pub const fn new() -> Self {
161 Self {
162 model_path: None,
163 labels: None,
164 execution_providers: Vec::new(),
165 threshold: 0.01,
166 }
167 }
168
169 #[must_use]
171 pub fn model_path(mut self, path: impl Into<String>) -> Self {
172 self.model_path = Some(path.into());
173 self
174 }
175
176 #[must_use]
178 pub fn labels_path(mut self, path: impl Into<String>) -> Self {
179 self.labels = Some(Labels::Path(path.into()));
180 self
181 }
182
183 #[must_use]
185 pub fn labels(mut self, labels: Vec<String>) -> Self {
186 self.labels = Some(Labels::InMemory(labels));
187 self
188 }
189
190 #[must_use]
195 pub fn from_classifier_labels(mut self, labels: &[String]) -> Self {
196 self.labels = Some(Labels::InMemory(labels.to_vec()));
197 self
198 }
199
200 #[must_use]
202 pub fn execution_provider(
203 mut self,
204 provider: impl Into<ort::execution_providers::ExecutionProviderDispatch>,
205 ) -> Self {
206 self.execution_providers.push(provider.into());
207 self
208 }
209
210 #[must_use]
212 pub const fn threshold(mut self, threshold: f32) -> Self {
213 self.threshold = threshold;
214 self
215 }
216
217 pub fn build(self) -> Result<RangeFilter> {
222 let model_path = self.model_path.ok_or(Error::ModelPathRequired)?;
223 let labels_source = self.labels.ok_or(Error::LabelsRequired)?;
224
225 let labels = match labels_source {
227 Labels::Path(path) => {
228 let content = std::fs::read_to_string(&path).map_err(|e| Error::LabelLoad {
230 path: path.clone(),
231 reason: e.to_string(),
232 })?;
233 parse_labels(&content, LabelFormat::Text)?
234 }
235 Labels::InMemory(labels) => labels,
236 };
237
238 let mut session_builder = Session::builder().map_err(Error::ModelLoad)?;
240
241 if !self.execution_providers.is_empty() {
242 session_builder = session_builder
243 .with_execution_providers(self.execution_providers)
244 .map_err(Error::ModelLoad)?;
245 }
246
247 let session = session_builder
248 .commit_from_file(&model_path)
249 .map_err(Error::ModelLoad)?;
250
251 let output_shapes = extract_output_shapes(&session)?;
253
254 if output_shapes.len() != 1 {
256 return Err(Error::ModelDetection {
257 reason: format!("meta model expects 1 output, got {}", output_shapes.len()),
258 });
259 }
260
261 let expected = extract_last_dim(&output_shapes[0])?;
262 if labels.len() != expected {
263 return Err(Error::LabelCount {
264 expected,
265 got: labels.len(),
266 });
267 }
268
269 Ok(RangeFilter {
270 inner: Arc::new(RangeFilterInner {
271 session: Mutex::new(session),
272 labels,
273 threshold: self.threshold,
274 }),
275 })
276 }
277}
278
279fn extract_output_shapes(session: &Session) -> Result<Vec<Vec<i64>>> {
281 session
282 .outputs
283 .iter()
284 .map(|output| {
285 let shape = output
286 .output_type
287 .tensor_shape()
288 .ok_or_else(|| Error::ModelDetection {
289 reason: "output is not a tensor".to_string(),
290 })?;
291 Ok(shape.iter().copied().collect())
292 })
293 .collect()
294}
295
296fn extract_last_dim(shape: &[i64]) -> Result<usize> {
298 let value = shape.last().copied().ok_or_else(|| Error::ModelDetection {
299 reason: "empty output shape".to_string(),
300 })?;
301
302 usize::try_from(value).map_err(|_| Error::ModelDetection {
303 reason: format!("invalid dimension: {value}"),
304 })
305}
306
307fn filter_batch_predictions_impl(
312 predictions_batch: Vec<Vec<crate::types::Prediction>>,
313 location_scores: &[LocationScore],
314 threshold: f32,
315 rerank: bool,
316) -> Vec<Vec<crate::types::Prediction>> {
317 predictions_batch
318 .into_iter()
319 .map(|preds| filter_predictions_impl(&preds, location_scores, threshold, rerank))
320 .collect()
321}
322
323fn filter_predictions_impl(
334 predictions: &[Prediction],
335 location_scores: &[LocationScore],
336 threshold: f32,
337 rerank: bool,
338) -> Vec<Prediction> {
339 let location_map: std::collections::HashMap<&str, f32> = location_scores
341 .iter()
342 .map(|score| (score.species.as_str(), score.score))
343 .collect();
344
345 let mut filtered: Vec<Prediction> = predictions
347 .iter()
348 .filter_map(|pred| {
349 let location_score = location_map.get(pred.species.as_str()).copied();
350 match location_score {
351 Some(score) if score >= threshold => {
352 let confidence = if rerank {
354 pred.confidence * score
355 } else {
356 pred.confidence
357 };
358 Some(Prediction {
359 species: pred.species.clone(),
360 confidence,
361 index: pred.index,
362 })
363 }
364 Some(_) => {
365 None
367 }
368 None => {
369 Some(Prediction {
371 species: pred.species.clone(),
372 confidence: pred.confidence,
373 index: pred.index,
374 })
375 }
376 }
377 })
378 .collect();
379
380 if rerank {
382 filtered.sort_unstable_by(|a, b| b.confidence.total_cmp(&a.confidence));
383 }
384
385 filtered
386}
387
388struct RangeFilterInner {
390 session: Mutex<Session>,
391 labels: Vec<String>,
392 threshold: f32,
393}
394
395#[derive(Clone)]
397pub struct RangeFilter {
398 inner: Arc<RangeFilterInner>,
399}
400
401impl std::fmt::Debug for RangeFilter {
402 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
403 f.debug_struct("RangeFilter")
404 .field("labels_count", &self.inner.labels.len())
405 .field("threshold", &self.inner.threshold)
406 .finish_non_exhaustive()
407 }
408}
409
410impl RangeFilter {
411 #[must_use]
413 pub const fn builder() -> RangeFilterBuilder {
414 RangeFilterBuilder::new()
415 }
416
417 #[allow(clippy::significant_drop_tightening)]
435 pub fn predict(
436 &self,
437 latitude: f32,
438 longitude: f32,
439 month: u32,
440 day: u32,
441 ) -> Result<Vec<LocationScore>> {
442 validate_coordinates(latitude, longitude)?;
444
445 validate_date(month, day)?;
447
448 let week = calculate_week(month, day);
450
451 let input_data = vec![latitude, longitude, week];
453 let input_array = Array2::from_shape_vec((1, 3), input_data).map_err(|e| {
454 Error::RangeFilterInference(format!("failed to create input array: {e}"))
455 })?;
456
457 let input_value = Value::from_array(input_array).map_err(|e| {
458 Error::RangeFilterInference(format!("failed to create input tensor: {e}"))
459 })?;
460
461 let mut session = self
463 .inner
464 .session
465 .lock()
466 .map_err(|e| Error::RangeFilterInference(format!("session lock poisoned: {e}")))?;
467
468 let outputs = session
469 .run(ort::inputs![input_value])
470 .map_err(|e| Error::RangeFilterInference(e.to_string()))?;
471
472 let tensor = outputs.values().next().ok_or_else(|| {
474 Error::RangeFilterInference("model returned no output tensors".to_string())
475 })?;
476
477 let (_, data) = tensor
478 .try_extract_tensor::<f32>()
479 .map_err(|e| Error::RangeFilterInference(e.to_string()))?;
480
481 let mut scores: Vec<LocationScore> = data
483 .iter()
484 .enumerate()
485 .filter_map(|(i, &score)| {
486 if score >= self.inner.threshold && i < self.inner.labels.len() {
487 Some(LocationScore {
488 species: self.inner.labels[i].clone(),
489 score,
490 index: i,
491 })
492 } else {
493 None
494 }
495 })
496 .collect();
497
498 scores.sort_unstable_by(|a, b| b.score.total_cmp(&a.score));
500
501 Ok(scores)
502 }
503
504 #[must_use]
527 pub fn filter_predictions(
528 &self,
529 predictions: &[Prediction],
530 location_scores: &[LocationScore],
531 rerank: bool,
532 ) -> Vec<Prediction> {
533 filter_predictions_impl(predictions, location_scores, self.inner.threshold, rerank)
534 }
535
536 #[must_use]
567 pub fn filter_batch_predictions(
568 &self,
569 predictions_batch: Vec<Vec<crate::types::Prediction>>,
570 location_scores: &[LocationScore],
571 rerank: bool,
572 ) -> Vec<Vec<crate::types::Prediction>> {
573 filter_batch_predictions_impl(
574 predictions_batch,
575 location_scores,
576 self.inner.threshold,
577 rerank,
578 )
579 }
580}
581
582#[cfg(test)]
583mod tests {
584 #![allow(clippy::unwrap_used)]
585 #![allow(clippy::float_cmp)]
586 use super::*;
587
588 #[test]
589 fn test_calculate_week_january_first() {
590 let week = calculate_week(1, 1);
595 assert_eq!(week, 1.0);
596 }
597
598 #[test]
599 fn test_calculate_week_january_eighth() {
600 let week = calculate_week(1, 8);
605 assert_eq!(week, 2.0);
606 }
607
608 #[test]
609 fn test_calculate_week_february_first() {
610 let week = calculate_week(2, 1);
615 assert_eq!(week, 5.0);
616 }
617
618 #[test]
619 fn test_calculate_week_december_last() {
620 let week = calculate_week(12, 31);
626 assert_eq!(week, 49.0);
627 }
628
629 #[test]
630 fn test_validate_coordinates_valid() {
631 assert!(validate_coordinates(45.0, -122.0).is_ok());
632 assert!(validate_coordinates(0.0, 0.0).is_ok());
633 assert!(validate_coordinates(-90.0, -180.0).is_ok());
634 assert!(validate_coordinates(90.0, 180.0).is_ok());
635 }
636
637 #[test]
638 fn test_validate_coordinates_invalid_latitude() {
639 let result = validate_coordinates(91.0, 0.0);
640 assert!(result.is_err());
641 assert!(matches!(
642 result.unwrap_err(),
643 Error::InvalidCoordinates { .. }
644 ));
645 }
646
647 #[test]
648 fn test_validate_coordinates_invalid_longitude() {
649 let result = validate_coordinates(0.0, 181.0);
650 assert!(result.is_err());
651 }
652
653 #[test]
654 fn test_validate_date_valid() {
655 assert!(validate_date(1, 1).is_ok());
656 assert!(validate_date(6, 15).is_ok());
657 assert!(validate_date(12, 31).is_ok());
658 }
659
660 #[test]
661 fn test_validate_date_invalid_month_zero() {
662 let result = validate_date(0, 1);
663 assert!(result.is_err());
664 assert!(matches!(result.unwrap_err(), Error::InvalidDate { .. }));
665 }
666
667 #[test]
668 fn test_validate_date_invalid_month_thirteen() {
669 let result = validate_date(13, 1);
670 assert!(result.is_err());
671 assert!(matches!(result.unwrap_err(), Error::InvalidDate { .. }));
672 }
673
674 #[test]
675 fn test_validate_date_invalid_day_zero() {
676 let result = validate_date(1, 0);
677 assert!(result.is_err());
678 assert!(matches!(result.unwrap_err(), Error::InvalidDate { .. }));
679 }
680
681 #[test]
682 fn test_validate_date_invalid_day_thirty_two() {
683 let result = validate_date(1, 32);
684 assert!(result.is_err());
685 assert!(matches!(result.unwrap_err(), Error::InvalidDate { .. }));
686 }
687
688 #[test]
689 fn test_range_filter_builder_missing_model_path() {
690 let result = RangeFilter::builder().build();
691 assert!(result.is_err());
692 assert!(matches!(result.unwrap_err(), Error::ModelPathRequired));
693 }
694
695 #[test]
696 fn test_range_filter_builder_missing_labels() {
697 let result = RangeFilter::builder().model_path("/tmp/model.onnx").build();
698 assert!(result.is_err());
699 assert!(matches!(result.unwrap_err(), Error::LabelsRequired));
700 }
701
702 #[test]
703 fn test_filter_predictions_above_threshold() {
704 let predictions = vec![
706 Prediction {
707 species: "Species A".to_string(),
708 confidence: 0.8,
709 index: 0,
710 },
711 Prediction {
712 species: "Species B".to_string(),
713 confidence: 0.3,
714 index: 1,
715 },
716 Prediction {
717 species: "Species C".to_string(),
718 confidence: 0.05,
719 index: 2,
720 },
721 ];
722
723 let location_scores = vec![
724 LocationScore {
725 species: "Species A".to_string(),
726 score: 0.9,
727 index: 0,
728 },
729 LocationScore {
730 species: "Species B".to_string(),
731 score: 0.02,
732 index: 1,
733 },
734 LocationScore {
735 species: "Species C".to_string(),
736 score: 0.5,
737 index: 2,
738 },
739 ];
740
741 let threshold = 0.03;
742 let rerank = false;
743
744 let filtered = filter_predictions_impl(&predictions, &location_scores, threshold, rerank);
746
747 assert_eq!(filtered.len(), 2);
749 assert_eq!(filtered[0].species, "Species A");
750 assert_eq!(filtered[1].species, "Species C");
751 }
752
753 #[test]
754 fn test_filter_predictions_with_rerank() {
755 let predictions = vec![
757 Prediction {
758 species: "Species A".to_string(),
759 confidence: 0.9, index: 0,
761 },
762 Prediction {
763 species: "Species B".to_string(),
764 confidence: 0.8, index: 1,
766 },
767 Prediction {
768 species: "Species C".to_string(),
769 confidence: 0.7, index: 2,
771 },
772 ];
773
774 let location_scores = vec![
775 LocationScore {
776 species: "Species A".to_string(),
777 score: 0.5, index: 0,
779 },
780 LocationScore {
781 species: "Species B".to_string(),
782 score: 0.9, index: 1,
784 },
785 LocationScore {
786 species: "Species C".to_string(),
787 score: 0.6, index: 2,
789 },
790 ];
791
792 let threshold = 0.03;
793 let rerank = true;
794
795 let filtered = filter_predictions_impl(&predictions, &location_scores, threshold, rerank);
796
797 assert_eq!(filtered.len(), 3);
799
800 assert_eq!(filtered[0].species, "Species B");
806 assert_eq!(filtered[1].species, "Species A");
807 assert_eq!(filtered[2].species, "Species C");
808
809 assert!((filtered[0].confidence - 0.72).abs() < 0.001);
811 assert!((filtered[1].confidence - 0.45).abs() < 0.001);
812 assert!((filtered[2].confidence - 0.42).abs() < 0.001);
813 }
814
815 #[test]
816 fn test_filter_predictions_species_not_in_meta_model() {
817 let predictions = vec![
819 Prediction {
820 species: "Species A".to_string(),
821 confidence: 0.8,
822 index: 0,
823 },
824 Prediction {
825 species: "Species B".to_string(),
826 confidence: 0.7,
827 index: 1,
828 },
829 Prediction {
830 species: "Species D".to_string(), confidence: 0.9,
832 index: 3,
833 },
834 ];
835
836 let location_scores = vec![
837 LocationScore {
838 species: "Species A".to_string(),
839 score: 0.9,
840 index: 0,
841 },
842 LocationScore {
843 species: "Species C".to_string(), score: 0.8,
845 index: 2,
846 },
847 ];
848
849 let threshold = 0.03;
850 let rerank = false;
851
852 let filtered = filter_predictions_impl(&predictions, &location_scores, threshold, rerank);
853
854 assert_eq!(filtered.len(), 3);
858 assert_eq!(filtered[0].species, "Species A");
859 assert_eq!(filtered[0].confidence, 0.8);
860 assert_eq!(filtered[1].species, "Species B");
861 assert_eq!(filtered[1].confidence, 0.7);
862 assert_eq!(filtered[2].species, "Species D");
863 assert_eq!(filtered[2].confidence, 0.9);
864 }
865
866 #[test]
867 fn test_builder_from_classifier_labels() {
868 let labels = vec!["Species A".to_string(), "Species B".to_string()];
873 let builder = RangeFilterBuilder::new().from_classifier_labels(&labels);
874
875 assert!(matches!(builder.labels, Some(Labels::InMemory(_))));
877 }
878
879 #[test]
880 fn test_filter_batch_predictions() {
881 use crate::types::Prediction;
882
883 let batch1 = vec![Prediction {
884 species: "Species A".to_string(),
885 confidence: 0.8,
886 index: 0,
887 }];
888 let batch2 = vec![Prediction {
889 species: "Species B".to_string(),
890 confidence: 0.6,
891 index: 1,
892 }];
893
894 let predictions_batch = vec![batch1, batch2];
895
896 let location_scores = vec![
897 LocationScore {
898 species: "Species A".to_string(),
899 score: 0.9,
900 index: 0,
901 },
902 LocationScore {
903 species: "Species B".to_string(),
904 score: 0.05,
905 index: 1,
906 },
907 ];
908
909 let threshold = 0.1;
910 let results =
911 filter_batch_predictions_impl(predictions_batch, &location_scores, threshold, false);
912
913 assert_eq!(results.len(), 2);
914 assert_eq!(results[0].len(), 1); assert_eq!(results[1].len(), 0); }
917}