Skip to main content

anno/
error.rs

1//! Error types for anno.
2
3use thiserror::Error;
4
5/// Result type for anno operations.
6pub type Result<T> = std::result::Result<T, Error>;
7
8/// Error type for anno operations.
9#[derive(Error, Debug)]
10#[non_exhaustive]
11pub enum Error {
12    /// Model initialization failed.
13    #[error("Model initialization failed: {0}")]
14    ModelInit(String),
15
16    /// Model inference failed.
17    #[error("Inference failed: {0}")]
18    Inference(String),
19
20    /// Invalid input provided.
21    #[error("Invalid input: {0}")]
22    InvalidInput(String),
23
24    /// IO error.
25    #[error("IO error: {0}")]
26    Io(#[from] std::io::Error),
27
28    /// Dataset loading/parsing error.
29    #[error("Dataset error: {0}")]
30    Dataset(String),
31
32    /// Feature not available.
33    #[error("Feature not available: {0}")]
34    FeatureNotAvailable(String),
35
36    /// Parse error.
37    #[error("Parse error: {0}")]
38    Parse(String),
39
40    /// Evaluation error.
41    #[error("Evaluation error: {0}")]
42    Evaluation(String),
43
44    /// Model retrieval error (downloading from HuggingFace).
45    #[error("Retrieval error: {0}")]
46    Retrieval(String),
47
48    /// Candle ML error (when candle feature enabled).
49    #[cfg(feature = "candle")]
50    #[error("Candle error: {0}")]
51    Candle(#[from] candle_core::Error),
52
53    /// Corpus operation error.
54    #[error("Corpus error: {0}")]
55    Corpus(String),
56
57    /// Track reference error.
58    #[error("Track reference error: {0}")]
59    TrackRef(String),
60}
61
62impl Error {
63    /// Create a model initialization error.
64    pub fn model_init(msg: impl Into<String>) -> Self {
65        Error::ModelInit(msg.into())
66    }
67
68    /// Create an inference error.
69    pub fn inference(msg: impl Into<String>) -> Self {
70        Error::Inference(msg.into())
71    }
72
73    /// Create an invalid input error.
74    pub fn invalid_input(msg: impl Into<String>) -> Self {
75        Error::InvalidInput(msg.into())
76    }
77
78    /// Create a dataset error.
79    pub fn dataset(msg: impl Into<String>) -> Self {
80        Error::Dataset(msg.into())
81    }
82
83    /// Create a feature not available error.
84    pub fn feature_not_available(feature: impl Into<String>) -> Self {
85        Error::FeatureNotAvailable(feature.into())
86    }
87
88    /// Create a parse error.
89    pub fn parse(msg: impl Into<String>) -> Self {
90        Error::Parse(msg.into())
91    }
92
93    /// Create an evaluation error.
94    pub fn evaluation(msg: impl Into<String>) -> Self {
95        Error::Evaluation(msg.into())
96    }
97
98    /// Create a retrieval error.
99    pub fn retrieval(msg: impl Into<String>) -> Self {
100        Error::Retrieval(msg.into())
101    }
102
103    /// Create a corpus error.
104    pub fn corpus(msg: impl Into<String>) -> Self {
105        Error::Corpus(msg.into())
106    }
107
108    /// Create a track reference error.
109    pub fn track_ref(msg: impl Into<String>) -> Self {
110        Error::TrackRef(msg.into())
111    }
112}
113
114/// Convert HuggingFace API errors to our Error type.
115/// Only available when hf-hub is in the dependency tree (onnx or candle features).
116#[cfg(any(feature = "onnx", feature = "candle"))]
117impl From<hf_hub::api::sync::ApiError> for Error {
118    fn from(err: hf_hub::api::sync::ApiError) -> Self {
119        Error::Retrieval(format!("{}", err))
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126
127    #[test]
128    fn test_error_constructors() {
129        let e = Error::model_init("test model init");
130        assert!(e.to_string().contains("Model initialization failed"));
131        assert!(e.to_string().contains("test model init"));
132
133        let e = Error::inference("test inference");
134        assert!(e.to_string().contains("Inference failed"));
135
136        let e = Error::invalid_input("test input");
137        assert!(e.to_string().contains("Invalid input"));
138
139        let e = Error::dataset("test dataset");
140        assert!(e.to_string().contains("Dataset error"));
141
142        let e = Error::feature_not_available("test feature");
143        assert!(e.to_string().contains("Feature not available"));
144
145        let e = Error::parse("test parse");
146        assert!(e.to_string().contains("Parse error"));
147
148        let e = Error::evaluation("test eval");
149        assert!(e.to_string().contains("Evaluation error"));
150
151        let e = Error::retrieval("test retrieval");
152        assert!(e.to_string().contains("Retrieval error"));
153
154        let e = Error::corpus("test corpus");
155        assert!(e.to_string().contains("Corpus error"));
156
157        let e = Error::track_ref("test track");
158        assert!(e.to_string().contains("Track reference error"));
159    }
160
161    #[test]
162    fn test_error_debug_display() {
163        let e = Error::ModelInit("debug test".to_string());
164        let debug = format!("{:?}", e);
165        assert!(debug.contains("ModelInit"));
166        assert!(debug.contains("debug test"));
167
168        let display = format!("{}", e);
169        assert!(display.contains("Model initialization failed"));
170    }
171
172    #[test]
173    fn test_io_error_conversion() {
174        let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
175        let e: Error = io_err.into();
176        assert!(e.to_string().contains("IO error"));
177    }
178
179    #[test]
180    fn test_result_type_alias() {
181        fn returns_result() -> Result<i32> {
182            Ok(42)
183        }
184        assert_eq!(returns_result().unwrap(), 42);
185
186        fn returns_error() -> Result<i32> {
187            Err(Error::invalid_input("bad"))
188        }
189        assert!(returns_error().is_err());
190    }
191}