Skip to main content

apr_cli/
error.rs

1//! Error types for apr-cli
2//!
3//! Toyota Way: Jidoka - Stop and highlight problems immediately.
4
5use std::path::PathBuf;
6use std::process::ExitCode;
7use thiserror::Error;
8
9/// Result type alias for CLI operations
10pub type Result<T> = std::result::Result<T, CliError>;
11
12/// CLI error types
13#[derive(Error, Debug)]
14pub enum CliError {
15    /// File not found
16    #[error("File not found: {0}")]
17    FileNotFound(PathBuf),
18
19    /// Not a file (e.g., directory)
20    #[error("Not a file: {0}")]
21    NotAFile(PathBuf),
22
23    /// Invalid APR format
24    #[error("Invalid APR format: {0}")]
25    InvalidFormat(String),
26
27    /// IO error
28    #[error("IO error: {0}")]
29    Io(#[from] std::io::Error),
30
31    /// Validation failed
32    #[error("Validation failed: {0}")]
33    ValidationFailed(String),
34
35    /// Aprender error
36    #[error("Aprender error: {0}")]
37    Aprender(String),
38
39    /// Model loading failed (used with inference feature)
40    #[error("Model load failed: {0}")]
41    #[allow(dead_code)]
42    ModelLoadFailed(String),
43
44    /// Inference failed (used with inference feature)
45    #[error("Inference failed: {0}")]
46    #[allow(dead_code)]
47    InferenceFailed(String),
48
49    /// Feature disabled (used when optional features are not compiled)
50    #[error("Feature not enabled: {0}")]
51    #[allow(dead_code)]
52    FeatureDisabled(String),
53
54    /// Network error
55    #[error("Network error: {0}")]
56    NetworkError(String),
57
58    /// HTTP 404 Not Found (GH-356: distinguish from other network errors)
59    #[error("HTTP 404 Not Found: {0}")]
60    HttpNotFound(String),
61}
62
63impl CliError {
64    /// Get exit code for this error
65    pub fn exit_code(&self) -> ExitCode {
66        contract_pre_exit_code_semantics!();
67        match self {
68            Self::FileNotFound(_) | Self::NotAFile(_) => ExitCode::from(3),
69            Self::InvalidFormat(_) => ExitCode::from(4),
70            Self::Io(_) => ExitCode::from(7),
71            Self::ValidationFailed(_) => ExitCode::from(5),
72            Self::Aprender(_) => ExitCode::from(1),
73            Self::ModelLoadFailed(_) => ExitCode::from(6),
74            Self::InferenceFailed(_) => ExitCode::from(8),
75            Self::FeatureDisabled(_) => ExitCode::from(9),
76            Self::NetworkError(_) => ExitCode::from(10),
77            Self::HttpNotFound(_) => ExitCode::from(11),
78        }
79    }
80}
81
82impl From<aprender::error::AprenderError> for CliError {
83    fn from(e: aprender::error::AprenderError) -> Self {
84        Self::Aprender(e.to_string())
85    }
86}
87
88/// Resolve a model path: if given a directory, look for common model files inside.
89///
90/// HuggingFace models are stored as directories containing `model.safetensors`,
91/// `model-00001-of-NNNNN.safetensors`, or `*.gguf`. This function resolves such
92/// directories to the actual model file, avoiding "Not a file" errors.
93pub fn resolve_model_path(
94    path: &std::path::Path,
95) -> std::result::Result<std::path::PathBuf, CliError> {
96    if !path.exists() {
97        return Err(CliError::FileNotFound(path.to_path_buf()));
98    }
99    if path.is_file() {
100        return Ok(path.to_path_buf());
101    }
102    if path.is_dir() {
103        // GH-668: Reject common system/temp directories. Only resolve dirs that
104        // are plausible HuggingFace model checkpoints (not /, /tmp, /home, etc.).
105        if let Some(parent) = path.parent() {
106            let depth = path.components().count();
107            // Directories at filesystem root level (depth <= 2) are never model dirs
108            if depth <= 2 {
109                return Err(CliError::NotAFile(path.to_path_buf()));
110            }
111            let _ = parent; // suppress unused warning
112        }
113
114        // PMAT-314: Check sharded SafeTensors index FIRST — individual shard files
115        // only contain a subset of tensors and will fail the architecture gate.
116        let index = path.join("model.safetensors.index.json");
117        if index.is_file() {
118            return Ok(index);
119        }
120        // Try common model file names in priority order
121        let candidates = [
122            "model.safetensors",
123            "model-00001-of-00001.safetensors",
124            "model-00001-of-00002.safetensors",
125            "model-00001-of-00003.safetensors",
126            "model-00001-of-00004.safetensors",
127        ];
128        for candidate in &candidates {
129            let p = path.join(candidate);
130            if p.is_file() {
131                return Ok(p);
132            }
133        }
134        // Try first .gguf file
135        if let Ok(entries) = std::fs::read_dir(path) {
136            for entry in entries.flatten() {
137                let p = entry.path();
138                // GH-668: Skip temp files (rosetta_temp.apr, etc.) to avoid inspecting stale artifacts
139                let is_temp = p
140                    .file_name()
141                    .is_some_and(|n| n.to_string_lossy().starts_with("rosetta_temp"));
142                if !is_temp && p.extension().is_some_and(|ext| ext == "gguf") && p.is_file() {
143                    return Ok(p);
144                }
145            }
146        }
147        // Try first .apr file
148        if let Ok(entries) = std::fs::read_dir(path) {
149            for entry in entries.flatten() {
150                let p = entry.path();
151                let is_temp = p
152                    .file_name()
153                    .is_some_and(|n| n.to_string_lossy().starts_with("rosetta_temp"));
154                if !is_temp && p.extension().is_some_and(|ext| ext == "apr") && p.is_file() {
155                    return Ok(p);
156                }
157            }
158        }
159        Err(CliError::ValidationFailed(format!(
160            "Directory {} does not contain a model file (expected model.safetensors, *.gguf, or *.apr)",
161            path.display()
162        )))
163    } else {
164        Err(CliError::NotAFile(path.to_path_buf()))
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use std::path::PathBuf;
172
173    // ==================== Exit Code Tests ====================
174
175    #[test]
176    fn test_file_not_found_exit_code() {
177        let err = CliError::FileNotFound(PathBuf::from("/test"));
178        assert_eq!(err.exit_code(), ExitCode::from(3));
179    }
180
181    #[test]
182    fn test_not_a_file_exit_code() {
183        let err = CliError::NotAFile(PathBuf::from("/test"));
184        assert_eq!(err.exit_code(), ExitCode::from(3));
185    }
186
187    #[test]
188    fn test_invalid_format_exit_code() {
189        let err = CliError::InvalidFormat("bad".to_string());
190        assert_eq!(err.exit_code(), ExitCode::from(4));
191    }
192
193    #[test]
194    fn test_io_error_exit_code() {
195        let err = CliError::Io(std::io::Error::new(std::io::ErrorKind::Other, "test"));
196        assert_eq!(err.exit_code(), ExitCode::from(7));
197    }
198
199    #[test]
200    fn test_validation_failed_exit_code() {
201        let err = CliError::ValidationFailed("test".to_string());
202        assert_eq!(err.exit_code(), ExitCode::from(5));
203    }
204
205    #[test]
206    fn test_aprender_error_exit_code() {
207        let err = CliError::Aprender("test".to_string());
208        assert_eq!(err.exit_code(), ExitCode::from(1));
209    }
210
211    #[test]
212    fn test_model_load_failed_exit_code() {
213        let err = CliError::ModelLoadFailed("test".to_string());
214        assert_eq!(err.exit_code(), ExitCode::from(6));
215    }
216
217    #[test]
218    fn test_inference_failed_exit_code() {
219        let err = CliError::InferenceFailed("test".to_string());
220        assert_eq!(err.exit_code(), ExitCode::from(8));
221    }
222
223    #[test]
224    fn test_feature_disabled_exit_code() {
225        let err = CliError::FeatureDisabled("test".to_string());
226        assert_eq!(err.exit_code(), ExitCode::from(9));
227    }
228
229    #[test]
230    fn test_network_error_exit_code() {
231        let err = CliError::NetworkError("test".to_string());
232        assert_eq!(err.exit_code(), ExitCode::from(10));
233    }
234
235    #[test]
236    fn test_http_not_found_exit_code() {
237        let err = CliError::HttpNotFound("test".to_string());
238        assert_eq!(err.exit_code(), ExitCode::from(11));
239    }
240
241    // ==================== Display Tests ====================
242
243    #[test]
244    fn test_file_not_found_display() {
245        let err = CliError::FileNotFound(PathBuf::from("/model.apr"));
246        assert_eq!(err.to_string(), "File not found: /model.apr");
247    }
248
249    #[test]
250    fn test_not_a_file_display() {
251        let err = CliError::NotAFile(PathBuf::from("/dir"));
252        assert_eq!(err.to_string(), "Not a file: /dir");
253    }
254
255    #[test]
256    fn test_invalid_format_display() {
257        let err = CliError::InvalidFormat("bad magic".to_string());
258        assert_eq!(err.to_string(), "Invalid APR format: bad magic");
259    }
260
261    #[test]
262    fn test_validation_failed_display() {
263        let err = CliError::ValidationFailed("missing field".to_string());
264        assert_eq!(err.to_string(), "Validation failed: missing field");
265    }
266
267    #[test]
268    fn test_aprender_error_display() {
269        let err = CliError::Aprender("internal".to_string());
270        assert_eq!(err.to_string(), "Aprender error: internal");
271    }
272
273    #[test]
274    fn test_model_load_failed_display() {
275        let err = CliError::ModelLoadFailed("corrupt".to_string());
276        assert_eq!(err.to_string(), "Model load failed: corrupt");
277    }
278
279    #[test]
280    fn test_inference_failed_display() {
281        let err = CliError::InferenceFailed("OOM".to_string());
282        assert_eq!(err.to_string(), "Inference failed: OOM");
283    }
284
285    #[test]
286    fn test_feature_disabled_display() {
287        let err = CliError::FeatureDisabled("cuda".to_string());
288        assert_eq!(err.to_string(), "Feature not enabled: cuda");
289    }
290
291    #[test]
292    fn test_network_error_display() {
293        let err = CliError::NetworkError("timeout".to_string());
294        assert_eq!(err.to_string(), "Network error: timeout");
295    }
296
297    #[test]
298    fn test_http_not_found_display() {
299        let err = CliError::HttpNotFound("tokenizer.json".to_string());
300        assert_eq!(err.to_string(), "HTTP 404 Not Found: tokenizer.json");
301    }
302
303    // ==================== Conversion Tests ====================
304
305    #[test]
306    fn test_io_error_conversion() {
307        let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file missing");
308        let cli_err: CliError = io_err.into();
309        assert!(cli_err.to_string().contains("file missing"));
310        assert_eq!(cli_err.exit_code(), ExitCode::from(7));
311    }
312
313    #[test]
314    fn test_debug_impl() {
315        let err = CliError::FileNotFound(PathBuf::from("/test"));
316        let debug = format!("{:?}", err);
317        assert!(debug.contains("FileNotFound"));
318    }
319
320    // ==================== Result Type Alias ====================
321
322    #[test]
323    fn test_result_type_ok() {
324        let result: Result<i32> = Ok(42);
325        assert_eq!(result.unwrap(), 42);
326    }
327
328    #[test]
329    fn test_result_type_err() {
330        let result: Result<i32> = Err(CliError::InvalidFormat("test".to_string()));
331        assert!(result.is_err());
332    }
333
334    // ==================== Exit Code Uniqueness ====================
335
336    #[test]
337    fn test_all_exit_codes_are_distinct_per_category() {
338        // Verify exit codes map to distinct categories
339        let codes = vec![
340            (
341                CliError::FileNotFound(PathBuf::from("a")).exit_code(),
342                "file",
343            ),
344            (
345                CliError::InvalidFormat("a".to_string()).exit_code(),
346                "format",
347            ),
348            (
349                CliError::Io(std::io::Error::new(std::io::ErrorKind::Other, "")).exit_code(),
350                "io",
351            ),
352            (
353                CliError::ValidationFailed("a".to_string()).exit_code(),
354                "validation",
355            ),
356            (CliError::Aprender("a".to_string()).exit_code(), "aprender"),
357            (
358                CliError::ModelLoadFailed("a".to_string()).exit_code(),
359                "model_load",
360            ),
361            (
362                CliError::InferenceFailed("a".to_string()).exit_code(),
363                "inference",
364            ),
365            (
366                CliError::FeatureDisabled("a".to_string()).exit_code(),
367                "feature",
368            ),
369            (
370                CliError::NetworkError("a".to_string()).exit_code(),
371                "network",
372            ),
373            (
374                CliError::HttpNotFound("a".to_string()).exit_code(),
375                "http_not_found",
376            ),
377        ];
378        // FileNotFound and NotAFile intentionally share exit code 3
379        assert_eq!(codes[0].0, ExitCode::from(3));
380    }
381
382    // ==================== resolve_model_path Tests ====================
383
384    #[test]
385    fn test_resolve_model_path_nonexistent() {
386        let result = resolve_model_path(std::path::Path::new("/nonexistent/path/model.gguf"));
387        assert!(result.is_err());
388        assert!(matches!(result.unwrap_err(), CliError::FileNotFound(_)));
389    }
390
391    #[test]
392    fn test_resolve_model_path_regular_file() {
393        // Create a temp file and resolve it
394        let tmp = std::env::temp_dir().join("apr-test-resolve.safetensors");
395        std::fs::write(&tmp, b"test").expect("write");
396        let result = resolve_model_path(&tmp);
397        assert!(result.is_ok());
398        assert_eq!(result.unwrap(), tmp);
399        std::fs::remove_file(&tmp).ok();
400    }
401
402    #[test]
403    fn test_resolve_model_path_dir_with_safetensors() {
404        let dir = std::env::temp_dir().join("apr-test-resolve-dir");
405        std::fs::create_dir_all(&dir).expect("mkdir");
406        let model_file = dir.join("model.safetensors");
407        std::fs::write(&model_file, b"test").expect("write");
408        let result = resolve_model_path(&dir);
409        assert!(result.is_ok());
410        assert_eq!(result.unwrap(), model_file);
411        std::fs::remove_file(&model_file).ok();
412        std::fs::remove_dir(&dir).ok();
413    }
414
415    #[test]
416    fn test_resolve_model_path_dir_with_gguf() {
417        let dir = std::env::temp_dir().join("apr-test-resolve-gguf");
418        std::fs::create_dir_all(&dir).expect("mkdir");
419        let model_file = dir.join("model-q4.gguf");
420        std::fs::write(&model_file, b"test").expect("write");
421        let result = resolve_model_path(&dir);
422        assert!(result.is_ok());
423        assert_eq!(result.unwrap(), model_file);
424        std::fs::remove_file(&model_file).ok();
425        std::fs::remove_dir(&dir).ok();
426    }
427
428    #[test]
429    fn test_resolve_model_path_dir_with_sharded_safetensors() {
430        // PMAT-314: Sharded models have index.json that MUST take priority
431        // over individual shard files (model-00001-of-00002.safetensors)
432        let dir = std::env::temp_dir().join("apr-test-resolve-sharded");
433        std::fs::create_dir_all(&dir).expect("mkdir");
434        let index_file = dir.join("model.safetensors.index.json");
435        let shard_file = dir.join("model-00001-of-00002.safetensors");
436        std::fs::write(&index_file, b"{}").expect("write index");
437        std::fs::write(&shard_file, b"test").expect("write shard");
438        let result = resolve_model_path(&dir);
439        assert!(result.is_ok());
440        assert_eq!(
441            result.unwrap(),
442            index_file,
443            "index.json must take priority over shard files"
444        );
445        std::fs::remove_file(&shard_file).ok();
446        std::fs::remove_file(&index_file).ok();
447        std::fs::remove_dir(&dir).ok();
448    }
449
450    #[test]
451    fn test_resolve_model_path_empty_dir() {
452        let dir = std::env::temp_dir().join("apr-test-resolve-empty");
453        std::fs::create_dir_all(&dir).expect("mkdir");
454        let result = resolve_model_path(&dir);
455        assert!(result.is_err());
456        assert!(matches!(result.unwrap_err(), CliError::ValidationFailed(_)));
457        std::fs::remove_dir(&dir).ok();
458    }
459}