1#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3pub enum ModelType {
4 BirdNetV24,
6 BirdNetV30,
8 PerchV2,
10}
11
12impl ModelType {
13 #[must_use]
15 pub const fn sample_rate(&self) -> u32 {
16 match self {
17 Self::BirdNetV24 => 48_000,
18 Self::BirdNetV30 | Self::PerchV2 => 32_000,
19 }
20 }
21
22 #[must_use]
24 pub const fn segment_duration(&self) -> f32 {
25 match self {
26 Self::BirdNetV24 => 3.0,
27 Self::BirdNetV30 | Self::PerchV2 => 5.0,
28 }
29 }
30
31 #[must_use]
33 pub const fn sample_count(&self) -> usize {
34 match self {
35 Self::BirdNetV24 => 144_000,
36 Self::BirdNetV30 | Self::PerchV2 => 160_000,
37 }
38 }
39
40 #[must_use]
42 pub const fn has_embeddings(&self) -> bool {
43 match self {
44 Self::BirdNetV24 => false,
45 Self::BirdNetV30 | Self::PerchV2 => true,
46 }
47 }
48
49 #[must_use]
51 pub const fn expected_label_format(&self) -> LabelFormat {
52 match self {
53 Self::BirdNetV24 => LabelFormat::Text,
54 Self::BirdNetV30 | Self::PerchV2 => LabelFormat::Csv,
55 }
56 }
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum LabelFormat {
62 Text,
64 Csv,
66 Json,
68}
69
70#[derive(Debug, Clone)]
72pub struct ModelConfig {
73 pub model_type: ModelType,
75 pub sample_rate: u32,
77 pub segment_duration: f32,
79 pub sample_count: usize,
81 pub num_species: usize,
83 pub embedding_dim: Option<usize>,
85}
86
87#[derive(Debug, Clone)]
89pub struct Prediction {
90 pub species: String,
92 pub confidence: f32,
94 pub index: usize,
96}
97
98#[derive(Debug, Clone)]
100pub struct PredictionResult {
101 pub model_type: ModelType,
103 pub predictions: Vec<Prediction>,
105 pub embeddings: Option<Vec<f32>>,
107 pub raw_scores: Vec<f32>,
109}
110
111#[derive(Debug, Clone)]
113pub struct LocationScore {
114 pub species: String,
116 pub score: f32,
118 pub index: usize,
120}
121
122#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
124pub enum ExecutionProviderInfo {
125 Cpu,
127 Cuda,
129 TensorRt,
131 DirectMl,
133 CoreMl,
135 Rocm,
137 OpenVino,
139 OneDnn,
141 Qnn,
143 Acl,
145 ArmNn,
147}
148
149impl ExecutionProviderInfo {
150 #[must_use]
152 pub const fn as_str(self) -> &'static str {
153 match self {
154 Self::Cpu => "CPU",
155 Self::Cuda => "CUDA",
156 Self::TensorRt => "TensorRT",
157 Self::DirectMl => "DirectML",
158 Self::CoreMl => "CoreML",
159 Self::Rocm => "ROCm",
160 Self::OpenVino => "OpenVINO",
161 Self::OneDnn => "oneDNN",
162 Self::Qnn => "QNN",
163 Self::Acl => "ACL",
164 Self::ArmNn => "ArmNN",
165 }
166 }
167
168 #[must_use]
170 pub const fn category(self) -> &'static str {
171 match self {
172 Self::Cpu => "CPU",
173 Self::Cuda | Self::TensorRt | Self::Rocm | Self::DirectMl => "GPU",
174 Self::CoreMl => "Neural Engine",
175 Self::Qnn => "NPU",
176 Self::OpenVino | Self::OneDnn | Self::Acl | Self::ArmNn => "Accelerator",
177 }
178 }
179}
180
181impl std::fmt::Display for ExecutionProviderInfo {
182 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183 write!(f, "{}", self.as_str())
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 #![allow(clippy::unwrap_used)]
190 #![allow(clippy::float_cmp)]
191 #![allow(clippy::cast_precision_loss)]
192 use super::*;
193
194 #[test]
195 fn test_birdnet_v24_properties() {
196 let model = ModelType::BirdNetV24;
197 assert_eq!(model.sample_rate(), 48_000);
198 assert_eq!(model.segment_duration(), 3.0);
199 assert_eq!(model.sample_count(), 144_000);
200 assert!(!model.has_embeddings());
201 assert_eq!(model.expected_label_format(), LabelFormat::Text);
202 }
203
204 #[test]
205 fn test_birdnet_v30_properties() {
206 let model = ModelType::BirdNetV30;
207 assert_eq!(model.sample_rate(), 32_000);
208 assert_eq!(model.segment_duration(), 5.0);
209 assert_eq!(model.sample_count(), 160_000);
210 assert!(model.has_embeddings());
211 assert_eq!(model.expected_label_format(), LabelFormat::Csv);
212 }
213
214 #[test]
215 fn test_perch_v2_properties() {
216 let model = ModelType::PerchV2;
217 assert_eq!(model.sample_rate(), 32_000);
218 assert_eq!(model.segment_duration(), 5.0);
219 assert_eq!(model.sample_count(), 160_000);
220 assert!(model.has_embeddings());
221 assert_eq!(model.expected_label_format(), LabelFormat::Csv);
222 }
223
224 #[test]
225 fn test_sample_count_matches_rate_times_duration() {
226 for model in [
227 ModelType::BirdNetV24,
228 ModelType::BirdNetV30,
229 ModelType::PerchV2,
230 ] {
231 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
232 let expected = (model.sample_rate() as f32 * model.segment_duration()) as usize;
233 assert_eq!(model.sample_count(), expected);
234 }
235 }
236
237 #[test]
238 fn test_location_score_creation() {
239 let score = LocationScore {
240 species: "Turdus merula_Common Blackbird".to_string(),
241 score: 0.85,
242 index: 42,
243 };
244 assert_eq!(score.species, "Turdus merula_Common Blackbird");
245 assert_eq!(score.score, 0.85);
246 assert_eq!(score.index, 42);
247 }
248
249 #[test]
250 fn test_execution_provider_display() {
251 assert_eq!(ExecutionProviderInfo::Cpu.to_string(), "CPU");
252 assert_eq!(ExecutionProviderInfo::Cuda.to_string(), "CUDA");
253 assert_eq!(ExecutionProviderInfo::TensorRt.to_string(), "TensorRT");
254 assert_eq!(ExecutionProviderInfo::DirectMl.to_string(), "DirectML");
255 assert_eq!(ExecutionProviderInfo::CoreMl.to_string(), "CoreML");
256 assert_eq!(ExecutionProviderInfo::Rocm.to_string(), "ROCm");
257 assert_eq!(ExecutionProviderInfo::OpenVino.to_string(), "OpenVINO");
258 assert_eq!(ExecutionProviderInfo::OneDnn.to_string(), "oneDNN");
259 assert_eq!(ExecutionProviderInfo::Qnn.to_string(), "QNN");
260 assert_eq!(ExecutionProviderInfo::Acl.to_string(), "ACL");
261 assert_eq!(ExecutionProviderInfo::ArmNn.to_string(), "ArmNN");
262 }
263
264 #[test]
265 fn test_execution_provider_category_cpu() {
266 assert_eq!(ExecutionProviderInfo::Cpu.category(), "CPU");
267 }
268
269 #[test]
270 fn test_execution_provider_category_gpu() {
271 assert_eq!(ExecutionProviderInfo::Cuda.category(), "GPU");
272 assert_eq!(ExecutionProviderInfo::TensorRt.category(), "GPU");
273 assert_eq!(ExecutionProviderInfo::Rocm.category(), "GPU");
274 assert_eq!(ExecutionProviderInfo::DirectMl.category(), "GPU");
275 }
276
277 #[test]
278 fn test_execution_provider_category_neural_engine() {
279 assert_eq!(ExecutionProviderInfo::CoreMl.category(), "Neural Engine");
280 }
281
282 #[test]
283 fn test_execution_provider_category_npu() {
284 assert_eq!(ExecutionProviderInfo::Qnn.category(), "NPU");
285 }
286
287 #[test]
288 fn test_execution_provider_category_accelerator() {
289 assert_eq!(ExecutionProviderInfo::OpenVino.category(), "Accelerator");
290 assert_eq!(ExecutionProviderInfo::OneDnn.category(), "Accelerator");
291 assert_eq!(ExecutionProviderInfo::Acl.category(), "Accelerator");
292 assert_eq!(ExecutionProviderInfo::ArmNn.category(), "Accelerator");
293 }
294}