1#[cfg(not(feature = "std"))]
8use alloc::{string::String, vec::Vec};
9
10use core::fmt;
11
12#[derive(Debug, thiserror::Error)]
31#[non_exhaustive]
32pub enum FerroError {
33 #[error("Shape mismatch in {context}: expected {expected:?}, got {actual:?}")]
35 ShapeMismatch {
36 expected: Vec<usize>,
38 actual: Vec<usize>,
40 context: String,
42 },
43
44 #[error("Insufficient samples: need at least {required}, got {actual} ({context})")]
46 InsufficientSamples {
47 required: usize,
49 actual: usize,
51 context: String,
53 },
54
55 #[error("Convergence failure after {iterations} iterations: {message}")]
57 ConvergenceFailure {
58 iterations: usize,
60 message: String,
62 },
63
64 #[error("Invalid parameter `{name}`: {reason}")]
66 InvalidParameter {
67 name: String,
69 reason: String,
71 },
72
73 #[error("Numerical instability: {message}")]
75 NumericalInstability {
76 message: String,
78 },
79
80 #[cfg(feature = "std")]
82 #[error("I/O error: {0}")]
83 IoError(#[from] std::io::Error),
84
85 #[error("Serialization error: {message}")]
87 SerdeError {
88 message: String,
90 },
91}
92
93pub type FerroResult<T> = Result<T, FerroError>;
95
96#[derive(Debug, Clone)]
113pub struct ShapeMismatchContext {
114 context: String,
115 expected: Vec<usize>,
116 actual: Vec<usize>,
117}
118
119impl ShapeMismatchContext {
120 pub fn new(context: impl Into<String>) -> Self {
122 Self {
123 context: context.into(),
124 expected: Vec::new(),
125 actual: Vec::new(),
126 }
127 }
128
129 #[must_use]
131 pub fn expected(mut self, shape: &[usize]) -> Self {
132 self.expected = shape.to_vec();
133 self
134 }
135
136 #[must_use]
138 pub fn actual(mut self, shape: &[usize]) -> Self {
139 self.actual = shape.to_vec();
140 self
141 }
142
143 pub fn build(self) -> FerroError {
145 FerroError::ShapeMismatch {
146 expected: self.expected,
147 actual: self.actual,
148 context: self.context,
149 }
150 }
151}
152
153impl fmt::Display for ShapeMismatchContext {
154 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
155 write!(
156 f,
157 "ShapeMismatchContext({}, expected {:?}, actual {:?})",
158 self.context, self.expected, self.actual
159 )
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166
167 #[test]
168 fn test_shape_mismatch_display() {
169 let err = FerroError::ShapeMismatch {
170 expected: vec![100, 10],
171 actual: vec![100, 5],
172 context: "feature matrix".into(),
173 };
174 let msg = err.to_string();
175 assert!(msg.contains("Shape mismatch"));
176 assert!(msg.contains("feature matrix"));
177 assert!(msg.contains("[100, 10]"));
178 assert!(msg.contains("[100, 5]"));
179 }
180
181 #[test]
182 fn test_insufficient_samples_display() {
183 let err = FerroError::InsufficientSamples {
184 required: 10,
185 actual: 3,
186 context: "cross-validation".into(),
187 };
188 let msg = err.to_string();
189 assert!(msg.contains("10"));
190 assert!(msg.contains("3"));
191 assert!(msg.contains("cross-validation"));
192 }
193
194 #[test]
195 fn test_convergence_failure_display() {
196 let err = FerroError::ConvergenceFailure {
197 iterations: 1000,
198 message: "loss did not decrease".into(),
199 };
200 let msg = err.to_string();
201 assert!(msg.contains("1000"));
202 assert!(msg.contains("loss did not decrease"));
203 }
204
205 #[test]
206 fn test_invalid_parameter_display() {
207 let err = FerroError::InvalidParameter {
208 name: "n_clusters".into(),
209 reason: "must be positive".into(),
210 };
211 let msg = err.to_string();
212 assert!(msg.contains("n_clusters"));
213 assert!(msg.contains("must be positive"));
214 }
215
216 #[test]
217 fn test_numerical_instability_display() {
218 let err = FerroError::NumericalInstability {
219 message: "matrix is singular".into(),
220 };
221 assert!(err.to_string().contains("matrix is singular"));
222 }
223
224 #[cfg(feature = "std")]
225 #[test]
226 fn test_io_error_from() {
227 let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
228 let ferro_err: FerroError = io_err.into();
229 assert!(ferro_err.to_string().contains("file not found"));
230 }
231
232 #[test]
233 fn test_serde_error_display() {
234 let err = FerroError::SerdeError {
235 message: "invalid JSON".into(),
236 };
237 assert!(err.to_string().contains("invalid JSON"));
238 }
239
240 #[test]
241 fn test_shape_mismatch_context_builder() {
242 let err = ShapeMismatchContext::new("test context")
243 .expected(&[3, 4])
244 .actual(&[3, 5])
245 .build();
246 let msg = err.to_string();
247 assert!(msg.contains("test context"));
248 assert!(msg.contains("[3, 4]"));
249 assert!(msg.contains("[3, 5]"));
250 }
251
252 #[test]
253 fn test_ferro_error_is_send_sync() {
254 fn assert_send_sync<T: Send + Sync>() {}
255 assert_send_sync::<FerroError>();
256 }
257}