use super::result::{CheckError, CheckResult};
use super::traits::LightCheck;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum CascadeStrategy {
#[default]
TryToAddDerive,
TryToCallNew,
SkipAndReport,
ImmediateError,
}
#[derive(Debug, Clone)]
pub struct CascadeResult {
pub mutations: Vec<CascadeMutation>,
pub skipped: Vec<String>,
pub status: CascadeStatus,
}
impl CascadeResult {
pub fn ok() -> Self {
Self {
mutations: Vec::new(),
skipped: Vec::new(),
status: CascadeStatus::Success,
}
}
pub fn with_mutations(mutations: Vec<CascadeMutation>) -> Self {
Self {
mutations,
skipped: Vec::new(),
status: CascadeStatus::Success,
}
}
pub fn partial(mutations: Vec<CascadeMutation>, skipped: Vec<String>) -> Self {
Self {
mutations,
skipped,
status: CascadeStatus::Partial,
}
}
pub fn failed(reason: String) -> Self {
Self {
mutations: Vec::new(),
skipped: vec![reason],
status: CascadeStatus::Failed,
}
}
pub fn is_success(&self) -> bool {
matches!(self.status, CascadeStatus::Success)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CascadeStatus {
Success,
Partial,
Failed,
}
#[derive(Debug, Clone)]
pub enum CascadeMutation {
AddDerive {
target: String,
derives: Vec<String>,
},
GenerateImpl {
target: String,
trait_name: String,
call_new: bool,
},
}
pub fn cascade_add_derive(
checker: &impl LightCheck,
target: &str,
trait_name: &str,
strategy: CascadeStrategy,
) -> CascadeResult {
cascade_add_derive_recursive(checker, target, trait_name, strategy, &mut Vec::new(), 0)
}
const MAX_CASCADE_DEPTH: usize = 10;
fn cascade_add_derive_recursive(
checker: &impl LightCheck,
target: &str,
trait_name: &str,
strategy: CascadeStrategy,
visited: &mut Vec<String>,
depth: usize,
) -> CascadeResult {
if depth > MAX_CASCADE_DEPTH {
return CascadeResult::failed(format!(
"cascade depth exceeded for {}::{}",
target, trait_name
));
}
if visited.contains(&target.to_string()) {
return CascadeResult::ok();
}
visited.push(target.to_string());
let check_result = checker.check_derive_possible(target, trait_name);
match check_result {
CheckResult::Ok => {
CascadeResult::with_mutations(vec![CascadeMutation::AddDerive {
target: target.to_string(),
derives: vec![trait_name.to_string()],
}])
}
CheckResult::Warning(_) => {
CascadeResult::with_mutations(vec![CascadeMutation::AddDerive {
target: target.to_string(),
derives: vec![trait_name.to_string()],
}])
}
CheckResult::Error(errors) => handle_cascade_errors(
checker, target, trait_name, strategy, visited, depth, errors,
),
}
}
fn handle_cascade_errors(
checker: &impl LightCheck,
target: &str,
trait_name: &str,
strategy: CascadeStrategy,
visited: &mut Vec<String>,
depth: usize,
errors: Vec<CheckError>,
) -> CascadeResult {
let mut all_mutations = Vec::new();
let mut skipped = Vec::new();
for error in errors {
if let CheckError::DeriveFailed { missing_impls, .. } = error {
for missing in missing_impls {
match strategy {
CascadeStrategy::TryToAddDerive => {
let sub_result = cascade_add_derive_recursive(
checker,
&missing,
trait_name,
strategy,
visited,
depth + 1,
);
match sub_result.status {
CascadeStatus::Success => {
all_mutations.extend(sub_result.mutations);
}
CascadeStatus::Partial => {
all_mutations.extend(sub_result.mutations);
skipped.extend(sub_result.skipped);
}
CascadeStatus::Failed => {
skipped.push(missing);
}
}
}
CascadeStrategy::TryToCallNew => {
all_mutations.push(CascadeMutation::GenerateImpl {
target: missing,
trait_name: trait_name.to_string(),
call_new: true,
});
}
CascadeStrategy::SkipAndReport => {
skipped.push(missing);
}
CascadeStrategy::ImmediateError => {
return CascadeResult::failed(format!(
"{} does not implement {}",
missing, trait_name
));
}
}
}
}
}
all_mutations.push(CascadeMutation::AddDerive {
target: target.to_string(),
derives: vec![trait_name.to_string()],
});
if skipped.is_empty() {
CascadeResult::with_mutations(all_mutations)
} else {
CascadeResult::partial(all_mutations, skipped)
}
}
#[cfg(test)]
mod tests {
use super::*;
struct MockChecker {
symbols: Vec<&'static str>,
trait_impls: Vec<(&'static str, &'static str)>,
}
impl LightCheck for MockChecker {
fn check_symbol_exists(&self, name: &str) -> bool {
self.symbols.contains(&name)
}
fn check_trait_impl(&self, type_name: &str, trait_name: &str) -> bool {
self.trait_impls
.iter()
.any(|(t, tr)| *t == type_name && *tr == trait_name)
}
fn check_derive_possible(&self, target: &str, trait_name: &str) -> CheckResult {
if !self.check_symbol_exists(target) {
return CheckResult::Error(vec![CheckError::type_not_found(target)]);
}
let _ = self.check_trait_impl(target, trait_name);
CheckResult::Ok
}
}
#[test]
fn test_cascade_simple() {
let checker = MockChecker {
symbols: vec!["MyStruct"],
trait_impls: vec![],
};
let result =
cascade_add_derive(&checker, "MyStruct", "Default", CascadeStrategy::default());
assert!(result.is_success());
assert_eq!(result.mutations.len(), 1);
}
#[test]
fn test_cascade_strategy_immediate_error() {
let checker = MockChecker {
symbols: vec![],
trait_impls: vec![],
};
let result = cascade_add_derive(
&checker,
"NonExistent",
"Default",
CascadeStrategy::ImmediateError,
);
assert!(matches!(
result.status,
CascadeStatus::Success | CascadeStatus::Failed
));
}
}