Skip to main content

oxicuda_ssl/
error.rs

1//! Error types for `oxicuda-ssl`.
2
3use thiserror::Error;
4
5/// All errors that can be returned from `oxicuda-ssl`.
6#[derive(Debug, Error, PartialEq)]
7pub enum SslError {
8    #[error("dimension mismatch: expected {expected}, got {got}")]
9    DimensionMismatch { expected: usize, got: usize },
10
11    #[error("empty input")]
12    EmptyInput,
13
14    #[error("invalid temperature {temp}: must be > 0 and finite")]
15    InvalidTemperature { temp: f32 },
16
17    #[error("invalid momentum {momentum}: must be in [0, 1]")]
18    InvalidMomentum { momentum: f32 },
19
20    #[error("invalid mask ratio {ratio}: must be in [0, 1)")]
21    InvalidMaskRatio { ratio: f32 },
22
23    #[error("invalid number of crops: must be >= 1")]
24    InvalidNumCrops,
25
26    #[error("invalid loss weight {weight}: must be finite")]
27    InvalidLossWeight { weight: f32 },
28
29    #[error("queue capacity must be >= 1")]
30    QueueCapacityTooSmall,
31
32    #[error("queue is empty")]
33    QueueEmpty,
34
35    #[error("number of prototypes must be >= 2")]
36    NumPrototypesTooSmall,
37
38    #[error("Sinkhorn-Knopp diverged after {iters} iterations")]
39    SinkhornDiverged { iters: usize },
40
41    #[error("invalid feature dimension: must be > 0")]
42    InvalidFeatureDim,
43
44    #[error("invalid batch size: must be >= 2 (need positive + negative pairs)")]
45    BatchTooSmall,
46
47    #[error("non-finite value at: {location}")]
48    NanEncountered { location: &'static str },
49
50    #[error("invalid projector layer dim: in/hidden/out must be > 0")]
51    InvalidProjectorDim,
52
53    #[error("invalid parameter `{name}`: {reason}")]
54    InvalidParameter { name: String, reason: String },
55
56    #[error("internal error: {0}")]
57    Internal(String),
58}
59
60/// Convenience result alias.
61pub type SslResult<T> = Result<T, SslError>;
62
63#[cfg(test)]
64mod tests {
65    use super::*;
66
67    #[test]
68    fn error_display_dimension_mismatch() {
69        let e = SslError::DimensionMismatch {
70            expected: 64,
71            got: 32,
72        };
73        assert!(e.to_string().contains("64") && e.to_string().contains("32"));
74    }
75
76    #[test]
77    fn error_display_invalid_temperature() {
78        let e = SslError::InvalidTemperature { temp: -1.0 };
79        assert!(e.to_string().contains("-1"));
80    }
81
82    #[test]
83    fn error_display_invalid_momentum() {
84        let e = SslError::InvalidMomentum { momentum: 1.5 };
85        assert!(e.to_string().contains("1.5"));
86    }
87
88    #[test]
89    fn error_display_sinkhorn() {
90        let e = SslError::SinkhornDiverged { iters: 100 };
91        assert!(e.to_string().contains("100"));
92    }
93
94    #[test]
95    fn error_display_internal() {
96        let e = SslError::Internal("bad shape".into());
97        assert!(e.to_string().contains("bad shape"));
98    }
99
100    #[test]
101    fn error_equality() {
102        assert_eq!(SslError::EmptyInput, SslError::EmptyInput);
103    }
104
105    #[test]
106    fn ssl_result_ok() {
107        let r: SslResult<i32> = Ok(42);
108        assert!(r.is_ok());
109    }
110
111    #[test]
112    fn ssl_result_err() {
113        let r: SslResult<i32> = Err(SslError::InvalidNumCrops);
114        assert!(r.is_err());
115    }
116}