Skip to main content

oxigdal_ml/
error.rs

1//! Error types for OxiGDAL ML operations
2//!
3//! This module provides ML-specific error types that integrate with
4//! the core OxiGDAL error hierarchy.
5
6use thiserror::Error;
7
8/// The main result type for ML operations
9pub type Result<T> = core::result::Result<T, MlError>;
10
11/// ML-specific errors
12#[derive(Debug, Error)]
13pub enum MlError {
14    /// Model loading error
15    #[error("Model error: {0}")]
16    Model(#[from] ModelError),
17
18    /// Inference error
19    #[error("Inference error: {0}")]
20    Inference(#[from] InferenceError),
21
22    /// Preprocessing error
23    #[error("Preprocessing error: {0}")]
24    Preprocessing(#[from] PreprocessingError),
25
26    /// Postprocessing error
27    #[error("Postprocessing error: {0}")]
28    Postprocessing(#[from] PostprocessingError),
29
30    /// OxiGDAL core error
31    #[error("OxiGDAL error: {0}")]
32    OxiGdal(#[from] oxigdal_core::OxiGdalError),
33
34    /// ONNX Runtime error
35    #[error("ONNX Runtime error: {0}")]
36    Ort(String),
37
38    /// I/O error
39    #[error("I/O error: {0}")]
40    Io(#[from] std::io::Error),
41
42    /// Serialization error
43    #[error("Serialization error: {0}")]
44    Serialization(#[from] serde_json::Error),
45
46    /// Invalid configuration
47    #[error("Invalid configuration: {0}")]
48    InvalidConfig(String),
49
50    /// Feature not available
51    #[error("Feature not available: {feature}. Enable with feature flag: {flag}")]
52    FeatureNotAvailable {
53        /// The feature name
54        feature: String,
55        /// The required feature flag
56        flag: String,
57    },
58}
59
60/// Model-related errors
61#[derive(Debug, Error)]
62pub enum ModelError {
63    /// Model file not found
64    #[error("Model file not found: {path}")]
65    NotFound {
66        /// The model file path
67        path: String,
68    },
69
70    /// Model loading failed
71    #[error("Failed to load model: {reason}")]
72    LoadFailed {
73        /// The reason for failure
74        reason: String,
75    },
76
77    /// Invalid model format
78    #[error("Invalid model format: {message}")]
79    InvalidFormat {
80        /// Error message
81        message: String,
82    },
83
84    /// Model initialization failed
85    #[error("Model initialization failed: {reason}")]
86    InitializationFailed {
87        /// The reason for failure
88        reason: String,
89    },
90
91    /// Incompatible model version
92    #[error("Incompatible model version: expected {expected}, got {actual}")]
93    IncompatibleVersion {
94        /// Expected version
95        expected: String,
96        /// Actual version
97        actual: String,
98    },
99
100    /// Missing required input
101    #[error("Missing required model input: {input_name}")]
102    MissingInput {
103        /// Input name
104        input_name: String,
105    },
106
107    /// Missing required output
108    #[error("Missing required model output: {output_name}")]
109    MissingOutput {
110        /// Output name
111        output_name: String,
112    },
113}
114
115/// Inference-related errors
116#[derive(Debug, Error)]
117pub enum InferenceError {
118    /// Invalid input shape
119    #[error("Invalid input shape: expected {expected:?}, got {actual:?}")]
120    InvalidInputShape {
121        /// Expected shape
122        expected: Vec<usize>,
123        /// Actual shape
124        actual: Vec<usize>,
125    },
126
127    /// Invalid input type
128    #[error("Invalid input type: expected {expected}, got {actual}")]
129    InvalidInputType {
130        /// Expected type
131        expected: String,
132        /// Actual type
133        actual: String,
134    },
135
136    /// Batch size mismatch
137    #[error("Batch size mismatch: expected {expected}, got {actual}")]
138    BatchSizeMismatch {
139        /// Expected batch size
140        expected: usize,
141        /// Actual batch size
142        actual: usize,
143    },
144
145    /// Inference failed
146    #[error("Inference failed: {reason}")]
147    Failed {
148        /// The reason for failure
149        reason: String,
150    },
151
152    /// Output parsing failed
153    #[error("Failed to parse output: {reason}")]
154    OutputParsingFailed {
155        /// The reason for failure
156        reason: String,
157    },
158
159    /// GPU not available
160    #[error("GPU acceleration requested but not available: {message}")]
161    GpuNotAvailable {
162        /// Error message
163        message: String,
164    },
165}
166
167/// Preprocessing-related errors
168#[derive(Debug, Error)]
169pub enum PreprocessingError {
170    /// Invalid normalization parameters
171    #[error("Invalid normalization parameters: {message}")]
172    InvalidNormalization {
173        /// Error message
174        message: String,
175    },
176
177    /// Tiling failed
178    #[error("Tiling failed: {reason}")]
179    TilingFailed {
180        /// The reason for failure
181        reason: String,
182    },
183
184    /// Padding failed
185    #[error("Padding failed: {reason}")]
186    PaddingFailed {
187        /// The reason for failure
188        reason: String,
189    },
190
191    /// Invalid tile size
192    #[error("Invalid tile size: width={width}, height={height}")]
193    InvalidTileSize {
194        /// Tile width
195        width: usize,
196        /// Tile height
197        height: usize,
198    },
199
200    /// Channel mismatch
201    #[error("Channel mismatch: expected {expected}, got {actual}")]
202    ChannelMismatch {
203        /// Expected channels
204        expected: usize,
205        /// Actual channels
206        actual: usize,
207    },
208
209    /// Augmentation failed
210    #[error("Data augmentation failed: {reason}")]
211    AugmentationFailed {
212        /// The reason for failure
213        reason: String,
214    },
215}
216
217/// Postprocessing-related errors
218#[derive(Debug, Error)]
219pub enum PostprocessingError {
220    /// Tile merging failed
221    #[error("Tile merging failed: {reason}")]
222    MergingFailed {
223        /// The reason for failure
224        reason: String,
225    },
226
227    /// Threshold out of range
228    #[error("Threshold out of range: must be between 0.0 and 1.0, got {value}")]
229    InvalidThreshold {
230        /// The invalid threshold value
231        value: f32,
232    },
233
234    /// Polygon conversion failed
235    #[error("Polygon conversion failed: {reason}")]
236    PolygonConversionFailed {
237        /// The reason for failure
238        reason: String,
239    },
240
241    /// NMS failed
242    #[error("Non-maximum suppression failed: {reason}")]
243    NmsFailed {
244        /// The reason for failure
245        reason: String,
246    },
247
248    /// Export failed
249    #[error("Export failed: {reason}")]
250    ExportFailed {
251        /// The reason for failure
252        reason: String,
253    },
254
255    /// Invalid class ID
256    #[error("Invalid class ID: {class_id}")]
257    InvalidClassId {
258        /// The invalid class ID
259        class_id: usize,
260    },
261}
262
263impl From<ort::Error> for MlError {
264    fn from(err: ort::Error) -> Self {
265        MlError::Ort(err.to_string())
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272
273    #[test]
274    fn test_error_display() {
275        let err = ModelError::NotFound {
276            path: "/path/to/model.onnx".to_string(),
277        };
278        assert!(err.to_string().contains("Model file not found"));
279        assert!(err.to_string().contains("/path/to/model.onnx"));
280    }
281
282    #[test]
283    fn test_error_conversion() {
284        let model_err = ModelError::LoadFailed {
285            reason: "test".to_string(),
286        };
287        let ml_err: MlError = model_err.into();
288        assert!(matches!(ml_err, MlError::Model(_)));
289    }
290
291    #[test]
292    fn test_invalid_threshold() {
293        let err = PostprocessingError::InvalidThreshold { value: 1.5 };
294        assert!(err.to_string().contains("1.5"));
295    }
296}