1#![allow(clippy::multiple_crate_versions)]
3
4use std::collections::HashMap;
25use std::path::{Path, PathBuf};
26
27use aprender::format::{self, Compression, ModelType, SaveOptions};
28use aprender::metrics::drift::{DriftConfig, DriftDetector, DriftStatus};
29use aprender::primitives::Matrix;
30use aprender::tree::RandomForestClassifier;
31use serde::{Deserialize, Serialize};
32
33pub mod categories;
34pub mod classifier;
35pub mod corpus;
36pub mod features;
37
38pub use categories::ErrorCategory;
39pub use classifier::ErrorClassifier;
40pub use corpus::{Corpus, TrainingExample};
41pub use features::ErrorFeatures;
42
43#[derive(Debug, thiserror::Error)]
45pub enum OracleError {
46 #[error("Model error: {0}")]
48 Model(String),
49 #[error("Feature extraction error: {0}")]
51 Feature(String),
52 #[error("Training error: {0}")]
54 Training(String),
55 #[error("Classification error: {0}")]
57 Classification(String),
58 #[error("IO error: {0}")]
60 Io(#[from] std::io::Error),
61}
62
63pub type Result<T> = std::result::Result<T, OracleError>;
65
66#[derive(Clone, Debug, Serialize, Deserialize)]
68pub struct ClassificationResult {
69 pub category: ErrorCategory,
71 pub confidence: f32,
73 pub suggested_fix: Option<String>,
75 pub related_patterns: Vec<String>,
77}
78
79#[derive(Clone, Debug)]
81pub struct OracleConfig {
82 pub n_estimators: usize,
85 pub max_depth: usize,
87 pub random_state: Option<u64>,
89}
90
91impl Default for OracleConfig {
92 fn default() -> Self {
93 Self {
94 n_estimators: 100,
95 max_depth: 10,
96 random_state: Some(42),
97 }
98 }
99}
100
101const DEFAULT_MODEL_NAME: &str = "bashrs_oracle.apr";
103
104pub struct Oracle {
106 classifier: RandomForestClassifier,
108 #[allow(dead_code)]
110 config: OracleConfig,
111 #[allow(dead_code)]
113 categories: Vec<ErrorCategory>,
114 fix_templates: HashMap<ErrorCategory, Vec<String>>,
116 drift_detector: DriftDetector,
118 performance_history: Vec<f32>,
120 is_trained: bool,
122}
123
124impl Default for Oracle {
125 fn default() -> Self {
126 Self::new()
127 }
128}
129
130impl Oracle {
131 #[must_use]
133 pub fn default_model_path() -> PathBuf {
134 let mut path = std::env::current_dir().unwrap_or_default();
136 for _ in 0..5 {
137 if path.join("Cargo.toml").exists() {
138 return path.join(DEFAULT_MODEL_NAME);
139 }
140 if !path.pop() {
141 break;
142 }
143 }
144 PathBuf::from(DEFAULT_MODEL_NAME)
145 }
146
147 pub fn load_or_train() -> Result<Self> {
152 let path = Self::default_model_path();
153
154 if path.exists() {
155 match Self::load(&path) {
156 Ok(oracle) => return Ok(oracle),
157 Err(e) => {
158 tracing::warn!("Failed to load cached model: {e}. Retraining...");
159 }
160 }
161 }
162
163 let corpus = Corpus::generate_synthetic(5000);
165 let oracle = Self::train_from_corpus(&corpus, OracleConfig::default())?;
166
167 if let Err(e) = oracle.save(&path) {
169 tracing::warn!("Failed to cache model to {}: {e}", path.display());
170 }
171
172 Ok(oracle)
173 }
174
175 #[must_use]
177 pub fn new() -> Self {
178 Self::with_config(OracleConfig::default())
179 }
180
181 #[must_use]
183 pub fn with_config(config: OracleConfig) -> Self {
184 let mut classifier =
185 RandomForestClassifier::new(config.n_estimators).with_max_depth(config.max_depth);
186 if let Some(seed) = config.random_state {
187 classifier = classifier.with_random_state(seed);
188 }
189
190 Self {
191 classifier,
192 config,
193 categories: ErrorCategory::all().to_vec(),
194 fix_templates: Self::default_fix_templates(),
195 drift_detector: DriftDetector::new(
196 DriftConfig::default()
197 .with_min_samples(10)
198 .with_window_size(50),
199 ),
200 performance_history: Vec::new(),
201 is_trained: false,
202 }
203 }
204
205 pub fn train_from_corpus(corpus: &Corpus, config: OracleConfig) -> Result<Self> {
210 let (x, y) = corpus.to_training_data();
211
212 let n_samples = x.len();
214 let n_features = x.first().map_or(0, |row| row.len());
215 let flat: Vec<f32> = x.into_iter().flatten().collect();
216 let features = Matrix::from_vec(n_samples, n_features, flat)
217 .map_err(|e| OracleError::Training(format!("Failed to create feature matrix: {e}")))?;
218 let labels: Vec<usize> = y.into_iter().map(|l| l as usize).collect();
219
220 let mut oracle = Self::with_config(config);
221 oracle.train(&features, &labels)?;
222
223 Ok(oracle)
224 }
225
226 pub fn train(&mut self, features: &Matrix<f32>, labels: &[usize]) -> Result<()> {
231 self.classifier
232 .fit(features, labels)
233 .map_err(|e| OracleError::Training(e.to_string()))?;
234 self.is_trained = true;
235
236 Ok(())
237 }
238
239 pub fn classify(&self, features: &ErrorFeatures) -> Result<ClassificationResult> {
241 if !self.is_trained {
242 let kw_classifier = ErrorClassifier::new();
244 let category = kw_classifier.classify_by_keywords(
245 &features
246 .features
247 .iter()
248 .map(|f| f.to_string())
249 .collect::<Vec<_>>()
250 .join(" "),
251 );
252 return Ok(ClassificationResult {
253 category,
254 confidence: 0.5,
255 suggested_fix: Some(category.fix_suggestion().to_string()),
256 related_patterns: vec![],
257 });
258 }
259
260 let feature_matrix = Matrix::from_vec(1, ErrorFeatures::SIZE, features.as_slice().to_vec())
261 .map_err(|e| {
262 OracleError::Classification(format!("Failed to create feature matrix: {e}"))
263 })?;
264 let predictions = self.classifier.predict(&feature_matrix);
265
266 let pred_idx = predictions
267 .as_slice()
268 .first()
269 .copied()
270 .ok_or_else(|| OracleError::Classification("No prediction produced".to_string()))?;
271 let category = ErrorCategory::from_label_index(pred_idx);
272
273 let suggested_fix = self
274 .fix_templates
275 .get(&category)
276 .and_then(|fixes| fixes.first().cloned());
277
278 let related = self
279 .fix_templates
280 .get(&category)
281 .map(|fixes| fixes.iter().skip(1).cloned().collect())
282 .unwrap_or_default();
283
284 Ok(ClassificationResult {
285 category,
286 confidence: 0.85, suggested_fix,
288 related_patterns: related,
289 })
290 }
291
292 pub fn classify_error(
294 &self,
295 exit_code: i32,
296 stderr: &str,
297 command: Option<&str>,
298 ) -> Result<ClassificationResult> {
299 let features = ErrorFeatures::extract(exit_code, stderr, command);
300 self.classify(&features)
301 }
302
303 #[must_use]
305 pub fn suggest_fix(&self, exit_code: i32, stderr: &str, command: Option<&str>) -> String {
306 if !self.is_trained {
308 let kw_classifier = ErrorClassifier::new();
309 let category = kw_classifier.classify_by_keywords(stderr);
310 let confidence = kw_classifier.confidence(stderr, category);
311 return format!(
312 "[{:.0}% confident] {}: {}",
313 confidence * 100.0,
314 category.name(),
315 category.fix_suggestion()
316 );
317 }
318
319 match self.classify_error(exit_code, stderr, command) {
320 Ok(result) => {
321 format!(
322 "[{:.0}% confident] {}: {}",
323 result.confidence * 100.0,
324 result.category.name(),
325 result
326 .suggested_fix
327 .unwrap_or_else(|| result.category.fix_suggestion().to_string())
328 )
329 }
330 Err(_) => {
331 let kw_classifier = ErrorClassifier::new();
333 let category = kw_classifier.classify_by_keywords(stderr);
334 format!(
335 "[keyword] {}: {}",
336 category.name(),
337 category.fix_suggestion()
338 )
339 }
340 }
341 }
342
343 pub fn check_drift(&mut self, recent_accuracy: f32) -> DriftStatus {
345 self.performance_history.push(recent_accuracy);
346
347 if self.performance_history.len() < 10 {
348 return DriftStatus::NoDrift;
349 }
350
351 let mid = self.performance_history.len() / 2;
352 let baseline: Vec<f32> = self
353 .performance_history
354 .get(..mid)
355 .map(|s| s.to_vec())
356 .unwrap_or_default();
357 let current: Vec<f32> = self
358 .performance_history
359 .get(mid..)
360 .map(|s| s.to_vec())
361 .unwrap_or_default();
362
363 self.drift_detector
364 .detect_performance_drift(&baseline, ¤t)
365 }
366
367 pub fn save(&self, path: &Path) -> Result<()> {
372 let options = SaveOptions::default()
373 .with_name("bashrs-oracle")
374 .with_description("RandomForest error classification model for bashrs shell linter")
375 .with_compression(Compression::ZstdDefault); format::save(&self.classifier, ModelType::RandomForest, path, options)
378 .map_err(|e| OracleError::Model(e.to_string()))?;
379
380 Ok(())
381 }
382
383 pub fn load(path: &Path) -> Result<Self> {
388 let classifier: RandomForestClassifier = format::load(path, ModelType::RandomForest)
389 .map_err(|e| OracleError::Model(e.to_string()))?;
390
391 let config = OracleConfig::default();
392 Ok(Self {
393 classifier,
394 config,
395 categories: ErrorCategory::all().to_vec(),
396 fix_templates: Self::default_fix_templates(),
397 drift_detector: DriftDetector::new(
398 DriftConfig::default()
399 .with_min_samples(10)
400 .with_window_size(50),
401 ),
402 performance_history: Vec::new(),
403 is_trained: true,
404 })
405 }
406
407 #[must_use]
409 pub fn is_trained(&self) -> bool {
410 self.is_trained
411 }
412
413 fn default_fix_templates() -> HashMap<ErrorCategory, Vec<String>> {
415 let mut templates = HashMap::new();
416
417 templates.insert(
419 ErrorCategory::SyntaxQuoteMismatch,
420 vec![
421 "Check for unmatched quotes (' or \")".to_string(),
422 "Use shellcheck to identify the exact location".to_string(),
423 ],
424 );
425 templates.insert(
426 ErrorCategory::SyntaxBracketMismatch,
427 vec![
428 "Check for unmatched brackets ([], {}, ())".to_string(),
429 "Ensure conditionals have proper [ ] or [[ ]] syntax".to_string(),
430 ],
431 );
432 templates.insert(
433 ErrorCategory::SyntaxUnexpectedToken,
434 vec![
435 "Review syntax near the reported token".to_string(),
436 "Check for missing 'then', 'do', or 'fi'".to_string(),
437 ],
438 );
439 templates.insert(
440 ErrorCategory::SyntaxMissingOperand,
441 vec![
442 "Add missing operand to the expression".to_string(),
443 "Check arithmetic expressions for completeness".to_string(),
444 ],
445 );
446
447 templates.insert(
449 ErrorCategory::CommandNotFound,
450 vec![
451 "Check PATH or install the missing command".to_string(),
452 "Verify the command name spelling".to_string(),
453 "Try 'which <command>' or 'type <command>'".to_string(),
454 ],
455 );
456 templates.insert(
457 ErrorCategory::CommandPermissionDenied,
458 vec![
459 "Use chmod +x to make the script executable".to_string(),
460 "Run with sudo if elevated privileges needed".to_string(),
461 ],
462 );
463 templates.insert(
464 ErrorCategory::CommandInvalidOption,
465 vec![
466 "Check command documentation with --help or man page".to_string(),
467 "Verify option syntax (single dash vs double dash)".to_string(),
468 ],
469 );
470 templates.insert(
471 ErrorCategory::CommandMissingArgument,
472 vec![
473 "Provide required argument to the command".to_string(),
474 "Check command usage with --help".to_string(),
475 ],
476 );
477
478 templates.insert(
480 ErrorCategory::FileNotFound,
481 vec![
482 "Verify the file path exists".to_string(),
483 "Check for typos in the path".to_string(),
484 "Use 'ls' to list directory contents".to_string(),
485 ],
486 );
487 templates.insert(
488 ErrorCategory::FilePermissionDenied,
489 vec![
490 "Check file permissions with ls -la".to_string(),
491 "Use sudo if needed for system files".to_string(),
492 ],
493 );
494 templates.insert(
495 ErrorCategory::FileIsDirectory,
496 vec![
497 "Use a file path, not a directory".to_string(),
498 "Add /* to operate on directory contents".to_string(),
499 ],
500 );
501 templates.insert(
502 ErrorCategory::FileNotDirectory,
503 vec![
504 "Use a directory path, not a file".to_string(),
505 "Check parent directories exist".to_string(),
506 ],
507 );
508 templates.insert(
509 ErrorCategory::FileTooManyOpen,
510 vec![
511 "Close unused file descriptors".to_string(),
512 "Increase ulimit -n value".to_string(),
513 ],
514 );
515
516 templates.insert(
518 ErrorCategory::VariableUnbound,
519 vec![
520 "Initialize variable before use".to_string(),
521 "Use ${VAR:-default} for default values".to_string(),
522 "Check for typos in variable name".to_string(),
523 ],
524 );
525 templates.insert(
526 ErrorCategory::VariableReadonly,
527 vec![
528 "Cannot modify readonly variable".to_string(),
529 "Use a different variable name".to_string(),
530 ],
531 );
532 templates.insert(
533 ErrorCategory::VariableBadSubstitution,
534 vec![
535 "Fix parameter expansion syntax".to_string(),
536 "Check for proper ${} brace matching".to_string(),
537 ],
538 );
539
540 templates.insert(
542 ErrorCategory::ProcessSignaled,
543 vec![
544 "Process was killed by signal".to_string(),
545 "Check for memory issues (OOM killer)".to_string(),
546 ],
547 );
548 templates.insert(
549 ErrorCategory::ProcessExitNonZero,
550 vec![
551 "Check command exit status with echo $?".to_string(),
552 "Add error handling with || or set -e".to_string(),
553 ],
554 );
555 templates.insert(
556 ErrorCategory::ProcessTimeout,
557 vec![
558 "Increase timeout value".to_string(),
559 "Optimize the command for better performance".to_string(),
560 ],
561 );
562
563 templates.insert(
565 ErrorCategory::PipeBroken,
566 vec![
567 "Check if downstream process exited early".to_string(),
568 "Use || true to ignore SIGPIPE".to_string(),
569 ],
570 );
571 templates.insert(
572 ErrorCategory::RedirectFailed,
573 vec![
574 "Verify target path is writable".to_string(),
575 "Check disk space availability".to_string(),
576 ],
577 );
578 templates.insert(
579 ErrorCategory::HereDocUnterminated,
580 vec![
581 "Add terminating delimiter for here-doc".to_string(),
582 "Ensure delimiter is at start of line with no trailing spaces".to_string(),
583 ],
584 );
585
586 templates.insert(
588 ErrorCategory::Unknown,
589 vec!["Review the full error message for details".to_string()],
590 );
591
592 templates
593 }
594}
595
596#[cfg(test)]
597mod tests {
598 #![allow(clippy::expect_used)]
599 use super::*;
600
601 #[test]
602 fn test_oracle_creation() {
603 let oracle = Oracle::new();
604 assert_eq!(oracle.categories.len(), ErrorCategory::all().len());
605 assert!(!oracle.is_trained());
606 }
607
608 #[test]
609 fn test_fix_templates_coverage() {
610 let oracle = Oracle::new();
611 for category in ErrorCategory::all() {
612 assert!(
613 oracle.fix_templates.contains_key(category),
614 "Missing fix template for {category:?}"
615 );
616 }
617 }
618
619 #[test]
620 fn test_drift_detection_insufficient_data() {
621 let mut oracle = Oracle::new();
622 let status = oracle.check_drift(0.95);
623 assert!(matches!(status, DriftStatus::NoDrift));
624 }
625
626 #[test]
627 fn test_default_model_path() {
628 let path = Oracle::default_model_path();
629 assert!(path.to_string_lossy().contains("bashrs_oracle.apr"));
630 }
631
632 #[test]
633 fn test_suggest_fix_fallback() {
634 let oracle = Oracle::new();
635 let suggestion = oracle.suggest_fix(127, "bash: foo: command not found", None);
637 assert!(
638 suggestion.contains("command") || suggestion.contains("Command"),
639 "Got: {suggestion}"
640 );
641 }
642
643 #[test]
644 fn test_train_from_corpus() {
645 let corpus = Corpus::generate_synthetic(100);
646 let oracle = Oracle::train_from_corpus(&corpus, OracleConfig::default())
647 .expect("Training should succeed");
648
649 assert!(oracle.is_trained());
650
651 let features = ErrorFeatures::extract(127, "command not found", None);
653 let result = oracle.classify(&features);
654 assert!(result.is_ok());
655 }
656
657 #[test]
658 fn test_classify_error_convenience() {
659 let corpus = Corpus::generate_synthetic(100);
660 let oracle = Oracle::train_from_corpus(&corpus, OracleConfig::default())
661 .expect("Training should succeed");
662
663 let result = oracle
664 .classify_error(127, "bash: foo: command not found", None)
665 .expect("Classification should succeed");
666
667 assert!(result.confidence > 0.0);
668 assert!(result.suggested_fix.is_some());
669 }
670
671 #[test]
672 fn test_save_and_load() {
673 let corpus = Corpus::generate_synthetic(100);
674 let oracle = Oracle::train_from_corpus(&corpus, OracleConfig::default())
675 .expect("Training should succeed");
676
677 let temp_dir = tempfile::tempdir().expect("Failed to create temp dir");
678 let path = temp_dir.path().join("test_model.apr");
679
680 oracle.save(&path).expect("Save should succeed");
681 assert!(path.exists());
682
683 let loaded = Oracle::load(&path).expect("Load should succeed");
684 assert_eq!(loaded.categories.len(), oracle.categories.len());
685 assert!(loaded.is_trained());
686 }
687
688 #[test]
689 fn test_oracle_config_default() {
690 let config = OracleConfig::default();
691 assert_eq!(config.n_estimators, 100);
692 assert_eq!(config.max_depth, 10);
693 assert_eq!(config.random_state, Some(42));
694 }
695}