Skip to main content

apr_cli/
error.rs

1//! Error types for apr-cli
2//!
3//! Toyota Way: Jidoka - Stop and highlight problems immediately.
4
5use std::path::PathBuf;
6use std::process::ExitCode;
7use thiserror::Error;
8
9/// Result type alias for CLI operations
10pub type Result<T> = std::result::Result<T, CliError>;
11
12/// CLI error types
13#[derive(Error, Debug)]
14pub enum CliError {
15    /// File not found
16    #[error("File not found: {0}")]
17    FileNotFound(PathBuf),
18
19    /// Not a file (e.g., directory)
20    #[error("Not a file: {0}")]
21    NotAFile(PathBuf),
22
23    /// Invalid APR format
24    #[error("Invalid APR format: {0}")]
25    InvalidFormat(String),
26
27    /// IO error
28    #[error("IO error: {0}")]
29    Io(#[from] std::io::Error),
30
31    /// Validation failed
32    #[error("Validation failed: {0}")]
33    ValidationFailed(String),
34
35    /// Aprender error
36    #[error("Aprender error: {0}")]
37    Aprender(String),
38
39    /// Model loading failed (used with inference feature)
40    #[error("Model load failed: {0}")]
41    #[allow(dead_code)]
42    ModelLoadFailed(String),
43
44    /// Inference failed (used with inference feature)
45    #[error("Inference failed: {0}")]
46    #[allow(dead_code)]
47    InferenceFailed(String),
48
49    /// Feature disabled (used when optional features are not compiled)
50    #[error("Feature not enabled: {0}")]
51    #[allow(dead_code)]
52    FeatureDisabled(String),
53
54    /// Network error
55    #[error("Network error: {0}")]
56    NetworkError(String),
57
58    /// HTTP 404 Not Found (GH-356: distinguish from other network errors)
59    #[error("HTTP 404 Not Found: {0}")]
60    HttpNotFound(String),
61}
62
63impl CliError {
64    /// Get exit code for this error
65    pub fn exit_code(&self) -> ExitCode {
66        contract_pre_exit_code_semantics!();
67        contract_pre_error_mapping!();
68        contract_pre_exit_code_on_error!();
69        match self {
70            Self::FileNotFound(_) | Self::NotAFile(_) => ExitCode::from(3),
71            Self::InvalidFormat(_) => ExitCode::from(4),
72            Self::Io(_) => ExitCode::from(7),
73            Self::ValidationFailed(_) => ExitCode::from(5),
74            Self::Aprender(_) => ExitCode::from(1),
75            Self::ModelLoadFailed(_) => ExitCode::from(6),
76            Self::InferenceFailed(_) => ExitCode::from(8),
77            Self::FeatureDisabled(_) => ExitCode::from(9),
78            Self::NetworkError(_) => ExitCode::from(10),
79            Self::HttpNotFound(_) => ExitCode::from(11),
80        }
81    }
82}
83
84impl From<aprender::error::AprenderError> for CliError {
85    fn from(e: aprender::error::AprenderError) -> Self {
86        Self::Aprender(e.to_string())
87    }
88}
89
90/// Resolve a model path: if given a directory, look for common model files inside.
91///
92/// HuggingFace models are stored as directories containing `model.safetensors`,
93/// `model-00001-of-NNNNN.safetensors`, or `*.gguf`. This function resolves such
94/// directories to the actual model file, avoiding "Not a file" errors.
95pub fn resolve_model_path(
96    path: &std::path::Path,
97) -> std::result::Result<std::path::PathBuf, CliError> {
98    if !path.exists() {
99        return Err(CliError::FileNotFound(path.to_path_buf()));
100    }
101    if path.is_file() {
102        return Ok(path.to_path_buf());
103    }
104    if path.is_dir() {
105        // GH-668: Reject common system/temp directories. Only resolve dirs that
106        // are plausible HuggingFace model checkpoints (not /, /tmp, /home, etc.).
107        if let Some(parent) = path.parent() {
108            let depth = path.components().count();
109            // Directories at filesystem root level (depth <= 2) are never model dirs
110            if depth <= 2 {
111                return Err(CliError::NotAFile(path.to_path_buf()));
112            }
113            let _ = parent; // suppress unused warning
114        }
115
116        // PMAT-314: Check sharded SafeTensors index FIRST — individual shard files
117        // only contain a subset of tensors and will fail the architecture gate.
118        let index = path.join("model.safetensors.index.json");
119        if index.is_file() {
120            return Ok(index);
121        }
122        // Try common model file names in priority order
123        let candidates = [
124            "model.safetensors",
125            "model-00001-of-00001.safetensors",
126            "model-00001-of-00002.safetensors",
127            "model-00001-of-00003.safetensors",
128            "model-00001-of-00004.safetensors",
129        ];
130        for candidate in &candidates {
131            let p = path.join(candidate);
132            if p.is_file() {
133                return Ok(p);
134            }
135        }
136        // Try first .gguf file
137        if let Ok(entries) = std::fs::read_dir(path) {
138            for entry in entries.flatten() {
139                let p = entry.path();
140                // GH-668: Skip temp files (rosetta_temp.apr, etc.) to avoid inspecting stale artifacts
141                let is_temp = p
142                    .file_name()
143                    .is_some_and(|n| n.to_string_lossy().starts_with("rosetta_temp"));
144                if !is_temp && p.extension().is_some_and(|ext| ext == "gguf") && p.is_file() {
145                    return Ok(p);
146                }
147            }
148        }
149        // Try first .apr file
150        if let Ok(entries) = std::fs::read_dir(path) {
151            for entry in entries.flatten() {
152                let p = entry.path();
153                let is_temp = p
154                    .file_name()
155                    .is_some_and(|n| n.to_string_lossy().starts_with("rosetta_temp"));
156                if !is_temp && p.extension().is_some_and(|ext| ext == "apr") && p.is_file() {
157                    return Ok(p);
158                }
159            }
160        }
161        Err(CliError::ValidationFailed(format!(
162            "Directory {} does not contain a model file (expected model.safetensors, *.gguf, or *.apr)",
163            path.display()
164        )))
165    } else {
166        Err(CliError::NotAFile(path.to_path_buf()))
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use std::path::PathBuf;
174
175    // ==================== Exit Code Tests ====================
176
177    #[test]
178    fn test_file_not_found_exit_code() {
179        let err = CliError::FileNotFound(PathBuf::from("/test"));
180        assert_eq!(err.exit_code(), ExitCode::from(3));
181    }
182
183    #[test]
184    fn test_not_a_file_exit_code() {
185        let err = CliError::NotAFile(PathBuf::from("/test"));
186        assert_eq!(err.exit_code(), ExitCode::from(3));
187    }
188
189    #[test]
190    fn test_invalid_format_exit_code() {
191        let err = CliError::InvalidFormat("bad".to_string());
192        assert_eq!(err.exit_code(), ExitCode::from(4));
193    }
194
195    #[test]
196    fn test_io_error_exit_code() {
197        let err = CliError::Io(std::io::Error::new(std::io::ErrorKind::Other, "test"));
198        assert_eq!(err.exit_code(), ExitCode::from(7));
199    }
200
201    #[test]
202    fn test_validation_failed_exit_code() {
203        let err = CliError::ValidationFailed("test".to_string());
204        assert_eq!(err.exit_code(), ExitCode::from(5));
205    }
206
207    #[test]
208    fn test_aprender_error_exit_code() {
209        let err = CliError::Aprender("test".to_string());
210        assert_eq!(err.exit_code(), ExitCode::from(1));
211    }
212
213    #[test]
214    fn test_model_load_failed_exit_code() {
215        let err = CliError::ModelLoadFailed("test".to_string());
216        assert_eq!(err.exit_code(), ExitCode::from(6));
217    }
218
219    #[test]
220    fn test_inference_failed_exit_code() {
221        let err = CliError::InferenceFailed("test".to_string());
222        assert_eq!(err.exit_code(), ExitCode::from(8));
223    }
224
225    #[test]
226    fn test_feature_disabled_exit_code() {
227        let err = CliError::FeatureDisabled("test".to_string());
228        assert_eq!(err.exit_code(), ExitCode::from(9));
229    }
230
231    #[test]
232    fn test_network_error_exit_code() {
233        let err = CliError::NetworkError("test".to_string());
234        assert_eq!(err.exit_code(), ExitCode::from(10));
235    }
236
237    #[test]
238    fn test_http_not_found_exit_code() {
239        let err = CliError::HttpNotFound("test".to_string());
240        assert_eq!(err.exit_code(), ExitCode::from(11));
241    }
242
243    // ==================== Display Tests ====================
244
245    #[test]
246    fn test_file_not_found_display() {
247        let err = CliError::FileNotFound(PathBuf::from("/model.apr"));
248        assert_eq!(err.to_string(), "File not found: /model.apr");
249    }
250
251    #[test]
252    fn test_not_a_file_display() {
253        let err = CliError::NotAFile(PathBuf::from("/dir"));
254        assert_eq!(err.to_string(), "Not a file: /dir");
255    }
256
257    #[test]
258    fn test_invalid_format_display() {
259        let err = CliError::InvalidFormat("bad magic".to_string());
260        assert_eq!(err.to_string(), "Invalid APR format: bad magic");
261    }
262
263    #[test]
264    fn test_validation_failed_display() {
265        let err = CliError::ValidationFailed("missing field".to_string());
266        assert_eq!(err.to_string(), "Validation failed: missing field");
267    }
268
269    #[test]
270    fn test_aprender_error_display() {
271        let err = CliError::Aprender("internal".to_string());
272        assert_eq!(err.to_string(), "Aprender error: internal");
273    }
274
275    #[test]
276    fn test_model_load_failed_display() {
277        let err = CliError::ModelLoadFailed("corrupt".to_string());
278        assert_eq!(err.to_string(), "Model load failed: corrupt");
279    }
280
281    #[test]
282    fn test_inference_failed_display() {
283        let err = CliError::InferenceFailed("OOM".to_string());
284        assert_eq!(err.to_string(), "Inference failed: OOM");
285    }
286
287    #[test]
288    fn test_feature_disabled_display() {
289        let err = CliError::FeatureDisabled("cuda".to_string());
290        assert_eq!(err.to_string(), "Feature not enabled: cuda");
291    }
292
293    #[test]
294    fn test_network_error_display() {
295        let err = CliError::NetworkError("timeout".to_string());
296        assert_eq!(err.to_string(), "Network error: timeout");
297    }
298
299    #[test]
300    fn test_http_not_found_display() {
301        let err = CliError::HttpNotFound("tokenizer.json".to_string());
302        assert_eq!(err.to_string(), "HTTP 404 Not Found: tokenizer.json");
303    }
304
305    // ==================== Conversion Tests ====================
306
307    #[test]
308    fn test_io_error_conversion() {
309        let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file missing");
310        let cli_err: CliError = io_err.into();
311        assert!(cli_err.to_string().contains("file missing"));
312        assert_eq!(cli_err.exit_code(), ExitCode::from(7));
313    }
314
315    #[test]
316    fn test_debug_impl() {
317        let err = CliError::FileNotFound(PathBuf::from("/test"));
318        let debug = format!("{:?}", err);
319        assert!(debug.contains("FileNotFound"));
320    }
321
322    // ==================== Result Type Alias ====================
323
324    #[test]
325    fn test_result_type_ok() {
326        let result: Result<i32> = Ok(42);
327        assert_eq!(result.expect("value"), 42);
328    }
329
330    #[test]
331    fn test_result_type_err() {
332        let result: Result<i32> = Err(CliError::InvalidFormat("test".to_string()));
333        assert!(result.is_err());
334    }
335
336    // ==================== Exit Code Uniqueness ====================
337
338    #[test]
339    fn test_all_exit_codes_are_distinct_per_category() {
340        // Verify exit codes map to distinct categories
341        let codes = vec![
342            (
343                CliError::FileNotFound(PathBuf::from("a")).exit_code(),
344                "file",
345            ),
346            (
347                CliError::InvalidFormat("a".to_string()).exit_code(),
348                "format",
349            ),
350            (
351                CliError::Io(std::io::Error::new(std::io::ErrorKind::Other, "")).exit_code(),
352                "io",
353            ),
354            (
355                CliError::ValidationFailed("a".to_string()).exit_code(),
356                "validation",
357            ),
358            (CliError::Aprender("a".to_string()).exit_code(), "aprender"),
359            (
360                CliError::ModelLoadFailed("a".to_string()).exit_code(),
361                "model_load",
362            ),
363            (
364                CliError::InferenceFailed("a".to_string()).exit_code(),
365                "inference",
366            ),
367            (
368                CliError::FeatureDisabled("a".to_string()).exit_code(),
369                "feature",
370            ),
371            (
372                CliError::NetworkError("a".to_string()).exit_code(),
373                "network",
374            ),
375            (
376                CliError::HttpNotFound("a".to_string()).exit_code(),
377                "http_not_found",
378            ),
379        ];
380        // FileNotFound and NotAFile intentionally share exit code 3
381        assert_eq!(codes[0].0, ExitCode::from(3));
382    }
383
384    // ==================== resolve_model_path Tests ====================
385
386    #[test]
387    fn test_resolve_model_path_nonexistent() {
388        let result = resolve_model_path(std::path::Path::new("/nonexistent/path/model.gguf"));
389        assert!(result.is_err());
390        assert!(matches!(result.unwrap_err(), CliError::FileNotFound(_)));
391    }
392
393    #[test]
394    fn test_resolve_model_path_regular_file() {
395        // Create a temp file and resolve it
396        let tmp = std::env::temp_dir().join("apr-test-resolve.safetensors");
397        std::fs::write(&tmp, b"test").expect("write");
398        let result = resolve_model_path(&tmp);
399        assert!(result.is_ok());
400        assert_eq!(result.expect("value"), tmp);
401        std::fs::remove_file(&tmp).ok();
402    }
403
404    #[test]
405    fn test_resolve_model_path_dir_with_safetensors() {
406        let dir = std::env::temp_dir().join("apr-test-resolve-dir");
407        std::fs::create_dir_all(&dir).expect("mkdir");
408        let model_file = dir.join("model.safetensors");
409        std::fs::write(&model_file, b"test").expect("write");
410        let result = resolve_model_path(&dir);
411        assert!(result.is_ok());
412        assert_eq!(result.expect("value"), model_file);
413        std::fs::remove_file(&model_file).ok();
414        std::fs::remove_dir(&dir).ok();
415    }
416
417    #[test]
418    fn test_resolve_model_path_dir_with_gguf() {
419        let dir = std::env::temp_dir().join("apr-test-resolve-gguf");
420        std::fs::create_dir_all(&dir).expect("mkdir");
421        let model_file = dir.join("model-q4.gguf");
422        std::fs::write(&model_file, b"test").expect("write");
423        let result = resolve_model_path(&dir);
424        assert!(result.is_ok());
425        assert_eq!(result.expect("value"), model_file);
426        std::fs::remove_file(&model_file).ok();
427        std::fs::remove_dir(&dir).ok();
428    }
429
430    #[test]
431    fn test_resolve_model_path_dir_with_sharded_safetensors() {
432        // PMAT-314: Sharded models have index.json that MUST take priority
433        // over individual shard files (model-00001-of-00002.safetensors)
434        let dir = std::env::temp_dir().join("apr-test-resolve-sharded");
435        std::fs::create_dir_all(&dir).expect("mkdir");
436        let index_file = dir.join("model.safetensors.index.json");
437        let shard_file = dir.join("model-00001-of-00002.safetensors");
438        std::fs::write(&index_file, b"{}").expect("write index");
439        std::fs::write(&shard_file, b"test").expect("write shard");
440        let result = resolve_model_path(&dir);
441        assert!(result.is_ok());
442        assert_eq!(
443            result.expect("value"),
444            index_file,
445            "index.json must take priority over shard files"
446        );
447        std::fs::remove_file(&shard_file).ok();
448        std::fs::remove_file(&index_file).ok();
449        std::fs::remove_dir(&dir).ok();
450    }
451
452    #[test]
453    fn test_resolve_model_path_empty_dir() {
454        let dir = std::env::temp_dir().join("apr-test-resolve-empty");
455        std::fs::create_dir_all(&dir).expect("mkdir");
456        let result = resolve_model_path(&dir);
457        assert!(result.is_err());
458        assert!(matches!(result.unwrap_err(), CliError::ValidationFailed(_)));
459        std::fs::remove_dir(&dir).ok();
460    }
461}