birdnet_onnx/
types.rs

1/// Supported model types.
2#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3pub enum ModelType {
4    /// `BirdNET` v2.4 - 48kHz, 3s segments, no embeddings.
5    BirdNetV24,
6    /// `BirdNET` v3.0 - 32kHz, 5s segments, 1024-dim embeddings.
7    BirdNetV30,
8    /// Google `Perch` v2 - 32kHz, 5s segments, variable embeddings.
9    PerchV2,
10}
11
12impl ModelType {
13    /// Sample rate in Hz.
14    #[must_use]
15    pub const fn sample_rate(&self) -> u32 {
16        match self {
17            Self::BirdNetV24 => 48_000,
18            Self::BirdNetV30 | Self::PerchV2 => 32_000,
19        }
20    }
21
22    /// Segment duration in seconds.
23    #[must_use]
24    pub const fn segment_duration(&self) -> f32 {
25        match self {
26            Self::BirdNetV24 => 3.0,
27            Self::BirdNetV30 | Self::PerchV2 => 5.0,
28        }
29    }
30
31    /// Expected sample count per segment.
32    #[must_use]
33    pub const fn sample_count(&self) -> usize {
34        match self {
35            Self::BirdNetV24 => 144_000,
36            Self::BirdNetV30 | Self::PerchV2 => 160_000,
37        }
38    }
39
40    /// Whether this model produces embeddings.
41    #[must_use]
42    pub const fn has_embeddings(&self) -> bool {
43        match self {
44            Self::BirdNetV24 => false,
45            Self::BirdNetV30 | Self::PerchV2 => true,
46        }
47    }
48
49    /// Expected label file format for this model type.
50    #[must_use]
51    pub const fn expected_label_format(&self) -> LabelFormat {
52        match self {
53            Self::BirdNetV24 => LabelFormat::Text,
54            Self::BirdNetV30 | Self::PerchV2 => LabelFormat::Csv,
55        }
56    }
57}
58
59/// Expected label format per model type.
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum LabelFormat {
62    /// One label per line.
63    Text,
64    /// CSV with first column as label.
65    Csv,
66    /// JSON array or object.
67    Json,
68}
69
70/// Model configuration derived from detected model type.
71#[derive(Debug, Clone)]
72pub struct ModelConfig {
73    /// Detected or overridden model type.
74    pub model_type: ModelType,
75    /// Sample rate in Hz.
76    pub sample_rate: u32,
77    /// Segment duration in seconds.
78    pub segment_duration: f32,
79    /// Expected sample count per segment.
80    pub sample_count: usize,
81    /// Number of species classes in model output.
82    pub num_species: usize,
83    /// Embedding dimension (None for v2.4).
84    pub embedding_dim: Option<usize>,
85}
86
87/// Single species prediction.
88#[derive(Debug, Clone)]
89pub struct Prediction {
90    /// Species name from labels.
91    pub species: String,
92    /// Confidence score (0.0 - 1.0, after sigmoid).
93    pub confidence: f32,
94    /// Index in model output.
95    pub index: usize,
96}
97
98/// Complete inference result.
99#[derive(Debug, Clone)]
100pub struct PredictionResult {
101    /// Model type that produced this result.
102    pub model_type: ModelType,
103    /// Top predictions sorted by confidence (descending).
104    pub predictions: Vec<Prediction>,
105    /// Feature embeddings (None for `BirdNET` v2.4).
106    pub embeddings: Option<Vec<f32>>,
107    /// Raw logits from model output.
108    pub raw_scores: Vec<f32>,
109}
110
111/// Species probability score from meta model based on location and date.
112#[derive(Debug, Clone)]
113pub struct LocationScore {
114    /// Species name from labels.
115    pub species: String,
116    /// Probability score (0.0 - 1.0) for this species at given location/time.
117    pub score: f32,
118    /// Index in model output.
119    pub index: usize,
120}
121
122/// Information about execution providers (hardware backends).
123#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
124pub enum ExecutionProviderInfo {
125    /// CPU execution provider (always available).
126    Cpu,
127    /// NVIDIA CUDA execution provider.
128    Cuda,
129    /// NVIDIA `TensorRT` execution provider.
130    TensorRt,
131    /// `DirectML` execution provider (Windows).
132    DirectMl,
133    /// Apple `CoreML` execution provider.
134    CoreMl,
135    /// AMD `ROCm` execution provider.
136    Rocm,
137    /// Intel `OpenVINO` execution provider.
138    OpenVino,
139    /// Intel oneDNN execution provider.
140    OneDnn,
141    /// Qualcomm QNN execution provider.
142    Qnn,
143    /// Arm Compute Library execution provider.
144    Acl,
145    /// Arm NN execution provider.
146    ArmNn,
147}
148
149impl ExecutionProviderInfo {
150    /// Returns the execution provider name as a string.
151    #[must_use]
152    pub const fn as_str(self) -> &'static str {
153        match self {
154            Self::Cpu => "CPU",
155            Self::Cuda => "CUDA",
156            Self::TensorRt => "TensorRT",
157            Self::DirectMl => "DirectML",
158            Self::CoreMl => "CoreML",
159            Self::Rocm => "ROCm",
160            Self::OpenVino => "OpenVINO",
161            Self::OneDnn => "oneDNN",
162            Self::Qnn => "QNN",
163            Self::Acl => "ACL",
164            Self::ArmNn => "ArmNN",
165        }
166    }
167
168    /// Returns the hardware category for this execution provider.
169    #[must_use]
170    pub const fn category(self) -> &'static str {
171        match self {
172            Self::Cpu => "CPU",
173            Self::Cuda | Self::TensorRt | Self::Rocm | Self::DirectMl => "GPU",
174            Self::CoreMl => "Neural Engine",
175            Self::Qnn => "NPU",
176            Self::OpenVino | Self::OneDnn | Self::Acl | Self::ArmNn => "Accelerator",
177        }
178    }
179}
180
181impl std::fmt::Display for ExecutionProviderInfo {
182    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183        write!(f, "{}", self.as_str())
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    #![allow(clippy::unwrap_used)]
190    #![allow(clippy::float_cmp)]
191    #![allow(clippy::cast_precision_loss)]
192    use super::*;
193
194    #[test]
195    fn test_birdnet_v24_properties() {
196        let model = ModelType::BirdNetV24;
197        assert_eq!(model.sample_rate(), 48_000);
198        assert_eq!(model.segment_duration(), 3.0);
199        assert_eq!(model.sample_count(), 144_000);
200        assert!(!model.has_embeddings());
201        assert_eq!(model.expected_label_format(), LabelFormat::Text);
202    }
203
204    #[test]
205    fn test_birdnet_v30_properties() {
206        let model = ModelType::BirdNetV30;
207        assert_eq!(model.sample_rate(), 32_000);
208        assert_eq!(model.segment_duration(), 5.0);
209        assert_eq!(model.sample_count(), 160_000);
210        assert!(model.has_embeddings());
211        assert_eq!(model.expected_label_format(), LabelFormat::Csv);
212    }
213
214    #[test]
215    fn test_perch_v2_properties() {
216        let model = ModelType::PerchV2;
217        assert_eq!(model.sample_rate(), 32_000);
218        assert_eq!(model.segment_duration(), 5.0);
219        assert_eq!(model.sample_count(), 160_000);
220        assert!(model.has_embeddings());
221        assert_eq!(model.expected_label_format(), LabelFormat::Csv);
222    }
223
224    #[test]
225    fn test_sample_count_matches_rate_times_duration() {
226        for model in [
227            ModelType::BirdNetV24,
228            ModelType::BirdNetV30,
229            ModelType::PerchV2,
230        ] {
231            #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
232            let expected = (model.sample_rate() as f32 * model.segment_duration()) as usize;
233            assert_eq!(model.sample_count(), expected);
234        }
235    }
236
237    #[test]
238    fn test_location_score_creation() {
239        let score = LocationScore {
240            species: "Turdus merula_Common Blackbird".to_string(),
241            score: 0.85,
242            index: 42,
243        };
244        assert_eq!(score.species, "Turdus merula_Common Blackbird");
245        assert_eq!(score.score, 0.85);
246        assert_eq!(score.index, 42);
247    }
248
249    #[test]
250    fn test_execution_provider_display() {
251        assert_eq!(ExecutionProviderInfo::Cpu.to_string(), "CPU");
252        assert_eq!(ExecutionProviderInfo::Cuda.to_string(), "CUDA");
253        assert_eq!(ExecutionProviderInfo::TensorRt.to_string(), "TensorRT");
254        assert_eq!(ExecutionProviderInfo::DirectMl.to_string(), "DirectML");
255        assert_eq!(ExecutionProviderInfo::CoreMl.to_string(), "CoreML");
256        assert_eq!(ExecutionProviderInfo::Rocm.to_string(), "ROCm");
257        assert_eq!(ExecutionProviderInfo::OpenVino.to_string(), "OpenVINO");
258        assert_eq!(ExecutionProviderInfo::OneDnn.to_string(), "oneDNN");
259        assert_eq!(ExecutionProviderInfo::Qnn.to_string(), "QNN");
260        assert_eq!(ExecutionProviderInfo::Acl.to_string(), "ACL");
261        assert_eq!(ExecutionProviderInfo::ArmNn.to_string(), "ArmNN");
262    }
263
264    #[test]
265    fn test_execution_provider_category_cpu() {
266        assert_eq!(ExecutionProviderInfo::Cpu.category(), "CPU");
267    }
268
269    #[test]
270    fn test_execution_provider_category_gpu() {
271        assert_eq!(ExecutionProviderInfo::Cuda.category(), "GPU");
272        assert_eq!(ExecutionProviderInfo::TensorRt.category(), "GPU");
273        assert_eq!(ExecutionProviderInfo::Rocm.category(), "GPU");
274        assert_eq!(ExecutionProviderInfo::DirectMl.category(), "GPU");
275    }
276
277    #[test]
278    fn test_execution_provider_category_neural_engine() {
279        assert_eq!(ExecutionProviderInfo::CoreMl.category(), "Neural Engine");
280    }
281
282    #[test]
283    fn test_execution_provider_category_npu() {
284        assert_eq!(ExecutionProviderInfo::Qnn.category(), "NPU");
285    }
286
287    #[test]
288    fn test_execution_provider_category_accelerator() {
289        assert_eq!(ExecutionProviderInfo::OpenVino.category(), "Accelerator");
290        assert_eq!(ExecutionProviderInfo::OneDnn.category(), "Accelerator");
291        assert_eq!(ExecutionProviderInfo::Acl.category(), "Accelerator");
292        assert_eq!(ExecutionProviderInfo::ArmNn.category(), "Accelerator");
293    }
294}