1use std::fmt;
8
9#[derive(Debug, thiserror::Error)]
28#[non_exhaustive]
29pub enum FerroError {
30 #[error("Shape mismatch in {context}: expected {expected:?}, got {actual:?}")]
32 ShapeMismatch {
33 expected: Vec<usize>,
35 actual: Vec<usize>,
37 context: String,
39 },
40
41 #[error("Insufficient samples: need at least {required}, got {actual} ({context})")]
43 InsufficientSamples {
44 required: usize,
46 actual: usize,
48 context: String,
50 },
51
52 #[error("Convergence failure after {iterations} iterations: {message}")]
54 ConvergenceFailure {
55 iterations: usize,
57 message: String,
59 },
60
61 #[error("Invalid parameter `{name}`: {reason}")]
63 InvalidParameter {
64 name: String,
66 reason: String,
68 },
69
70 #[error("Numerical instability: {message}")]
72 NumericalInstability {
73 message: String,
75 },
76
77 #[error("I/O error: {0}")]
79 IoError(#[from] std::io::Error),
80
81 #[error("Serialization error: {message}")]
83 SerdeError {
84 message: String,
86 },
87}
88
89pub type FerroResult<T> = Result<T, FerroError>;
91
92#[derive(Debug, Clone)]
109pub struct ShapeMismatchContext {
110 context: String,
111 expected: Vec<usize>,
112 actual: Vec<usize>,
113}
114
115impl ShapeMismatchContext {
116 pub fn new(context: impl Into<String>) -> Self {
118 Self {
119 context: context.into(),
120 expected: Vec::new(),
121 actual: Vec::new(),
122 }
123 }
124
125 #[must_use]
127 pub fn expected(mut self, shape: &[usize]) -> Self {
128 self.expected = shape.to_vec();
129 self
130 }
131
132 #[must_use]
134 pub fn actual(mut self, shape: &[usize]) -> Self {
135 self.actual = shape.to_vec();
136 self
137 }
138
139 pub fn build(self) -> FerroError {
141 FerroError::ShapeMismatch {
142 expected: self.expected,
143 actual: self.actual,
144 context: self.context,
145 }
146 }
147}
148
149impl fmt::Display for ShapeMismatchContext {
150 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
151 write!(
152 f,
153 "ShapeMismatchContext({}, expected {:?}, actual {:?})",
154 self.context, self.expected, self.actual
155 )
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162
163 #[test]
164 fn test_shape_mismatch_display() {
165 let err = FerroError::ShapeMismatch {
166 expected: vec![100, 10],
167 actual: vec![100, 5],
168 context: "feature matrix".into(),
169 };
170 let msg = err.to_string();
171 assert!(msg.contains("Shape mismatch"));
172 assert!(msg.contains("feature matrix"));
173 assert!(msg.contains("[100, 10]"));
174 assert!(msg.contains("[100, 5]"));
175 }
176
177 #[test]
178 fn test_insufficient_samples_display() {
179 let err = FerroError::InsufficientSamples {
180 required: 10,
181 actual: 3,
182 context: "cross-validation".into(),
183 };
184 let msg = err.to_string();
185 assert!(msg.contains("10"));
186 assert!(msg.contains("3"));
187 assert!(msg.contains("cross-validation"));
188 }
189
190 #[test]
191 fn test_convergence_failure_display() {
192 let err = FerroError::ConvergenceFailure {
193 iterations: 1000,
194 message: "loss did not decrease".into(),
195 };
196 let msg = err.to_string();
197 assert!(msg.contains("1000"));
198 assert!(msg.contains("loss did not decrease"));
199 }
200
201 #[test]
202 fn test_invalid_parameter_display() {
203 let err = FerroError::InvalidParameter {
204 name: "n_clusters".into(),
205 reason: "must be positive".into(),
206 };
207 let msg = err.to_string();
208 assert!(msg.contains("n_clusters"));
209 assert!(msg.contains("must be positive"));
210 }
211
212 #[test]
213 fn test_numerical_instability_display() {
214 let err = FerroError::NumericalInstability {
215 message: "matrix is singular".into(),
216 };
217 assert!(err.to_string().contains("matrix is singular"));
218 }
219
220 #[test]
221 fn test_io_error_from() {
222 let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
223 let ferro_err: FerroError = io_err.into();
224 assert!(ferro_err.to_string().contains("file not found"));
225 }
226
227 #[test]
228 fn test_serde_error_display() {
229 let err = FerroError::SerdeError {
230 message: "invalid JSON".into(),
231 };
232 assert!(err.to_string().contains("invalid JSON"));
233 }
234
235 #[test]
236 fn test_shape_mismatch_context_builder() {
237 let err = ShapeMismatchContext::new("test context")
238 .expected(&[3, 4])
239 .actual(&[3, 5])
240 .build();
241 let msg = err.to_string();
242 assert!(msg.contains("test context"));
243 assert!(msg.contains("[3, 4]"));
244 assert!(msg.contains("[3, 5]"));
245 }
246
247 #[test]
248 fn test_ferro_error_is_send_sync() {
249 fn assert_send_sync<T: Send + Sync>() {}
250 assert_send_sync::<FerroError>();
251 }
252}