use crate::error::{SammError, SourceLocation};
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RecoveryAction {
Skip,
Insert(String),
Abort,
UseDefault(String),
Replace(String),
}
#[derive(Debug, Clone)]
pub struct ErrorRecoveryStrategy {
pub max_errors: usize,
pub auto_correct_typos: bool,
pub auto_insert_punctuation: bool,
pub skip_malformed: bool,
pub use_defaults: bool,
pub custom_rules: HashMap<String, RecoveryAction>,
}
impl Default for ErrorRecoveryStrategy {
fn default() -> Self {
Self {
max_errors: 100,
auto_correct_typos: true,
auto_insert_punctuation: true,
skip_malformed: true,
use_defaults: false,
custom_rules: HashMap::new(),
}
}
}
impl ErrorRecoveryStrategy {
pub fn strict() -> Self {
Self {
max_errors: 1,
auto_correct_typos: false,
auto_insert_punctuation: false,
skip_malformed: false,
use_defaults: false,
custom_rules: HashMap::new(),
}
}
pub fn lenient() -> Self {
Self {
max_errors: 1000,
auto_correct_typos: true,
auto_insert_punctuation: true,
skip_malformed: true,
use_defaults: true,
custom_rules: HashMap::new(),
}
}
pub fn recover(&self, error: &SammError, context: &str) -> RecoveryAction {
if let Some(action) = self.check_custom_rules(error) {
return action;
}
match error {
SammError::ParseError(msg) | SammError::ParseErrorWithLocation { message: msg, .. } => {
self.recover_from_parse_error(msg, context)
}
SammError::ValidationError(msg)
| SammError::ValidationErrorWithLocation { message: msg, .. } => {
self.recover_from_validation_error(msg)
}
SammError::MissingElement(elem) => {
if self.use_defaults {
RecoveryAction::UseDefault(format!("urn:samm:default:1.0.0#{}", elem))
} else {
RecoveryAction::Abort
}
}
SammError::InvalidUrn(msg) => {
if self.auto_correct_typos {
self.try_correct_urn(msg)
} else {
RecoveryAction::Skip
}
}
_ => RecoveryAction::Abort,
}
}
fn check_custom_rules(&self, error: &SammError) -> Option<RecoveryAction> {
let error_msg = match error {
SammError::ParseError(msg) => msg,
SammError::ParseErrorWithLocation { message, .. } => message,
SammError::ValidationError(msg) => msg,
SammError::InvalidUrn(msg) => msg,
_ => return None,
};
for (pattern, action) in &self.custom_rules {
if error_msg.contains(pattern) {
return Some(action.clone());
}
}
None
}
fn recover_from_parse_error(&self, msg: &str, context: &str) -> RecoveryAction {
if (msg.contains("expected ';'") || msg.contains("missing semicolon"))
&& self.auto_insert_punctuation
{
return RecoveryAction::Insert(";".to_string());
}
if (msg.contains("expected '.'") || msg.contains("missing period"))
&& self.auto_insert_punctuation
{
return RecoveryAction::Insert(".".to_string());
}
if (msg.contains("unclosed bracket") || msg.contains("expected ']'"))
&& self.auto_insert_punctuation
{
return RecoveryAction::Insert("]".to_string());
}
if (msg.contains("unclosed parenthesis") || msg.contains("expected ')"))
&& self.auto_insert_punctuation
{
return RecoveryAction::Insert(")".to_string());
}
if msg.contains("undefined prefix") && self.auto_correct_typos {
return self.try_correct_prefix(msg, context);
}
if (msg.contains("malformed triple") || msg.contains("invalid syntax"))
&& self.skip_malformed
{
return RecoveryAction::Skip;
}
RecoveryAction::Abort
}
fn recover_from_validation_error(&self, msg: &str) -> RecoveryAction {
if msg.contains("missing required") && self.use_defaults {
return RecoveryAction::UseDefault("default_value".to_string());
}
if msg.contains("invalid type") && self.auto_correct_typos {
return self.try_correct_datatype(msg);
}
RecoveryAction::Skip
}
fn try_correct_urn(&self, msg: &str) -> RecoveryAction {
if msg.contains("urn:bamm:") {
return RecoveryAction::Replace("urn:samm:".to_string());
}
if msg.contains("missing '#'") {
return RecoveryAction::Insert("#".to_string());
}
if msg.contains("missing version") {
return RecoveryAction::Insert("1.0.0".to_string());
}
RecoveryAction::Skip
}
fn try_correct_prefix(&self, msg: &str, context: &str) -> RecoveryAction {
let common_prefixes = vec![
(
"samm",
"@prefix samm: <urn:samm:org.eclipse.esmf.samm:meta-model:2.3.0#> .",
),
(
"samm-c",
"@prefix samm-c: <urn:samm:org.eclipse.esmf.samm:characteristic:2.3.0#> .",
),
(
"samm-e",
"@prefix samm-e: <urn:samm:org.eclipse.esmf.samm:entity:2.3.0#> .",
),
("xsd", "@prefix xsd: <http://www.w3.org/2001/XMLSchema#> ."),
];
for (prefix, definition) in common_prefixes {
if msg.contains(prefix) {
return RecoveryAction::Insert(definition.to_string());
}
}
RecoveryAction::Skip
}
fn try_correct_datatype(&self, msg: &str) -> RecoveryAction {
let corrections = vec![
("String", "xsd:string"),
("Integer", "xsd:integer"),
("Boolean", "xsd:boolean"),
("Float", "xsd:float"),
("Double", "xsd:double"),
("Date", "xsd:date"),
("DateTime", "xsd:dateTime"),
];
for (typo, correct) in corrections {
if msg.contains(typo) {
return RecoveryAction::Replace(correct.to_string());
}
}
RecoveryAction::Skip
}
pub fn add_custom_rule(&mut self, pattern: String, action: RecoveryAction) {
self.custom_rules.insert(pattern, action);
}
pub fn is_recoverable(&self, error: &SammError) -> bool {
!matches!(self.recover(error, ""), RecoveryAction::Abort)
}
}
#[derive(Debug)]
pub struct RecoveryContext {
pub error_count: usize,
pub recovered_errors: Vec<(SammError, RecoveryAction)>,
pub fatal_errors: Vec<SammError>,
pub strategy: ErrorRecoveryStrategy,
}
impl RecoveryContext {
pub fn new(strategy: ErrorRecoveryStrategy) -> Self {
Self {
error_count: 0,
recovered_errors: Vec::new(),
fatal_errors: Vec::new(),
strategy,
}
}
pub fn record_recovery(&mut self, error: SammError, action: RecoveryAction) {
self.error_count += 1;
self.recovered_errors.push((error, action));
}
pub fn record_fatal(&mut self, error: SammError) {
self.error_count += 1;
self.fatal_errors.push(error);
}
pub fn is_max_errors_exceeded(&self) -> bool {
self.error_count >= self.strategy.max_errors
}
pub fn total_errors(&self) -> usize {
self.error_count
}
pub fn success_rate(&self) -> f64 {
if self.error_count == 0 {
1.0
} else {
self.recovered_errors.len() as f64 / self.error_count as f64
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_strategy() {
let strategy = ErrorRecoveryStrategy::default();
assert_eq!(strategy.max_errors, 100);
assert!(strategy.auto_correct_typos);
assert!(strategy.skip_malformed);
}
#[test]
fn test_strict_strategy() {
let strategy = ErrorRecoveryStrategy::strict();
assert_eq!(strategy.max_errors, 1);
assert!(!strategy.auto_correct_typos);
assert!(!strategy.skip_malformed);
}
#[test]
fn test_lenient_strategy() {
let strategy = ErrorRecoveryStrategy::lenient();
assert_eq!(strategy.max_errors, 1000);
assert!(strategy.auto_correct_typos);
assert!(strategy.use_defaults);
}
#[test]
fn test_recover_missing_semicolon() {
let strategy = ErrorRecoveryStrategy::default();
let error = SammError::ParseError("expected ';'".to_string());
let action = strategy.recover(&error, "");
assert_eq!(action, RecoveryAction::Insert(";".to_string()));
}
#[test]
fn test_recover_malformed_triple() {
let strategy = ErrorRecoveryStrategy::default();
let error = SammError::ParseError("malformed triple".to_string());
let action = strategy.recover(&error, "");
assert_eq!(action, RecoveryAction::Skip);
}
#[test]
fn test_recover_old_bamm_namespace() {
let strategy = ErrorRecoveryStrategy::default();
let error = SammError::InvalidUrn("urn:bamm: is deprecated".to_string());
let action = strategy.recover(&error, "");
assert_eq!(action, RecoveryAction::Replace("urn:samm:".to_string()));
}
#[test]
fn test_custom_recovery_rule() {
let mut strategy = ErrorRecoveryStrategy::default();
strategy.add_custom_rule(
"my custom error".to_string(),
RecoveryAction::UseDefault("custom_value".to_string()),
);
let error = SammError::ParseError("my custom error occurred".to_string());
let action = strategy.recover(&error, "");
assert_eq!(
action,
RecoveryAction::UseDefault("custom_value".to_string())
);
}
#[test]
fn test_is_recoverable() {
let strategy = ErrorRecoveryStrategy::default();
let recoverable = SammError::ParseError("expected ';'".to_string());
let fatal = SammError::Other("unknown error".to_string());
assert!(strategy.is_recoverable(&recoverable));
assert!(!strategy.is_recoverable(&fatal));
}
#[test]
fn test_recovery_context() {
let strategy = ErrorRecoveryStrategy::default();
let mut context = RecoveryContext::new(strategy);
assert_eq!(context.total_errors(), 0);
assert_eq!(context.success_rate(), 1.0);
context.record_recovery(
SammError::ParseError("test".to_string()),
RecoveryAction::Skip,
);
assert_eq!(context.total_errors(), 1);
assert_eq!(context.success_rate(), 1.0);
context.record_fatal(SammError::Other("fatal".to_string()));
assert_eq!(context.total_errors(), 2);
assert_eq!(context.success_rate(), 0.5);
}
#[test]
fn test_max_errors_exceeded() {
let strategy = ErrorRecoveryStrategy {
max_errors: 2,
..Default::default()
};
let mut context = RecoveryContext::new(strategy);
assert!(!context.is_max_errors_exceeded());
context.record_recovery(
SammError::ParseError("test1".to_string()),
RecoveryAction::Skip,
);
assert!(!context.is_max_errors_exceeded());
context.record_recovery(
SammError::ParseError("test2".to_string()),
RecoveryAction::Skip,
);
assert!(context.is_max_errors_exceeded());
}
#[test]
fn test_datatype_correction() {
let strategy = ErrorRecoveryStrategy::default();
let error = SammError::ValidationError("invalid type String".to_string());
let action = strategy.recover(&error, "");
assert_eq!(action, RecoveryAction::Replace("xsd:string".to_string()));
}
}