1use thiserror::Error;
7
8pub type Result<T> = core::result::Result<T, MlError>;
10
11#[derive(Debug, Error)]
13pub enum MlError {
14 #[error("Model error: {0}")]
16 Model(#[from] ModelError),
17
18 #[error("Inference error: {0}")]
20 Inference(#[from] InferenceError),
21
22 #[error("Preprocessing error: {0}")]
24 Preprocessing(#[from] PreprocessingError),
25
26 #[error("Postprocessing error: {0}")]
28 Postprocessing(#[from] PostprocessingError),
29
30 #[error("OxiGDAL error: {0}")]
32 OxiGdal(#[from] oxigdal_core::OxiGdalError),
33
34 #[error("ONNX Runtime error: {0}")]
36 Ort(String),
37
38 #[error("I/O error: {0}")]
40 Io(#[from] std::io::Error),
41
42 #[error("Serialization error: {0}")]
44 Serialization(#[from] serde_json::Error),
45
46 #[error("Invalid configuration: {0}")]
48 InvalidConfig(String),
49
50 #[error("Feature not available: {feature}. Enable with feature flag: {flag}")]
52 FeatureNotAvailable {
53 feature: String,
55 flag: String,
57 },
58}
59
60#[derive(Debug, Error)]
62pub enum ModelError {
63 #[error("Model file not found: {path}")]
65 NotFound {
66 path: String,
68 },
69
70 #[error("Failed to load model: {reason}")]
72 LoadFailed {
73 reason: String,
75 },
76
77 #[error("Invalid model format: {message}")]
79 InvalidFormat {
80 message: String,
82 },
83
84 #[error("Model initialization failed: {reason}")]
86 InitializationFailed {
87 reason: String,
89 },
90
91 #[error("Incompatible model version: expected {expected}, got {actual}")]
93 IncompatibleVersion {
94 expected: String,
96 actual: String,
98 },
99
100 #[error("Missing required model input: {input_name}")]
102 MissingInput {
103 input_name: String,
105 },
106
107 #[error("Missing required model output: {output_name}")]
109 MissingOutput {
110 output_name: String,
112 },
113}
114
115#[derive(Debug, Error)]
117pub enum InferenceError {
118 #[error("Invalid input shape: expected {expected:?}, got {actual:?}")]
120 InvalidInputShape {
121 expected: Vec<usize>,
123 actual: Vec<usize>,
125 },
126
127 #[error("Invalid input type: expected {expected}, got {actual}")]
129 InvalidInputType {
130 expected: String,
132 actual: String,
134 },
135
136 #[error("Batch size mismatch: expected {expected}, got {actual}")]
138 BatchSizeMismatch {
139 expected: usize,
141 actual: usize,
143 },
144
145 #[error("Inference failed: {reason}")]
147 Failed {
148 reason: String,
150 },
151
152 #[error("Failed to parse output: {reason}")]
154 OutputParsingFailed {
155 reason: String,
157 },
158
159 #[error("GPU acceleration requested but not available: {message}")]
161 GpuNotAvailable {
162 message: String,
164 },
165}
166
167#[derive(Debug, Error)]
169pub enum PreprocessingError {
170 #[error("Invalid normalization parameters: {message}")]
172 InvalidNormalization {
173 message: String,
175 },
176
177 #[error("Tiling failed: {reason}")]
179 TilingFailed {
180 reason: String,
182 },
183
184 #[error("Padding failed: {reason}")]
186 PaddingFailed {
187 reason: String,
189 },
190
191 #[error("Invalid tile size: width={width}, height={height}")]
193 InvalidTileSize {
194 width: usize,
196 height: usize,
198 },
199
200 #[error("Channel mismatch: expected {expected}, got {actual}")]
202 ChannelMismatch {
203 expected: usize,
205 actual: usize,
207 },
208
209 #[error("Data augmentation failed: {reason}")]
211 AugmentationFailed {
212 reason: String,
214 },
215}
216
217#[derive(Debug, Error)]
219pub enum PostprocessingError {
220 #[error("Tile merging failed: {reason}")]
222 MergingFailed {
223 reason: String,
225 },
226
227 #[error("Threshold out of range: must be between 0.0 and 1.0, got {value}")]
229 InvalidThreshold {
230 value: f32,
232 },
233
234 #[error("Polygon conversion failed: {reason}")]
236 PolygonConversionFailed {
237 reason: String,
239 },
240
241 #[error("Non-maximum suppression failed: {reason}")]
243 NmsFailed {
244 reason: String,
246 },
247
248 #[error("Export failed: {reason}")]
250 ExportFailed {
251 reason: String,
253 },
254
255 #[error("Invalid class ID: {class_id}")]
257 InvalidClassId {
258 class_id: usize,
260 },
261}
262
263impl From<ort::Error> for MlError {
264 fn from(err: ort::Error) -> Self {
265 MlError::Ort(err.to_string())
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272
273 #[test]
274 fn test_error_display() {
275 let err = ModelError::NotFound {
276 path: "/path/to/model.onnx".to_string(),
277 };
278 assert!(err.to_string().contains("Model file not found"));
279 assert!(err.to_string().contains("/path/to/model.onnx"));
280 }
281
282 #[test]
283 fn test_error_conversion() {
284 let model_err = ModelError::LoadFailed {
285 reason: "test".to_string(),
286 };
287 let ml_err: MlError = model_err.into();
288 assert!(matches!(ml_err, MlError::Model(_)));
289 }
290
291 #[test]
292 fn test_invalid_threshold() {
293 let err = PostprocessingError::InvalidThreshold { value: 1.5 };
294 assert!(err.to_string().contains("1.5"));
295 }
296}