Skip to main content

ferrolearn_core/
error.rs

1//! Error types for the ferrolearn framework.
2//!
3//! This module defines [`FerroError`], the unified error type used throughout
4//! all ferrolearn crates. Each variant carries diagnostic context to help
5//! users identify and fix problems.
6
7use std::fmt;
8
9/// The unified error type for all ferrolearn operations.
10///
11/// Every public function in ferrolearn returns `Result<T, FerroError>`.
12/// The enum is `#[non_exhaustive]` so that new variants can be added in
13/// future minor releases without breaking downstream code.
14///
15/// # Examples
16///
17/// ```
18/// use ferrolearn_core::FerroError;
19///
20/// let err = FerroError::ShapeMismatch {
21///     expected: vec![100, 10],
22///     actual: vec![100, 5],
23///     context: "feature matrix".into(),
24/// };
25/// assert!(err.to_string().contains("Shape mismatch"));
26/// ```
27#[derive(Debug, thiserror::Error)]
28#[non_exhaustive]
29pub enum FerroError {
30    /// Array dimensions do not match the expected shape.
31    #[error("Shape mismatch in {context}: expected {expected:?}, got {actual:?}")]
32    ShapeMismatch {
33        /// The expected dimensions.
34        expected: Vec<usize>,
35        /// The actual dimensions encountered.
36        actual: Vec<usize>,
37        /// Human-readable description of where the mismatch occurred.
38        context: String,
39    },
40
41    /// Not enough samples were provided for the requested operation.
42    #[error("Insufficient samples: need at least {required}, got {actual} ({context})")]
43    InsufficientSamples {
44        /// The minimum number of samples required.
45        required: usize,
46        /// The actual number of samples provided.
47        actual: usize,
48        /// Human-readable description of the operation.
49        context: String,
50    },
51
52    /// An iterative algorithm did not converge within the allowed iterations.
53    #[error("Convergence failure after {iterations} iterations: {message}")]
54    ConvergenceFailure {
55        /// The number of iterations that were attempted.
56        iterations: usize,
57        /// A description of the convergence issue.
58        message: String,
59    },
60
61    /// A hyperparameter or configuration value is invalid.
62    #[error("Invalid parameter `{name}`: {reason}")]
63    InvalidParameter {
64        /// The name of the parameter.
65        name: String,
66        /// Why the value is invalid.
67        reason: String,
68    },
69
70    /// A numerical computation produced NaN, infinity, or other instability.
71    #[error("Numerical instability: {message}")]
72    NumericalInstability {
73        /// A description of the numerical issue.
74        message: String,
75    },
76
77    /// An I/O error occurred during data loading or model persistence.
78    #[error("I/O error: {0}")]
79    IoError(#[from] std::io::Error),
80
81    /// A serialization or deserialization error occurred.
82    #[error("Serialization error: {message}")]
83    SerdeError {
84        /// A description of the serialization issue.
85        message: String,
86    },
87}
88
89/// A convenience type alias for `Result<T, FerroError>`.
90pub type FerroResult<T> = Result<T, FerroError>;
91
92/// Diagnostic context attached to shape-mismatch errors.
93///
94/// This struct provides a builder-style API for constructing
95/// descriptive [`FerroError::ShapeMismatch`] errors.
96///
97/// # Examples
98///
99/// ```
100/// use ferrolearn_core::error::ShapeMismatchContext;
101///
102/// let ctx = ShapeMismatchContext::new("predict input")
103///     .expected(&[100, 10])
104///     .actual(&[100, 5]);
105/// let err = ctx.build();
106/// assert!(err.to_string().contains("predict input"));
107/// ```
108#[derive(Debug, Clone)]
109pub struct ShapeMismatchContext {
110    context: String,
111    expected: Vec<usize>,
112    actual: Vec<usize>,
113}
114
115impl ShapeMismatchContext {
116    /// Create a new context with the given description.
117    pub fn new(context: impl Into<String>) -> Self {
118        Self {
119            context: context.into(),
120            expected: Vec::new(),
121            actual: Vec::new(),
122        }
123    }
124
125    /// Set the expected shape.
126    #[must_use]
127    pub fn expected(mut self, shape: &[usize]) -> Self {
128        self.expected = shape.to_vec();
129        self
130    }
131
132    /// Set the actual shape.
133    #[must_use]
134    pub fn actual(mut self, shape: &[usize]) -> Self {
135        self.actual = shape.to_vec();
136        self
137    }
138
139    /// Build the [`FerroError::ShapeMismatch`] error.
140    pub fn build(self) -> FerroError {
141        FerroError::ShapeMismatch {
142            expected: self.expected,
143            actual: self.actual,
144            context: self.context,
145        }
146    }
147}
148
149impl fmt::Display for ShapeMismatchContext {
150    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
151        write!(
152            f,
153            "ShapeMismatchContext({}, expected {:?}, actual {:?})",
154            self.context, self.expected, self.actual
155        )
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162
163    #[test]
164    fn test_shape_mismatch_display() {
165        let err = FerroError::ShapeMismatch {
166            expected: vec![100, 10],
167            actual: vec![100, 5],
168            context: "feature matrix".into(),
169        };
170        let msg = err.to_string();
171        assert!(msg.contains("Shape mismatch"));
172        assert!(msg.contains("feature matrix"));
173        assert!(msg.contains("[100, 10]"));
174        assert!(msg.contains("[100, 5]"));
175    }
176
177    #[test]
178    fn test_insufficient_samples_display() {
179        let err = FerroError::InsufficientSamples {
180            required: 10,
181            actual: 3,
182            context: "cross-validation".into(),
183        };
184        let msg = err.to_string();
185        assert!(msg.contains("10"));
186        assert!(msg.contains("3"));
187        assert!(msg.contains("cross-validation"));
188    }
189
190    #[test]
191    fn test_convergence_failure_display() {
192        let err = FerroError::ConvergenceFailure {
193            iterations: 1000,
194            message: "loss did not decrease".into(),
195        };
196        let msg = err.to_string();
197        assert!(msg.contains("1000"));
198        assert!(msg.contains("loss did not decrease"));
199    }
200
201    #[test]
202    fn test_invalid_parameter_display() {
203        let err = FerroError::InvalidParameter {
204            name: "n_clusters".into(),
205            reason: "must be positive".into(),
206        };
207        let msg = err.to_string();
208        assert!(msg.contains("n_clusters"));
209        assert!(msg.contains("must be positive"));
210    }
211
212    #[test]
213    fn test_numerical_instability_display() {
214        let err = FerroError::NumericalInstability {
215            message: "matrix is singular".into(),
216        };
217        assert!(err.to_string().contains("matrix is singular"));
218    }
219
220    #[test]
221    fn test_io_error_from() {
222        let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
223        let ferro_err: FerroError = io_err.into();
224        assert!(ferro_err.to_string().contains("file not found"));
225    }
226
227    #[test]
228    fn test_serde_error_display() {
229        let err = FerroError::SerdeError {
230            message: "invalid JSON".into(),
231        };
232        assert!(err.to_string().contains("invalid JSON"));
233    }
234
235    #[test]
236    fn test_shape_mismatch_context_builder() {
237        let err = ShapeMismatchContext::new("test context")
238            .expected(&[3, 4])
239            .actual(&[3, 5])
240            .build();
241        let msg = err.to_string();
242        assert!(msg.contains("test context"));
243        assert!(msg.contains("[3, 4]"));
244        assert!(msg.contains("[3, 5]"));
245    }
246
247    #[test]
248    fn test_ferro_error_is_send_sync() {
249        fn assert_send_sync<T: Send + Sync>() {}
250        assert_send_sync::<FerroError>();
251    }
252}