Skip to main content

apr_cli/
error.rs

1//! Error types for apr-cli
2//!
3//! Toyota Way: Jidoka - Stop and highlight problems immediately.
4
5use std::path::PathBuf;
6use std::process::ExitCode;
7use thiserror::Error;
8
9/// Result type alias for CLI operations
10pub type Result<T> = std::result::Result<T, CliError>;
11
12/// CLI error types
13#[derive(Error, Debug)]
14pub enum CliError {
15    /// File not found
16    #[error("File not found: {0}")]
17    FileNotFound(PathBuf),
18
19    /// Not a file (e.g., directory)
20    #[error("Not a file: {0}")]
21    NotAFile(PathBuf),
22
23    /// Invalid APR format
24    #[error("Invalid APR format: {0}")]
25    InvalidFormat(String),
26
27    /// IO error
28    #[error("IO error: {0}")]
29    Io(#[from] std::io::Error),
30
31    /// Validation failed
32    #[error("Validation failed: {0}")]
33    ValidationFailed(String),
34
35    /// Aprender error
36    #[error("Aprender error: {0}")]
37    Aprender(String),
38
39    /// Model loading failed (used with inference feature)
40    #[error("Model load failed: {0}")]
41    #[allow(dead_code)]
42    ModelLoadFailed(String),
43
44    /// Inference failed (used with inference feature)
45    #[error("Inference failed: {0}")]
46    #[allow(dead_code)]
47    InferenceFailed(String),
48
49    /// Feature disabled (used when optional features are not compiled)
50    #[error("Feature not enabled: {0}")]
51    #[allow(dead_code)]
52    FeatureDisabled(String),
53
54    /// Network error
55    #[error("Network error: {0}")]
56    NetworkError(String),
57
58    /// HTTP 404 Not Found (GH-356: distinguish from other network errors)
59    #[error("HTTP 404 Not Found: {0}")]
60    HttpNotFound(String),
61}
62
63impl CliError {
64    /// Get exit code for this error
65    pub fn exit_code(&self) -> ExitCode {
66        match self {
67            Self::FileNotFound(_) | Self::NotAFile(_) => ExitCode::from(3),
68            Self::InvalidFormat(_) => ExitCode::from(4),
69            Self::Io(_) => ExitCode::from(7),
70            Self::ValidationFailed(_) => ExitCode::from(5),
71            Self::Aprender(_) => ExitCode::from(1),
72            Self::ModelLoadFailed(_) => ExitCode::from(6),
73            Self::InferenceFailed(_) => ExitCode::from(8),
74            Self::FeatureDisabled(_) => ExitCode::from(9),
75            Self::NetworkError(_) => ExitCode::from(10),
76            Self::HttpNotFound(_) => ExitCode::from(11),
77        }
78    }
79}
80
81impl From<aprender::error::AprenderError> for CliError {
82    fn from(e: aprender::error::AprenderError) -> Self {
83        Self::Aprender(e.to_string())
84    }
85}
86
87/// Resolve a model path: if given a directory, look for common model files inside.
88///
89/// HuggingFace models are stored as directories containing `model.safetensors`,
90/// `model-00001-of-NNNNN.safetensors`, or `*.gguf`. This function resolves such
91/// directories to the actual model file, avoiding "Not a file" errors.
92pub fn resolve_model_path(
93    path: &std::path::Path,
94) -> std::result::Result<std::path::PathBuf, CliError> {
95    if !path.exists() {
96        return Err(CliError::FileNotFound(path.to_path_buf()));
97    }
98    if path.is_file() {
99        return Ok(path.to_path_buf());
100    }
101    if path.is_dir() {
102        // PMAT-314: Check sharded SafeTensors index FIRST — individual shard files
103        // only contain a subset of tensors and will fail the architecture gate.
104        let index = path.join("model.safetensors.index.json");
105        if index.is_file() {
106            return Ok(index);
107        }
108        // Try common model file names in priority order
109        let candidates = [
110            "model.safetensors",
111            "model-00001-of-00001.safetensors",
112            "model-00001-of-00002.safetensors",
113            "model-00001-of-00003.safetensors",
114            "model-00001-of-00004.safetensors",
115        ];
116        for candidate in &candidates {
117            let p = path.join(candidate);
118            if p.is_file() {
119                return Ok(p);
120            }
121        }
122        // Try first .gguf file
123        if let Ok(entries) = std::fs::read_dir(path) {
124            for entry in entries.flatten() {
125                let p = entry.path();
126                if p.extension().is_some_and(|ext| ext == "gguf") && p.is_file() {
127                    return Ok(p);
128                }
129            }
130        }
131        // Try first .apr file
132        if let Ok(entries) = std::fs::read_dir(path) {
133            for entry in entries.flatten() {
134                let p = entry.path();
135                if p.extension().is_some_and(|ext| ext == "apr") && p.is_file() {
136                    return Ok(p);
137                }
138            }
139        }
140        Err(CliError::ValidationFailed(format!(
141            "Directory {} does not contain a model file (expected model.safetensors, *.gguf, or *.apr)",
142            path.display()
143        )))
144    } else {
145        Err(CliError::NotAFile(path.to_path_buf()))
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152    use std::path::PathBuf;
153
154    // ==================== Exit Code Tests ====================
155
156    #[test]
157    fn test_file_not_found_exit_code() {
158        let err = CliError::FileNotFound(PathBuf::from("/test"));
159        assert_eq!(err.exit_code(), ExitCode::from(3));
160    }
161
162    #[test]
163    fn test_not_a_file_exit_code() {
164        let err = CliError::NotAFile(PathBuf::from("/test"));
165        assert_eq!(err.exit_code(), ExitCode::from(3));
166    }
167
168    #[test]
169    fn test_invalid_format_exit_code() {
170        let err = CliError::InvalidFormat("bad".to_string());
171        assert_eq!(err.exit_code(), ExitCode::from(4));
172    }
173
174    #[test]
175    fn test_io_error_exit_code() {
176        let err = CliError::Io(std::io::Error::new(std::io::ErrorKind::Other, "test"));
177        assert_eq!(err.exit_code(), ExitCode::from(7));
178    }
179
180    #[test]
181    fn test_validation_failed_exit_code() {
182        let err = CliError::ValidationFailed("test".to_string());
183        assert_eq!(err.exit_code(), ExitCode::from(5));
184    }
185
186    #[test]
187    fn test_aprender_error_exit_code() {
188        let err = CliError::Aprender("test".to_string());
189        assert_eq!(err.exit_code(), ExitCode::from(1));
190    }
191
192    #[test]
193    fn test_model_load_failed_exit_code() {
194        let err = CliError::ModelLoadFailed("test".to_string());
195        assert_eq!(err.exit_code(), ExitCode::from(6));
196    }
197
198    #[test]
199    fn test_inference_failed_exit_code() {
200        let err = CliError::InferenceFailed("test".to_string());
201        assert_eq!(err.exit_code(), ExitCode::from(8));
202    }
203
204    #[test]
205    fn test_feature_disabled_exit_code() {
206        let err = CliError::FeatureDisabled("test".to_string());
207        assert_eq!(err.exit_code(), ExitCode::from(9));
208    }
209
210    #[test]
211    fn test_network_error_exit_code() {
212        let err = CliError::NetworkError("test".to_string());
213        assert_eq!(err.exit_code(), ExitCode::from(10));
214    }
215
216    #[test]
217    fn test_http_not_found_exit_code() {
218        let err = CliError::HttpNotFound("test".to_string());
219        assert_eq!(err.exit_code(), ExitCode::from(11));
220    }
221
222    // ==================== Display Tests ====================
223
224    #[test]
225    fn test_file_not_found_display() {
226        let err = CliError::FileNotFound(PathBuf::from("/model.apr"));
227        assert_eq!(err.to_string(), "File not found: /model.apr");
228    }
229
230    #[test]
231    fn test_not_a_file_display() {
232        let err = CliError::NotAFile(PathBuf::from("/dir"));
233        assert_eq!(err.to_string(), "Not a file: /dir");
234    }
235
236    #[test]
237    fn test_invalid_format_display() {
238        let err = CliError::InvalidFormat("bad magic".to_string());
239        assert_eq!(err.to_string(), "Invalid APR format: bad magic");
240    }
241
242    #[test]
243    fn test_validation_failed_display() {
244        let err = CliError::ValidationFailed("missing field".to_string());
245        assert_eq!(err.to_string(), "Validation failed: missing field");
246    }
247
248    #[test]
249    fn test_aprender_error_display() {
250        let err = CliError::Aprender("internal".to_string());
251        assert_eq!(err.to_string(), "Aprender error: internal");
252    }
253
254    #[test]
255    fn test_model_load_failed_display() {
256        let err = CliError::ModelLoadFailed("corrupt".to_string());
257        assert_eq!(err.to_string(), "Model load failed: corrupt");
258    }
259
260    #[test]
261    fn test_inference_failed_display() {
262        let err = CliError::InferenceFailed("OOM".to_string());
263        assert_eq!(err.to_string(), "Inference failed: OOM");
264    }
265
266    #[test]
267    fn test_feature_disabled_display() {
268        let err = CliError::FeatureDisabled("cuda".to_string());
269        assert_eq!(err.to_string(), "Feature not enabled: cuda");
270    }
271
272    #[test]
273    fn test_network_error_display() {
274        let err = CliError::NetworkError("timeout".to_string());
275        assert_eq!(err.to_string(), "Network error: timeout");
276    }
277
278    #[test]
279    fn test_http_not_found_display() {
280        let err = CliError::HttpNotFound("tokenizer.json".to_string());
281        assert_eq!(err.to_string(), "HTTP 404 Not Found: tokenizer.json");
282    }
283
284    // ==================== Conversion Tests ====================
285
286    #[test]
287    fn test_io_error_conversion() {
288        let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file missing");
289        let cli_err: CliError = io_err.into();
290        assert!(cli_err.to_string().contains("file missing"));
291        assert_eq!(cli_err.exit_code(), ExitCode::from(7));
292    }
293
294    #[test]
295    fn test_debug_impl() {
296        let err = CliError::FileNotFound(PathBuf::from("/test"));
297        let debug = format!("{:?}", err);
298        assert!(debug.contains("FileNotFound"));
299    }
300
301    // ==================== Result Type Alias ====================
302
303    #[test]
304    fn test_result_type_ok() {
305        let result: Result<i32> = Ok(42);
306        assert_eq!(result.unwrap(), 42);
307    }
308
309    #[test]
310    fn test_result_type_err() {
311        let result: Result<i32> = Err(CliError::InvalidFormat("test".to_string()));
312        assert!(result.is_err());
313    }
314
315    // ==================== Exit Code Uniqueness ====================
316
317    #[test]
318    fn test_all_exit_codes_are_distinct_per_category() {
319        // Verify exit codes map to distinct categories
320        let codes = vec![
321            (
322                CliError::FileNotFound(PathBuf::from("a")).exit_code(),
323                "file",
324            ),
325            (
326                CliError::InvalidFormat("a".to_string()).exit_code(),
327                "format",
328            ),
329            (
330                CliError::Io(std::io::Error::new(std::io::ErrorKind::Other, "")).exit_code(),
331                "io",
332            ),
333            (
334                CliError::ValidationFailed("a".to_string()).exit_code(),
335                "validation",
336            ),
337            (CliError::Aprender("a".to_string()).exit_code(), "aprender"),
338            (
339                CliError::ModelLoadFailed("a".to_string()).exit_code(),
340                "model_load",
341            ),
342            (
343                CliError::InferenceFailed("a".to_string()).exit_code(),
344                "inference",
345            ),
346            (
347                CliError::FeatureDisabled("a".to_string()).exit_code(),
348                "feature",
349            ),
350            (
351                CliError::NetworkError("a".to_string()).exit_code(),
352                "network",
353            ),
354            (
355                CliError::HttpNotFound("a".to_string()).exit_code(),
356                "http_not_found",
357            ),
358        ];
359        // FileNotFound and NotAFile intentionally share exit code 3
360        assert_eq!(codes[0].0, ExitCode::from(3));
361    }
362
363    // ==================== resolve_model_path Tests ====================
364
365    #[test]
366    fn test_resolve_model_path_nonexistent() {
367        let result = resolve_model_path(std::path::Path::new("/nonexistent/path/model.gguf"));
368        assert!(result.is_err());
369        assert!(matches!(result.unwrap_err(), CliError::FileNotFound(_)));
370    }
371
372    #[test]
373    fn test_resolve_model_path_regular_file() {
374        // Create a temp file and resolve it
375        let tmp = std::env::temp_dir().join("apr-test-resolve.safetensors");
376        std::fs::write(&tmp, b"test").expect("write");
377        let result = resolve_model_path(&tmp);
378        assert!(result.is_ok());
379        assert_eq!(result.unwrap(), tmp);
380        std::fs::remove_file(&tmp).ok();
381    }
382
383    #[test]
384    fn test_resolve_model_path_dir_with_safetensors() {
385        let dir = std::env::temp_dir().join("apr-test-resolve-dir");
386        std::fs::create_dir_all(&dir).expect("mkdir");
387        let model_file = dir.join("model.safetensors");
388        std::fs::write(&model_file, b"test").expect("write");
389        let result = resolve_model_path(&dir);
390        assert!(result.is_ok());
391        assert_eq!(result.unwrap(), model_file);
392        std::fs::remove_file(&model_file).ok();
393        std::fs::remove_dir(&dir).ok();
394    }
395
396    #[test]
397    fn test_resolve_model_path_dir_with_gguf() {
398        let dir = std::env::temp_dir().join("apr-test-resolve-gguf");
399        std::fs::create_dir_all(&dir).expect("mkdir");
400        let model_file = dir.join("model-q4.gguf");
401        std::fs::write(&model_file, b"test").expect("write");
402        let result = resolve_model_path(&dir);
403        assert!(result.is_ok());
404        assert_eq!(result.unwrap(), model_file);
405        std::fs::remove_file(&model_file).ok();
406        std::fs::remove_dir(&dir).ok();
407    }
408
409    #[test]
410    fn test_resolve_model_path_dir_with_sharded_safetensors() {
411        // PMAT-314: Sharded models have index.json that MUST take priority
412        // over individual shard files (model-00001-of-00002.safetensors)
413        let dir = std::env::temp_dir().join("apr-test-resolve-sharded");
414        std::fs::create_dir_all(&dir).expect("mkdir");
415        let index_file = dir.join("model.safetensors.index.json");
416        let shard_file = dir.join("model-00001-of-00002.safetensors");
417        std::fs::write(&index_file, b"{}").expect("write index");
418        std::fs::write(&shard_file, b"test").expect("write shard");
419        let result = resolve_model_path(&dir);
420        assert!(result.is_ok());
421        assert_eq!(
422            result.unwrap(),
423            index_file,
424            "index.json must take priority over shard files"
425        );
426        std::fs::remove_file(&shard_file).ok();
427        std::fs::remove_file(&index_file).ok();
428        std::fs::remove_dir(&dir).ok();
429    }
430
431    #[test]
432    fn test_resolve_model_path_empty_dir() {
433        let dir = std::env::temp_dir().join("apr-test-resolve-empty");
434        std::fs::create_dir_all(&dir).expect("mkdir");
435        let result = resolve_model_path(&dir);
436        assert!(result.is_err());
437        assert!(matches!(result.unwrap_err(), CliError::ValidationFailed(_)));
438        std::fs::remove_dir(&dir).ok();
439    }
440}