Skip to main content

ferrolearn_core/
error.rs

1//! Error types for the ferrolearn framework.
2//!
3//! This module defines [`FerroError`], the unified error type used throughout
4//! all ferrolearn crates. Each variant carries diagnostic context to help
5//! users identify and fix problems.
6
7#[cfg(not(feature = "std"))]
8use alloc::{string::String, vec::Vec};
9
10use core::fmt;
11
12/// The unified error type for all ferrolearn operations.
13///
14/// Every public function in ferrolearn returns `Result<T, FerroError>`.
15/// The enum is `#[non_exhaustive]` so that new variants can be added in
16/// future minor releases without breaking downstream code.
17///
18/// # Examples
19///
20/// ```
21/// use ferrolearn_core::FerroError;
22///
23/// let err = FerroError::ShapeMismatch {
24///     expected: vec![100, 10],
25///     actual: vec![100, 5],
26///     context: "feature matrix".into(),
27/// };
28/// assert!(err.to_string().contains("Shape mismatch"));
29/// ```
30#[derive(Debug, thiserror::Error)]
31#[non_exhaustive]
32pub enum FerroError {
33    /// Array dimensions do not match the expected shape.
34    #[error("Shape mismatch in {context}: expected {expected:?}, got {actual:?}")]
35    ShapeMismatch {
36        /// The expected dimensions.
37        expected: Vec<usize>,
38        /// The actual dimensions encountered.
39        actual: Vec<usize>,
40        /// Human-readable description of where the mismatch occurred.
41        context: String,
42    },
43
44    /// Not enough samples were provided for the requested operation.
45    #[error("Insufficient samples: need at least {required}, got {actual} ({context})")]
46    InsufficientSamples {
47        /// The minimum number of samples required.
48        required: usize,
49        /// The actual number of samples provided.
50        actual: usize,
51        /// Human-readable description of the operation.
52        context: String,
53    },
54
55    /// An iterative algorithm did not converge within the allowed iterations.
56    #[error("Convergence failure after {iterations} iterations: {message}")]
57    ConvergenceFailure {
58        /// The number of iterations that were attempted.
59        iterations: usize,
60        /// A description of the convergence issue.
61        message: String,
62    },
63
64    /// A hyperparameter or configuration value is invalid.
65    #[error("Invalid parameter `{name}`: {reason}")]
66    InvalidParameter {
67        /// The name of the parameter.
68        name: String,
69        /// Why the value is invalid.
70        reason: String,
71    },
72
73    /// A numerical computation produced NaN, infinity, or other instability.
74    #[error("Numerical instability: {message}")]
75    NumericalInstability {
76        /// A description of the numerical issue.
77        message: String,
78    },
79
80    /// An I/O error occurred during data loading or model persistence.
81    #[cfg(feature = "std")]
82    #[error("I/O error: {0}")]
83    IoError(#[from] std::io::Error),
84
85    /// A serialization or deserialization error occurred.
86    #[error("Serialization error: {message}")]
87    SerdeError {
88        /// A description of the serialization issue.
89        message: String,
90    },
91}
92
93/// A convenience type alias for `Result<T, FerroError>`.
94pub type FerroResult<T> = Result<T, FerroError>;
95
96/// Diagnostic context attached to shape-mismatch errors.
97///
98/// This struct provides a builder-style API for constructing
99/// descriptive [`FerroError::ShapeMismatch`] errors.
100///
101/// # Examples
102///
103/// ```
104/// use ferrolearn_core::error::ShapeMismatchContext;
105///
106/// let ctx = ShapeMismatchContext::new("predict input")
107///     .expected(&[100, 10])
108///     .actual(&[100, 5]);
109/// let err = ctx.build();
110/// assert!(err.to_string().contains("predict input"));
111/// ```
112#[derive(Debug, Clone)]
113pub struct ShapeMismatchContext {
114    context: String,
115    expected: Vec<usize>,
116    actual: Vec<usize>,
117}
118
119impl ShapeMismatchContext {
120    /// Create a new context with the given description.
121    pub fn new(context: impl Into<String>) -> Self {
122        Self {
123            context: context.into(),
124            expected: Vec::new(),
125            actual: Vec::new(),
126        }
127    }
128
129    /// Set the expected shape.
130    #[must_use]
131    pub fn expected(mut self, shape: &[usize]) -> Self {
132        self.expected = shape.to_vec();
133        self
134    }
135
136    /// Set the actual shape.
137    #[must_use]
138    pub fn actual(mut self, shape: &[usize]) -> Self {
139        self.actual = shape.to_vec();
140        self
141    }
142
143    /// Build the [`FerroError::ShapeMismatch`] error.
144    pub fn build(self) -> FerroError {
145        FerroError::ShapeMismatch {
146            expected: self.expected,
147            actual: self.actual,
148            context: self.context,
149        }
150    }
151}
152
153impl fmt::Display for ShapeMismatchContext {
154    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
155        write!(
156            f,
157            "ShapeMismatchContext({}, expected {:?}, actual {:?})",
158            self.context, self.expected, self.actual
159        )
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166
167    #[test]
168    fn test_shape_mismatch_display() {
169        let err = FerroError::ShapeMismatch {
170            expected: vec![100, 10],
171            actual: vec![100, 5],
172            context: "feature matrix".into(),
173        };
174        let msg = err.to_string();
175        assert!(msg.contains("Shape mismatch"));
176        assert!(msg.contains("feature matrix"));
177        assert!(msg.contains("[100, 10]"));
178        assert!(msg.contains("[100, 5]"));
179    }
180
181    #[test]
182    fn test_insufficient_samples_display() {
183        let err = FerroError::InsufficientSamples {
184            required: 10,
185            actual: 3,
186            context: "cross-validation".into(),
187        };
188        let msg = err.to_string();
189        assert!(msg.contains("10"));
190        assert!(msg.contains("3"));
191        assert!(msg.contains("cross-validation"));
192    }
193
194    #[test]
195    fn test_convergence_failure_display() {
196        let err = FerroError::ConvergenceFailure {
197            iterations: 1000,
198            message: "loss did not decrease".into(),
199        };
200        let msg = err.to_string();
201        assert!(msg.contains("1000"));
202        assert!(msg.contains("loss did not decrease"));
203    }
204
205    #[test]
206    fn test_invalid_parameter_display() {
207        let err = FerroError::InvalidParameter {
208            name: "n_clusters".into(),
209            reason: "must be positive".into(),
210        };
211        let msg = err.to_string();
212        assert!(msg.contains("n_clusters"));
213        assert!(msg.contains("must be positive"));
214    }
215
216    #[test]
217    fn test_numerical_instability_display() {
218        let err = FerroError::NumericalInstability {
219            message: "matrix is singular".into(),
220        };
221        assert!(err.to_string().contains("matrix is singular"));
222    }
223
224    #[cfg(feature = "std")]
225    #[test]
226    fn test_io_error_from() {
227        let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
228        let ferro_err: FerroError = io_err.into();
229        assert!(ferro_err.to_string().contains("file not found"));
230    }
231
232    #[test]
233    fn test_serde_error_display() {
234        let err = FerroError::SerdeError {
235            message: "invalid JSON".into(),
236        };
237        assert!(err.to_string().contains("invalid JSON"));
238    }
239
240    #[test]
241    fn test_shape_mismatch_context_builder() {
242        let err = ShapeMismatchContext::new("test context")
243            .expected(&[3, 4])
244            .actual(&[3, 5])
245            .build();
246        let msg = err.to_string();
247        assert!(msg.contains("test context"));
248        assert!(msg.contains("[3, 4]"));
249        assert!(msg.contains("[3, 5]"));
250    }
251
252    #[test]
253    fn test_ferro_error_is_send_sync() {
254        fn assert_send_sync<T: Send + Sync>() {}
255        assert_send_sync::<FerroError>();
256    }
257}