Skip to main content

bashrs_oracle/
lib.rs

1// Allow multiple crate versions from transitive dependencies (aprender -> wgpu -> foldhash)
2#![allow(clippy::multiple_crate_versions)]
3
4//! ML-powered error classification oracle for bashrs.
5//!
6//! Uses aprender Random Forest classifier (GPU-accelerated via trueno/wgpu) to:
7//! - Classify shell errors into actionable categories (24 categories)
8//! - Suggest fixes based on error patterns
9//! - Detect error drift requiring model retraining
10//!
11//! ## GPU Acceleration
12//!
13//! Enable GPU feature for RTX 4090 acceleration via wgpu/trueno:
14//! ```toml
15//! bashrs-oracle = { version = "*", features = ["gpu"] }
16//! ```
17//!
18//! ## Performance Targets (from depyler-oracle)
19//! - Accuracy: >90% (depyler achieved 97.73%)
20//! - Training time: <1s
21//! - Predictions/sec: >1000
22//! - Model size: <1MB (with zstd compression)
23
24use std::collections::HashMap;
25use std::path::{Path, PathBuf};
26
27use aprender::format::{self, Compression, ModelType, SaveOptions};
28use aprender::metrics::drift::{DriftConfig, DriftDetector, DriftStatus};
29use aprender::primitives::Matrix;
30use aprender::tree::RandomForestClassifier;
31use serde::{Deserialize, Serialize};
32
33pub mod categories;
34pub mod classifier;
35pub mod corpus;
36pub mod features;
37
38pub use categories::ErrorCategory;
39pub use classifier::ErrorClassifier;
40pub use corpus::{Corpus, TrainingExample};
41pub use features::ErrorFeatures;
42
43/// Error types for the oracle.
44#[derive(Debug, thiserror::Error)]
45pub enum OracleError {
46    /// Model loading/saving error.
47    #[error("Model error: {0}")]
48    Model(String),
49    /// Feature extraction error.
50    #[error("Feature extraction error: {0}")]
51    Feature(String),
52    /// Training error.
53    #[error("Training error: {0}")]
54    Training(String),
55    /// Classification error.
56    #[error("Classification error: {0}")]
57    Classification(String),
58    /// IO error.
59    #[error("IO error: {0}")]
60    Io(#[from] std::io::Error),
61}
62
63/// Result type for oracle operations.
64pub type Result<T> = std::result::Result<T, OracleError>;
65
66/// Classification result with confidence and suggested fix.
67#[derive(Clone, Debug, Serialize, Deserialize)]
68pub struct ClassificationResult {
69    /// Predicted error category.
70    pub category: ErrorCategory,
71    /// Confidence score (0.0 - 1.0).
72    pub confidence: f32,
73    /// Suggested fix template.
74    pub suggested_fix: Option<String>,
75    /// Related fix patterns.
76    pub related_patterns: Vec<String>,
77}
78
79/// Configuration for the Random Forest classifier.
80#[derive(Clone, Debug)]
81pub struct OracleConfig {
82    /// Number of trees in the forest (default: 100).
83    /// IMPORTANT: 100 is sufficient. 10,000 causes 15+ min training!
84    pub n_estimators: usize,
85    /// Maximum tree depth (default: 10).
86    pub max_depth: usize,
87    /// Random seed for reproducibility.
88    pub random_state: Option<u64>,
89}
90
91impl Default for OracleConfig {
92    fn default() -> Self {
93        Self {
94            n_estimators: 100,
95            max_depth: 10,
96            random_state: Some(42),
97        }
98    }
99}
100
101/// Default model filename.
102const DEFAULT_MODEL_NAME: &str = "bashrs_oracle.apr";
103
104/// ML-powered shell error classification oracle.
105pub struct Oracle {
106    /// Random Forest classifier (GPU-accelerated via aprender).
107    classifier: RandomForestClassifier,
108    /// Configuration used to create the classifier.
109    #[allow(dead_code)]
110    config: OracleConfig,
111    /// Category list for index mapping (kept for model introspection).
112    #[allow(dead_code)]
113    categories: Vec<ErrorCategory>,
114    /// Fix templates per category.
115    fix_templates: HashMap<ErrorCategory, Vec<String>>,
116    /// Drift detector for retraining triggers.
117    drift_detector: DriftDetector,
118    /// Historical performance scores.
119    performance_history: Vec<f32>,
120    /// Whether model has been trained.
121    is_trained: bool,
122}
123
124impl Default for Oracle {
125    fn default() -> Self {
126        Self::new()
127    }
128}
129
130impl Oracle {
131    /// Get the default model path.
132    #[must_use]
133    pub fn default_model_path() -> PathBuf {
134        // Try to find project root via Cargo.toml
135        let mut path = std::env::current_dir().unwrap_or_default();
136        for _ in 0..5 {
137            if path.join("Cargo.toml").exists() {
138                return path.join(DEFAULT_MODEL_NAME);
139            }
140            if !path.pop() {
141                break;
142            }
143        }
144        PathBuf::from(DEFAULT_MODEL_NAME)
145    }
146
147    /// Load model from default path, or train and save if not found.
148    ///
149    /// # Errors
150    /// Returns error if training fails.
151    pub fn load_or_train() -> Result<Self> {
152        let path = Self::default_model_path();
153
154        if path.exists() {
155            match Self::load(&path) {
156                Ok(oracle) => return Ok(oracle),
157                Err(e) => {
158                    tracing::warn!("Failed to load cached model: {e}. Retraining...");
159                }
160            }
161        }
162
163        // Train using synthetic data (5000 samples for good accuracy)
164        let corpus = Corpus::generate_synthetic(5000);
165        let oracle = Self::train_from_corpus(&corpus, OracleConfig::default())?;
166
167        // Save for next time
168        if let Err(e) = oracle.save(&path) {
169            tracing::warn!("Failed to cache model to {}: {e}", path.display());
170        }
171
172        Ok(oracle)
173    }
174
175    /// Create a new oracle with default configuration.
176    #[must_use]
177    pub fn new() -> Self {
178        Self::with_config(OracleConfig::default())
179    }
180
181    /// Create a new oracle with custom configuration.
182    #[must_use]
183    pub fn with_config(config: OracleConfig) -> Self {
184        let mut classifier =
185            RandomForestClassifier::new(config.n_estimators).with_max_depth(config.max_depth);
186        if let Some(seed) = config.random_state {
187            classifier = classifier.with_random_state(seed);
188        }
189
190        Self {
191            classifier,
192            config,
193            categories: ErrorCategory::all().to_vec(),
194            fix_templates: Self::default_fix_templates(),
195            drift_detector: DriftDetector::new(
196                DriftConfig::default()
197                    .with_min_samples(10)
198                    .with_window_size(50),
199            ),
200            performance_history: Vec::new(),
201            is_trained: false,
202        }
203    }
204
205    /// Train oracle from a corpus.
206    ///
207    /// # Errors
208    /// Returns error if training fails.
209    pub fn train_from_corpus(corpus: &Corpus, config: OracleConfig) -> Result<Self> {
210        let (x, y) = corpus.to_training_data();
211
212        // Convert to Matrix for aprender
213        let n_samples = x.len();
214        let n_features = x.first().map(|row| row.len()).unwrap_or(0);
215        let flat: Vec<f32> = x.into_iter().flatten().collect();
216        let features = Matrix::from_vec(n_samples, n_features, flat)
217            .map_err(|e| OracleError::Training(format!("Failed to create feature matrix: {e}")))?;
218        let labels: Vec<usize> = y.into_iter().map(|l| l as usize).collect();
219
220        let mut oracle = Self::with_config(config);
221        oracle.train(&features, &labels)?;
222
223        Ok(oracle)
224    }
225
226    /// Train the oracle on labeled error data.
227    ///
228    /// # Errors
229    /// Returns error if training fails.
230    pub fn train(&mut self, features: &Matrix<f32>, labels: &[usize]) -> Result<()> {
231        self.classifier
232            .fit(features, labels)
233            .map_err(|e| OracleError::Training(e.to_string()))?;
234        self.is_trained = true;
235
236        Ok(())
237    }
238
239    /// Classify an error and return category with confidence.
240    pub fn classify(&self, features: &ErrorFeatures) -> Result<ClassificationResult> {
241        if !self.is_trained {
242            // Fallback to keyword-based classification
243            let kw_classifier = ErrorClassifier::new();
244            let category = kw_classifier.classify_by_keywords(
245                &features
246                    .features
247                    .iter()
248                    .map(|f| f.to_string())
249                    .collect::<Vec<_>>()
250                    .join(" "),
251            );
252            return Ok(ClassificationResult {
253                category,
254                confidence: 0.5,
255                suggested_fix: Some(category.fix_suggestion().to_string()),
256                related_patterns: vec![],
257            });
258        }
259
260        let feature_matrix = Matrix::from_vec(1, ErrorFeatures::SIZE, features.as_slice().to_vec())
261            .map_err(|e| {
262                OracleError::Classification(format!("Failed to create feature matrix: {e}"))
263            })?;
264        let predictions = self.classifier.predict(&feature_matrix);
265
266        let pred_idx = predictions
267            .as_slice()
268            .first()
269            .copied()
270            .ok_or_else(|| OracleError::Classification("No prediction produced".to_string()))?;
271        let category = ErrorCategory::from_label_index(pred_idx);
272
273        let suggested_fix = self
274            .fix_templates
275            .get(&category)
276            .and_then(|fixes| fixes.first().cloned());
277
278        let related = self
279            .fix_templates
280            .get(&category)
281            .map(|fixes| fixes.iter().skip(1).cloned().collect())
282            .unwrap_or_default();
283
284        Ok(ClassificationResult {
285            category,
286            confidence: 0.85, // TODO: Extract from tree probabilities
287            suggested_fix,
288            related_patterns: related,
289        })
290    }
291
292    /// Classify an error from raw inputs.
293    pub fn classify_error(
294        &self,
295        exit_code: i32,
296        stderr: &str,
297        command: Option<&str>,
298    ) -> Result<ClassificationResult> {
299        let features = ErrorFeatures::extract(exit_code, stderr, command);
300        self.classify(&features)
301    }
302
303    /// Get fix suggestion for an error.
304    #[must_use]
305    pub fn suggest_fix(&self, exit_code: i32, stderr: &str, command: Option<&str>) -> String {
306        // If not trained, use keyword classifier directly on the stderr message
307        if !self.is_trained {
308            let kw_classifier = ErrorClassifier::new();
309            let category = kw_classifier.classify_by_keywords(stderr);
310            let confidence = kw_classifier.confidence(stderr, category);
311            return format!(
312                "[{:.0}% confident] {}: {}",
313                confidence * 100.0,
314                category.name(),
315                category.fix_suggestion()
316            );
317        }
318
319        match self.classify_error(exit_code, stderr, command) {
320            Ok(result) => {
321                format!(
322                    "[{:.0}% confident] {}: {}",
323                    result.confidence * 100.0,
324                    result.category.name(),
325                    result
326                        .suggested_fix
327                        .unwrap_or_else(|| result.category.fix_suggestion().to_string())
328                )
329            }
330            Err(_) => {
331                // Fallback to keyword classifier
332                let kw_classifier = ErrorClassifier::new();
333                let category = kw_classifier.classify_by_keywords(stderr);
334                format!(
335                    "[keyword] {}: {}",
336                    category.name(),
337                    category.fix_suggestion()
338                )
339            }
340        }
341    }
342
343    /// Check if the model needs retraining based on performance drift.
344    pub fn check_drift(&mut self, recent_accuracy: f32) -> DriftStatus {
345        self.performance_history.push(recent_accuracy);
346
347        if self.performance_history.len() < 10 {
348            return DriftStatus::NoDrift;
349        }
350
351        let mid = self.performance_history.len() / 2;
352        let baseline: Vec<f32> = self
353            .performance_history
354            .get(..mid)
355            .map(|s| s.to_vec())
356            .unwrap_or_default();
357        let current: Vec<f32> = self
358            .performance_history
359            .get(mid..)
360            .map(|s| s.to_vec())
361            .unwrap_or_default();
362
363        self.drift_detector
364            .detect_performance_drift(&baseline, &current)
365    }
366
367    /// Save the oracle model to a file (with zstd compression).
368    ///
369    /// # Errors
370    /// Returns error if saving fails.
371    pub fn save(&self, path: &Path) -> Result<()> {
372        let options = SaveOptions::default()
373            .with_name("bashrs-oracle")
374            .with_description("RandomForest error classification model for bashrs shell linter")
375            .with_compression(Compression::ZstdDefault); // 14x smaller!
376
377        format::save(&self.classifier, ModelType::RandomForest, path, options)
378            .map_err(|e| OracleError::Model(e.to_string()))?;
379
380        Ok(())
381    }
382
383    /// Load an oracle model from a file.
384    ///
385    /// # Errors
386    /// Returns error if loading fails.
387    pub fn load(path: &Path) -> Result<Self> {
388        let classifier: RandomForestClassifier = format::load(path, ModelType::RandomForest)
389            .map_err(|e| OracleError::Model(e.to_string()))?;
390
391        let config = OracleConfig::default();
392        Ok(Self {
393            classifier,
394            config,
395            categories: ErrorCategory::all().to_vec(),
396            fix_templates: Self::default_fix_templates(),
397            drift_detector: DriftDetector::new(
398                DriftConfig::default()
399                    .with_min_samples(10)
400                    .with_window_size(50),
401            ),
402            performance_history: Vec::new(),
403            is_trained: true,
404        })
405    }
406
407    /// Check if the oracle has been trained.
408    #[must_use]
409    pub fn is_trained(&self) -> bool {
410        self.is_trained
411    }
412
413    /// Default fix templates for each category.
414    fn default_fix_templates() -> HashMap<ErrorCategory, Vec<String>> {
415        let mut templates = HashMap::new();
416
417        // Syntax errors
418        templates.insert(
419            ErrorCategory::SyntaxQuoteMismatch,
420            vec![
421                "Check for unmatched quotes (' or \")".to_string(),
422                "Use shellcheck to identify the exact location".to_string(),
423            ],
424        );
425        templates.insert(
426            ErrorCategory::SyntaxBracketMismatch,
427            vec![
428                "Check for unmatched brackets ([], {}, ())".to_string(),
429                "Ensure conditionals have proper [ ] or [[ ]] syntax".to_string(),
430            ],
431        );
432        templates.insert(
433            ErrorCategory::SyntaxUnexpectedToken,
434            vec![
435                "Review syntax near the reported token".to_string(),
436                "Check for missing 'then', 'do', or 'fi'".to_string(),
437            ],
438        );
439        templates.insert(
440            ErrorCategory::SyntaxMissingOperand,
441            vec![
442                "Add missing operand to the expression".to_string(),
443                "Check arithmetic expressions for completeness".to_string(),
444            ],
445        );
446
447        // Command errors
448        templates.insert(
449            ErrorCategory::CommandNotFound,
450            vec![
451                "Check PATH or install the missing command".to_string(),
452                "Verify the command name spelling".to_string(),
453                "Try 'which <command>' or 'type <command>'".to_string(),
454            ],
455        );
456        templates.insert(
457            ErrorCategory::CommandPermissionDenied,
458            vec![
459                "Use chmod +x to make the script executable".to_string(),
460                "Run with sudo if elevated privileges needed".to_string(),
461            ],
462        );
463        templates.insert(
464            ErrorCategory::CommandInvalidOption,
465            vec![
466                "Check command documentation with --help or man page".to_string(),
467                "Verify option syntax (single dash vs double dash)".to_string(),
468            ],
469        );
470        templates.insert(
471            ErrorCategory::CommandMissingArgument,
472            vec![
473                "Provide required argument to the command".to_string(),
474                "Check command usage with --help".to_string(),
475            ],
476        );
477
478        // File errors
479        templates.insert(
480            ErrorCategory::FileNotFound,
481            vec![
482                "Verify the file path exists".to_string(),
483                "Check for typos in the path".to_string(),
484                "Use 'ls' to list directory contents".to_string(),
485            ],
486        );
487        templates.insert(
488            ErrorCategory::FilePermissionDenied,
489            vec![
490                "Check file permissions with ls -la".to_string(),
491                "Use sudo if needed for system files".to_string(),
492            ],
493        );
494        templates.insert(
495            ErrorCategory::FileIsDirectory,
496            vec![
497                "Use a file path, not a directory".to_string(),
498                "Add /* to operate on directory contents".to_string(),
499            ],
500        );
501        templates.insert(
502            ErrorCategory::FileNotDirectory,
503            vec![
504                "Use a directory path, not a file".to_string(),
505                "Check parent directories exist".to_string(),
506            ],
507        );
508        templates.insert(
509            ErrorCategory::FileTooManyOpen,
510            vec![
511                "Close unused file descriptors".to_string(),
512                "Increase ulimit -n value".to_string(),
513            ],
514        );
515
516        // Variable errors
517        templates.insert(
518            ErrorCategory::VariableUnbound,
519            vec![
520                "Initialize variable before use".to_string(),
521                "Use ${VAR:-default} for default values".to_string(),
522                "Check for typos in variable name".to_string(),
523            ],
524        );
525        templates.insert(
526            ErrorCategory::VariableReadonly,
527            vec![
528                "Cannot modify readonly variable".to_string(),
529                "Use a different variable name".to_string(),
530            ],
531        );
532        templates.insert(
533            ErrorCategory::VariableBadSubstitution,
534            vec![
535                "Fix parameter expansion syntax".to_string(),
536                "Check for proper ${} brace matching".to_string(),
537            ],
538        );
539
540        // Process errors
541        templates.insert(
542            ErrorCategory::ProcessSignaled,
543            vec![
544                "Process was killed by signal".to_string(),
545                "Check for memory issues (OOM killer)".to_string(),
546            ],
547        );
548        templates.insert(
549            ErrorCategory::ProcessExitNonZero,
550            vec![
551                "Check command exit status with echo $?".to_string(),
552                "Add error handling with || or set -e".to_string(),
553            ],
554        );
555        templates.insert(
556            ErrorCategory::ProcessTimeout,
557            vec![
558                "Increase timeout value".to_string(),
559                "Optimize the command for better performance".to_string(),
560            ],
561        );
562
563        // Pipe/redirect errors
564        templates.insert(
565            ErrorCategory::PipeBroken,
566            vec![
567                "Check if downstream process exited early".to_string(),
568                "Use || true to ignore SIGPIPE".to_string(),
569            ],
570        );
571        templates.insert(
572            ErrorCategory::RedirectFailed,
573            vec![
574                "Verify target path is writable".to_string(),
575                "Check disk space availability".to_string(),
576            ],
577        );
578        templates.insert(
579            ErrorCategory::HereDocUnterminated,
580            vec![
581                "Add terminating delimiter for here-doc".to_string(),
582                "Ensure delimiter is at start of line with no trailing spaces".to_string(),
583            ],
584        );
585
586        // Unknown
587        templates.insert(
588            ErrorCategory::Unknown,
589            vec!["Review the full error message for details".to_string()],
590        );
591
592        templates
593    }
594}
595
596#[cfg(test)]
597mod tests {
598    #![allow(clippy::expect_used)]
599    use super::*;
600
601    #[test]
602    fn test_oracle_creation() {
603        let oracle = Oracle::new();
604        assert_eq!(oracle.categories.len(), ErrorCategory::all().len());
605        assert!(!oracle.is_trained());
606    }
607
608    #[test]
609    fn test_fix_templates_coverage() {
610        let oracle = Oracle::new();
611        for category in ErrorCategory::all() {
612            assert!(
613                oracle.fix_templates.contains_key(category),
614                "Missing fix template for {category:?}"
615            );
616        }
617    }
618
619    #[test]
620    fn test_drift_detection_insufficient_data() {
621        let mut oracle = Oracle::new();
622        let status = oracle.check_drift(0.95);
623        assert!(matches!(status, DriftStatus::NoDrift));
624    }
625
626    #[test]
627    fn test_default_model_path() {
628        let path = Oracle::default_model_path();
629        assert!(path.to_string_lossy().contains("bashrs_oracle.apr"));
630    }
631
632    #[test]
633    fn test_suggest_fix_fallback() {
634        let oracle = Oracle::new();
635        // Without training, should fall back to keyword classifier
636        let suggestion = oracle.suggest_fix(127, "bash: foo: command not found", None);
637        assert!(
638            suggestion.contains("command") || suggestion.contains("Command"),
639            "Got: {suggestion}"
640        );
641    }
642
643    #[test]
644    fn test_train_from_corpus() {
645        let corpus = Corpus::generate_synthetic(100);
646        let oracle = Oracle::train_from_corpus(&corpus, OracleConfig::default())
647            .expect("Training should succeed");
648
649        assert!(oracle.is_trained());
650
651        // Should be able to classify after training
652        let features = ErrorFeatures::extract(127, "command not found", None);
653        let result = oracle.classify(&features);
654        assert!(result.is_ok());
655    }
656
657    #[test]
658    fn test_classify_error_convenience() {
659        let corpus = Corpus::generate_synthetic(100);
660        let oracle = Oracle::train_from_corpus(&corpus, OracleConfig::default())
661            .expect("Training should succeed");
662
663        let result = oracle
664            .classify_error(127, "bash: foo: command not found", None)
665            .expect("Classification should succeed");
666
667        assert!(result.confidence > 0.0);
668        assert!(result.suggested_fix.is_some());
669    }
670
671    #[test]
672    fn test_save_and_load() {
673        let corpus = Corpus::generate_synthetic(100);
674        let oracle = Oracle::train_from_corpus(&corpus, OracleConfig::default())
675            .expect("Training should succeed");
676
677        let temp_dir = tempfile::tempdir().expect("Failed to create temp dir");
678        let path = temp_dir.path().join("test_model.apr");
679
680        oracle.save(&path).expect("Save should succeed");
681        assert!(path.exists());
682
683        let loaded = Oracle::load(&path).expect("Load should succeed");
684        assert_eq!(loaded.categories.len(), oracle.categories.len());
685        assert!(loaded.is_trained());
686    }
687
688    #[test]
689    fn test_oracle_config_default() {
690        let config = OracleConfig::default();
691        assert_eq!(config.n_estimators, 100);
692        assert_eq!(config.max_depth, 10);
693        assert_eq!(config.random_state, Some(42));
694    }
695}