1use thiserror::Error;
4
5#[derive(Debug, Error, PartialEq)]
7pub enum SslError {
8 #[error("dimension mismatch: expected {expected}, got {got}")]
9 DimensionMismatch { expected: usize, got: usize },
10
11 #[error("empty input")]
12 EmptyInput,
13
14 #[error("invalid temperature {temp}: must be > 0 and finite")]
15 InvalidTemperature { temp: f32 },
16
17 #[error("invalid momentum {momentum}: must be in [0, 1]")]
18 InvalidMomentum { momentum: f32 },
19
20 #[error("invalid mask ratio {ratio}: must be in [0, 1)")]
21 InvalidMaskRatio { ratio: f32 },
22
23 #[error("invalid number of crops: must be >= 1")]
24 InvalidNumCrops,
25
26 #[error("invalid loss weight {weight}: must be finite")]
27 InvalidLossWeight { weight: f32 },
28
29 #[error("queue capacity must be >= 1")]
30 QueueCapacityTooSmall,
31
32 #[error("queue is empty")]
33 QueueEmpty,
34
35 #[error("number of prototypes must be >= 2")]
36 NumPrototypesTooSmall,
37
38 #[error("Sinkhorn-Knopp diverged after {iters} iterations")]
39 SinkhornDiverged { iters: usize },
40
41 #[error("invalid feature dimension: must be > 0")]
42 InvalidFeatureDim,
43
44 #[error("invalid batch size: must be >= 2 (need positive + negative pairs)")]
45 BatchTooSmall,
46
47 #[error("non-finite value at: {location}")]
48 NanEncountered { location: &'static str },
49
50 #[error("invalid projector layer dim: in/hidden/out must be > 0")]
51 InvalidProjectorDim,
52
53 #[error("invalid parameter `{name}`: {reason}")]
54 InvalidParameter { name: String, reason: String },
55
56 #[error("internal error: {0}")]
57 Internal(String),
58}
59
60pub type SslResult<T> = Result<T, SslError>;
62
63#[cfg(test)]
64mod tests {
65 use super::*;
66
67 #[test]
68 fn error_display_dimension_mismatch() {
69 let e = SslError::DimensionMismatch {
70 expected: 64,
71 got: 32,
72 };
73 assert!(e.to_string().contains("64") && e.to_string().contains("32"));
74 }
75
76 #[test]
77 fn error_display_invalid_temperature() {
78 let e = SslError::InvalidTemperature { temp: -1.0 };
79 assert!(e.to_string().contains("-1"));
80 }
81
82 #[test]
83 fn error_display_invalid_momentum() {
84 let e = SslError::InvalidMomentum { momentum: 1.5 };
85 assert!(e.to_string().contains("1.5"));
86 }
87
88 #[test]
89 fn error_display_sinkhorn() {
90 let e = SslError::SinkhornDiverged { iters: 100 };
91 assert!(e.to_string().contains("100"));
92 }
93
94 #[test]
95 fn error_display_internal() {
96 let e = SslError::Internal("bad shape".into());
97 assert!(e.to_string().contains("bad shape"));
98 }
99
100 #[test]
101 fn error_equality() {
102 assert_eq!(SslError::EmptyInput, SslError::EmptyInput);
103 }
104
105 #[test]
106 fn ssl_result_ok() {
107 let r: SslResult<i32> = Ok(42);
108 assert!(r.is_ok());
109 }
110
111 #[test]
112 fn ssl_result_err() {
113 let r: SslResult<i32> = Err(SslError::InvalidNumCrops);
114 assert!(r.is_err());
115 }
116}