Skip to main content

entrenar/hf_pipeline/
error.rs

1//! Error types for HuggingFace pipeline operations
2//!
3//! Comprehensive error handling with retry hints and recovery options.
4
5use std::path::PathBuf;
6use std::time::Duration;
7use thiserror::Error;
8
9/// Result type for HF pipeline operations
10pub type Result<T> = std::result::Result<T, FetchError>;
11
12/// Errors that can occur during HuggingFace operations
13///
14/// Designed for Jidoka (built-in quality) - explicit error types enable
15/// proper error handling and recovery strategies.
16#[derive(Debug, Error)]
17pub enum FetchError {
18    /// Network timeout during download
19    #[error("Network timeout for {repo} after {elapsed:?}")]
20    NetworkTimeout { repo: String, elapsed: Duration },
21
22    /// Rate limited by HuggingFace API
23    #[error("Rate limited, retry after {retry_after:?}")]
24    RateLimited { retry_after: Duration },
25
26    /// Model repository not found
27    #[error("Repository not found: {repo}")]
28    ModelNotFound { repo: String },
29
30    /// File not found in repository
31    #[error("File not found in {repo}: {file}")]
32    FileNotFound { repo: String, file: String },
33
34    /// Downloaded file is corrupt (checksum mismatch)
35    #[error("Corrupt file at {path}: expected SHA256 {expected_hash}, got {actual_hash}")]
36    CorruptFile { path: PathBuf, expected_hash: String, actual_hash: String },
37
38    /// Insufficient disk space
39    #[error("Insufficient disk space: need {required} bytes, have {available} bytes")]
40    InsufficientDisk { required: u64, available: u64 },
41
42    /// Out of memory during model loading
43    #[error("Out of memory: model requires {required} bytes, available {available} bytes")]
44    OutOfMemory { required: u64, available: u64 },
45
46    /// Authentication failed
47    #[error("Authentication failed: {message}")]
48    AuthenticationFailed { message: String },
49
50    /// Missing authentication token
51    #[error("Missing HF_TOKEN - set environment variable or use with_token()")]
52    MissingToken,
53
54    /// Invalid repository ID format
55    #[error("Invalid repository ID format (expected 'org/name'): {repo_id}")]
56    InvalidRepoId { repo_id: String },
57
58    /// Unsupported model format
59    #[error("Unsupported model format: {format}")]
60    UnsupportedFormat { format: String },
61
62    /// SECURITY: PyTorch pickle file detected
63    #[error("SECURITY: PyTorch .bin files may contain arbitrary code. Enable allow_pytorch_pickle to proceed.")]
64    PickleSecurityRisk,
65
66    /// Model config parsing error
67    #[error("Failed to parse config.json: {message}")]
68    ConfigParseError { message: String },
69
70    /// Tensor shape mismatch
71    #[error("Tensor shape mismatch for {tensor}: expected {expected:?}, got {actual:?}")]
72    ShapeMismatch { tensor: String, expected: Vec<usize>, actual: Vec<usize> },
73
74    /// IO error
75    #[error("IO error: {0}")]
76    Io(#[from] std::io::Error),
77
78    /// JSON parsing error
79    #[error("JSON error: {0}")]
80    Json(#[from] serde_json::Error),
81
82    /// SafeTensors parsing error
83    #[error("SafeTensors parse error: {message}")]
84    SafeTensorsParseError { message: String },
85
86    /// Leaderboard not found
87    #[error("Leaderboard not found: {kind}")]
88    LeaderboardNotFound { kind: String },
89
90    /// Leaderboard data parsing error
91    #[error("Failed to parse leaderboard data: {message}")]
92    LeaderboardParseError { message: String },
93
94    /// HTTP request error
95    #[error("HTTP error: {message}")]
96    HttpError { message: String },
97
98    /// GGUF write/serialization error
99    #[error("GGUF write error: {message}")]
100    GgufWriteError { message: String },
101}
102
103impl FetchError {
104    /// Check if error is retryable
105    #[must_use]
106    pub fn is_retryable(&self) -> bool {
107        matches!(self, Self::NetworkTimeout { .. } | Self::RateLimited { .. })
108    }
109
110    /// Get retry delay if applicable
111    #[must_use]
112    pub fn retry_after(&self) -> Option<Duration> {
113        match self {
114            Self::RateLimited { retry_after } => Some(*retry_after),
115            Self::NetworkTimeout { .. } => Some(Duration::from_secs(5)),
116            _ => None,
117        }
118    }
119
120    /// Check if error is a security concern
121    #[must_use]
122    pub fn is_security_risk(&self) -> bool {
123        matches!(self, Self::PickleSecurityRisk | Self::CorruptFile { .. })
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130
131    #[test]
132    fn test_network_timeout_is_retryable() {
133        let err = FetchError::NetworkTimeout {
134            repo: "test/model".into(),
135            elapsed: Duration::from_secs(30),
136        };
137        assert!(err.is_retryable());
138        assert!(err.retry_after().is_some());
139    }
140
141    #[test]
142    fn test_rate_limited_is_retryable() {
143        let err = FetchError::RateLimited { retry_after: Duration::from_secs(60) };
144        assert!(err.is_retryable());
145        assert_eq!(err.retry_after(), Some(Duration::from_secs(60)));
146    }
147
148    #[test]
149    fn test_model_not_found_not_retryable() {
150        let err = FetchError::ModelNotFound { repo: "test/model".into() };
151        assert!(!err.is_retryable());
152        assert!(err.retry_after().is_none());
153    }
154
155    #[test]
156    fn test_pickle_is_security_risk() {
157        let err = FetchError::PickleSecurityRisk;
158        assert!(err.is_security_risk());
159        assert!(!err.is_retryable());
160    }
161
162    #[test]
163    fn test_corrupt_file_is_security_risk() {
164        let err = FetchError::CorruptFile {
165            path: PathBuf::from("/tmp/model.safetensors"),
166            expected_hash: "abc123".into(),
167            actual_hash: "def456".into(),
168        };
169        assert!(err.is_security_risk());
170    }
171
172    #[test]
173    fn test_missing_token_display() {
174        let err = FetchError::MissingToken;
175        let msg = err.to_string();
176        assert!(msg.contains("HF_TOKEN"));
177    }
178
179    #[test]
180    fn test_invalid_repo_id_display() {
181        let err = FetchError::InvalidRepoId { repo_id: "invalid".into() };
182        let msg = err.to_string();
183        assert!(msg.contains("org/name"));
184    }
185
186    #[test]
187    fn test_all_error_variants_display() {
188        // Ensure all variants have proper Display
189        let errors: Vec<FetchError> = vec![
190            FetchError::NetworkTimeout { repo: "r".into(), elapsed: Duration::from_secs(1) },
191            FetchError::RateLimited { retry_after: Duration::from_secs(1) },
192            FetchError::ModelNotFound { repo: "r".into() },
193            FetchError::FileNotFound { repo: "r".into(), file: "f".into() },
194            FetchError::CorruptFile {
195                path: PathBuf::from("p"),
196                expected_hash: "e".into(),
197                actual_hash: "a".into(),
198            },
199            FetchError::InsufficientDisk { required: 100, available: 50 },
200            FetchError::OutOfMemory { required: 100, available: 50 },
201            FetchError::AuthenticationFailed { message: "m".into() },
202            FetchError::MissingToken,
203            FetchError::InvalidRepoId { repo_id: "r".into() },
204            FetchError::UnsupportedFormat { format: "f".into() },
205            FetchError::PickleSecurityRisk,
206            FetchError::ConfigParseError { message: "m".into() },
207            FetchError::ShapeMismatch {
208                tensor: "t".into(),
209                expected: vec![1, 2],
210                actual: vec![3, 4],
211            },
212            FetchError::LeaderboardNotFound { kind: "OpenASR".into() },
213            FetchError::LeaderboardParseError { message: "missing field".into() },
214            FetchError::HttpError { message: "connection refused".into() },
215            FetchError::GgufWriteError { message: "alignment error".into() },
216        ];
217
218        for err in errors {
219            let msg = err.to_string();
220            assert!(!msg.is_empty(), "Error display should not be empty: {err:?}");
221        }
222    }
223
224    #[test]
225    fn test_io_error_conversion() {
226        let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
227        let fetch_err: FetchError = io_err.into();
228        assert!(matches!(fetch_err, FetchError::Io(_)));
229    }
230}