birdnet_onnx/
error.rs

1use thiserror::Error;
2
3/// Errors that can occur during classifier operations.
4#[derive(Debug, Error)]
5pub enum Error {
6    /// Audio segment has wrong number of samples.
7    #[error("input size mismatch: expected {expected} samples, got {got}")]
8    InputSize {
9        /// Expected sample count.
10        expected: usize,
11        /// Actual sample count.
12        got: usize,
13    },
14
15    /// One segment in a batch has wrong number of samples.
16    #[error("batch input size mismatch: segment {index} has {got} samples, expected {expected}")]
17    BatchInputSize {
18        /// Index of the problematic segment.
19        index: usize,
20        /// Expected sample count.
21        expected: usize,
22        /// Actual sample count.
23        got: usize,
24    },
25
26    /// Failed to detect model type from ONNX structure.
27    #[error("model detection failed: {reason}")]
28    ModelDetection {
29        /// Reason for detection failure.
30        reason: String,
31    },
32
33    /// Number of labels doesn't match model output size.
34    #[error("label count mismatch: model expects {expected}, got {got}")]
35    LabelCount {
36        /// Expected label count.
37        expected: usize,
38        /// Actual label count.
39        got: usize,
40    },
41
42    /// Model path was not provided to builder.
43    #[error("model path required")]
44    ModelPathRequired,
45
46    /// Labels were not provided to builder.
47    #[error("labels required (provide path or vec)")]
48    LabelsRequired,
49
50    /// Failed to load ONNX model.
51    #[error("failed to load model: {0}")]
52    ModelLoad(#[from] ort::Error),
53
54    /// Failed to load labels from file.
55    #[error("failed to load labels from {path}: {reason}")]
56    LabelLoad {
57        /// Path that failed to load.
58        path: String,
59        /// Reason for failure.
60        reason: String,
61    },
62
63    /// Failed to parse label file content.
64    #[error("failed to parse labels: {0}")]
65    LabelParse(String),
66
67    /// Inference execution failed.
68    #[error("inference failed: {0}")]
69    Inference(String),
70
71    /// Invalid geographic coordinates provided.
72    #[error("invalid coordinates: latitude: {latitude}, longitude: {longitude}, reason: {reason}")]
73    InvalidCoordinates {
74        /// Latitude value.
75        latitude: f32,
76        /// Longitude value.
77        longitude: f32,
78        /// Reason for invalidity.
79        reason: String,
80    },
81
82    /// Invalid date provided.
83    #[error("invalid date: month: {month}, day: {day}, reason: {reason}")]
84    InvalidDate {
85        /// Month value.
86        month: u32,
87        /// Day value.
88        day: u32,
89        /// Reason for invalidity.
90        reason: String,
91    },
92
93    /// Range filter inference failed.
94    #[error("range filter inference failed: {0}")]
95    RangeFilterInference(String),
96
97    /// Failed to initialize ONNX Runtime.
98    #[error("failed to initialize ONNX Runtime: {0}")]
99    RuntimeInit(String),
100
101    /// Audio file format is not supported.
102    #[error("unsupported audio format: {reason}")]
103    AudioFormat {
104        /// Reason for format rejection.
105        reason: String,
106    },
107
108    /// Failed to read audio file.
109    #[error("failed to read audio file {path}: {reason}")]
110    AudioRead {
111        /// Path to the audio file.
112        path: String,
113        /// Reason for failure.
114        reason: String,
115    },
116}
117
118/// Result type alias using [`Error`].
119pub type Result<T> = std::result::Result<T, Error>;
120
121#[cfg(test)]
122mod tests {
123    #![allow(clippy::unwrap_used)]
124    use super::*;
125
126    #[test]
127    fn test_input_size_error_display() {
128        let err = Error::InputSize {
129            expected: 144_000,
130            got: 100_000,
131        };
132        assert_eq!(
133            err.to_string(),
134            "input size mismatch: expected 144000 samples, got 100000"
135        );
136    }
137
138    #[test]
139    fn test_batch_input_size_error_display() {
140        let err = Error::BatchInputSize {
141            index: 3,
142            expected: 144_000,
143            got: 50_000,
144        };
145        assert_eq!(
146            err.to_string(),
147            "batch input size mismatch: segment 3 has 50000 samples, expected 144000"
148        );
149    }
150
151    #[test]
152    fn test_model_detection_error_display() {
153        let err = Error::ModelDetection {
154            reason: "unsupported model".to_string(),
155        };
156        assert_eq!(err.to_string(), "model detection failed: unsupported model");
157    }
158
159    #[test]
160    fn test_label_count_error_display() {
161        let err = Error::LabelCount {
162            expected: 6522,
163            got: 1000,
164        };
165        assert_eq!(
166            err.to_string(),
167            "label count mismatch: model expects 6522, got 1000"
168        );
169    }
170
171    #[test]
172    fn test_audio_format_error_display() {
173        let err = Error::AudioFormat {
174            reason: "WAV must be mono".to_string(),
175        };
176        assert_eq!(
177            err.to_string(),
178            "unsupported audio format: WAV must be mono"
179        );
180    }
181
182    #[test]
183    fn test_audio_read_error_display() {
184        let err = Error::AudioRead {
185            path: "/path/to/file.wav".to_string(),
186            reason: "file not found".to_string(),
187        };
188        assert_eq!(
189            err.to_string(),
190            "failed to read audio file /path/to/file.wav: file not found"
191        );
192    }
193
194    #[test]
195    fn test_invalid_coordinates_error() {
196        let err = Error::InvalidCoordinates {
197            latitude: 95.0,
198            longitude: 200.0,
199            reason: "latitude out of range".to_string(),
200        };
201        assert!(err.to_string().contains("latitude: 95"));
202        assert!(err.to_string().contains("longitude: 200"));
203    }
204
205    #[test]
206    fn test_invalid_date_error() {
207        let err = Error::InvalidDate {
208            month: 13,
209            day: 32,
210            reason: "month out of range".to_string(),
211        };
212        assert!(err.to_string().contains("month: 13"));
213        assert!(err.to_string().contains("day: 32"));
214        assert!(err.to_string().contains("month out of range"));
215    }
216
217    #[test]
218    fn test_range_filter_inference_error() {
219        let err = Error::RangeFilterInference("model invoke failed".to_string());
220        assert_eq!(
221            err.to_string(),
222            "range filter inference failed: model invoke failed"
223        );
224    }
225}