1use thiserror::Error;
4
5pub type Result<T> = std::result::Result<T, Error>;
7
8#[derive(Error, Debug)]
10#[non_exhaustive]
11pub enum Error {
12 #[error("Model initialization failed: {0}")]
14 ModelInit(String),
15
16 #[error("Inference failed: {0}")]
18 Inference(String),
19
20 #[error("Invalid input: {0}")]
22 InvalidInput(String),
23
24 #[error("IO error: {0}")]
26 Io(#[from] std::io::Error),
27
28 #[error("Dataset error: {0}")]
30 Dataset(String),
31
32 #[error("Feature not available: {0}")]
34 FeatureNotAvailable(String),
35
36 #[error("Parse error: {0}")]
38 Parse(String),
39
40 #[error("Evaluation error: {0}")]
42 Evaluation(String),
43
44 #[error("Retrieval error: {0}")]
46 Retrieval(String),
47
48 #[cfg(feature = "candle")]
50 #[error("Candle error: {0}")]
51 Candle(#[from] candle_core::Error),
52
53 #[error("Corpus error: {0}")]
55 Corpus(String),
56
57 #[error("Track reference error: {0}")]
59 TrackRef(String),
60}
61
62impl Error {
63 pub fn model_init(msg: impl Into<String>) -> Self {
65 Error::ModelInit(msg.into())
66 }
67
68 pub fn inference(msg: impl Into<String>) -> Self {
70 Error::Inference(msg.into())
71 }
72
73 pub fn invalid_input(msg: impl Into<String>) -> Self {
75 Error::InvalidInput(msg.into())
76 }
77
78 pub fn dataset(msg: impl Into<String>) -> Self {
80 Error::Dataset(msg.into())
81 }
82
83 pub fn feature_not_available(feature: impl Into<String>) -> Self {
85 Error::FeatureNotAvailable(feature.into())
86 }
87
88 pub fn parse(msg: impl Into<String>) -> Self {
90 Error::Parse(msg.into())
91 }
92
93 pub fn evaluation(msg: impl Into<String>) -> Self {
95 Error::Evaluation(msg.into())
96 }
97
98 pub fn retrieval(msg: impl Into<String>) -> Self {
100 Error::Retrieval(msg.into())
101 }
102
103 pub fn corpus(msg: impl Into<String>) -> Self {
105 Error::Corpus(msg.into())
106 }
107
108 pub fn track_ref(msg: impl Into<String>) -> Self {
110 Error::TrackRef(msg.into())
111 }
112}
113
114#[cfg(any(feature = "onnx", feature = "candle"))]
117impl From<hf_hub::api::sync::ApiError> for Error {
118 fn from(err: hf_hub::api::sync::ApiError) -> Self {
119 Error::Retrieval(format!("{}", err))
120 }
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126
127 #[test]
128 fn test_error_constructors() {
129 let e = Error::model_init("test model init");
130 assert!(e.to_string().contains("Model initialization failed"));
131 assert!(e.to_string().contains("test model init"));
132
133 let e = Error::inference("test inference");
134 assert!(e.to_string().contains("Inference failed"));
135
136 let e = Error::invalid_input("test input");
137 assert!(e.to_string().contains("Invalid input"));
138
139 let e = Error::dataset("test dataset");
140 assert!(e.to_string().contains("Dataset error"));
141
142 let e = Error::feature_not_available("test feature");
143 assert!(e.to_string().contains("Feature not available"));
144
145 let e = Error::parse("test parse");
146 assert!(e.to_string().contains("Parse error"));
147
148 let e = Error::evaluation("test eval");
149 assert!(e.to_string().contains("Evaluation error"));
150
151 let e = Error::retrieval("test retrieval");
152 assert!(e.to_string().contains("Retrieval error"));
153
154 let e = Error::corpus("test corpus");
155 assert!(e.to_string().contains("Corpus error"));
156
157 let e = Error::track_ref("test track");
158 assert!(e.to_string().contains("Track reference error"));
159 }
160
161 #[test]
162 fn test_error_debug_display() {
163 let e = Error::ModelInit("debug test".to_string());
164 let debug = format!("{:?}", e);
165 assert!(debug.contains("ModelInit"));
166 assert!(debug.contains("debug test"));
167
168 let display = format!("{}", e);
169 assert!(display.contains("Model initialization failed"));
170 }
171
172 #[test]
173 fn test_io_error_conversion() {
174 let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
175 let e: Error = io_err.into();
176 assert!(e.to_string().contains("IO error"));
177 }
178
179 #[test]
180 fn test_result_type_alias() {
181 fn returns_result() -> Result<i32> {
182 Ok(42)
183 }
184 assert_eq!(returns_result().unwrap(), 42);
185
186 fn returns_error() -> Result<i32> {
187 Err(Error::invalid_input("bad"))
188 }
189 assert!(returns_error().is_err());
190 }
191}