1use super::result::{CheckError, CheckResult};
7use super::traits::LightCheck;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
11pub enum CascadeStrategy {
12 #[default]
14 TryToAddDerive,
15
16 TryToCallNew,
18
19 SkipAndReport,
21
22 ImmediateError,
24}
25
26#[derive(Debug, Clone)]
28pub struct CascadeResult {
29 pub mutations: Vec<CascadeMutation>,
31 pub skipped: Vec<String>,
33 pub status: CascadeStatus,
35}
36
37impl CascadeResult {
38 pub fn ok() -> Self {
40 Self {
41 mutations: Vec::new(),
42 skipped: Vec::new(),
43 status: CascadeStatus::Success,
44 }
45 }
46
47 pub fn with_mutations(mutations: Vec<CascadeMutation>) -> Self {
49 Self {
50 mutations,
51 skipped: Vec::new(),
52 status: CascadeStatus::Success,
53 }
54 }
55
56 pub fn partial(mutations: Vec<CascadeMutation>, skipped: Vec<String>) -> Self {
58 Self {
59 mutations,
60 skipped,
61 status: CascadeStatus::Partial,
62 }
63 }
64
65 pub fn failed(reason: String) -> Self {
67 Self {
68 mutations: Vec::new(),
69 skipped: vec![reason],
70 status: CascadeStatus::Failed,
71 }
72 }
73
74 pub fn is_success(&self) -> bool {
76 matches!(self.status, CascadeStatus::Success)
77 }
78}
79
80#[derive(Debug, Clone, Copy, PartialEq, Eq)]
82pub enum CascadeStatus {
83 Success,
85 Partial,
87 Failed,
89}
90
91#[derive(Debug, Clone)]
93pub enum CascadeMutation {
94 AddDerive {
96 target: String,
98 derives: Vec<String>,
100 },
101 GenerateImpl {
103 target: String,
105 trait_name: String,
107 call_new: bool,
109 },
110}
111
112pub fn cascade_add_derive(
140 checker: &impl LightCheck,
141 target: &str,
142 trait_name: &str,
143 strategy: CascadeStrategy,
144) -> CascadeResult {
145 cascade_add_derive_recursive(checker, target, trait_name, strategy, &mut Vec::new(), 0)
146}
147
148const MAX_CASCADE_DEPTH: usize = 10;
150
151fn cascade_add_derive_recursive(
152 checker: &impl LightCheck,
153 target: &str,
154 trait_name: &str,
155 strategy: CascadeStrategy,
156 visited: &mut Vec<String>,
157 depth: usize,
158) -> CascadeResult {
159 if depth > MAX_CASCADE_DEPTH {
161 return CascadeResult::failed(format!(
162 "cascade depth exceeded for {}::{}",
163 target, trait_name
164 ));
165 }
166
167 if visited.contains(&target.to_string()) {
169 return CascadeResult::ok();
170 }
171 visited.push(target.to_string());
172
173 let check_result = checker.check_derive_possible(target, trait_name);
175
176 match check_result {
177 CheckResult::Ok => {
178 CascadeResult::with_mutations(vec![CascadeMutation::AddDerive {
180 target: target.to_string(),
181 derives: vec![trait_name.to_string()],
182 }])
183 }
184 CheckResult::Warning(_) => {
185 CascadeResult::with_mutations(vec![CascadeMutation::AddDerive {
187 target: target.to_string(),
188 derives: vec![trait_name.to_string()],
189 }])
190 }
191 CheckResult::Error(errors) => handle_cascade_errors(
192 checker, target, trait_name, strategy, visited, depth, errors,
193 ),
194 }
195}
196
197fn handle_cascade_errors(
198 checker: &impl LightCheck,
199 target: &str,
200 trait_name: &str,
201 strategy: CascadeStrategy,
202 visited: &mut Vec<String>,
203 depth: usize,
204 errors: Vec<CheckError>,
205) -> CascadeResult {
206 let mut all_mutations = Vec::new();
207 let mut skipped = Vec::new();
208
209 for error in errors {
210 if let CheckError::DeriveFailed { missing_impls, .. } = error {
211 for missing in missing_impls {
212 match strategy {
213 CascadeStrategy::TryToAddDerive => {
214 let sub_result = cascade_add_derive_recursive(
216 checker,
217 &missing,
218 trait_name,
219 strategy,
220 visited,
221 depth + 1,
222 );
223
224 match sub_result.status {
225 CascadeStatus::Success => {
226 all_mutations.extend(sub_result.mutations);
227 }
228 CascadeStatus::Partial => {
229 all_mutations.extend(sub_result.mutations);
230 skipped.extend(sub_result.skipped);
231 }
232 CascadeStatus::Failed => {
233 skipped.push(missing);
234 }
235 }
236 }
237 CascadeStrategy::TryToCallNew => {
238 all_mutations.push(CascadeMutation::GenerateImpl {
240 target: missing,
241 trait_name: trait_name.to_string(),
242 call_new: true,
243 });
244 }
245 CascadeStrategy::SkipAndReport => {
246 skipped.push(missing);
247 }
248 CascadeStrategy::ImmediateError => {
249 return CascadeResult::failed(format!(
250 "{} does not implement {}",
251 missing, trait_name
252 ));
253 }
254 }
255 }
256 }
257 }
258
259 all_mutations.push(CascadeMutation::AddDerive {
261 target: target.to_string(),
262 derives: vec![trait_name.to_string()],
263 });
264
265 if skipped.is_empty() {
266 CascadeResult::with_mutations(all_mutations)
267 } else {
268 CascadeResult::partial(all_mutations, skipped)
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275
276 struct MockChecker {
278 symbols: Vec<&'static str>,
279 trait_impls: Vec<(&'static str, &'static str)>,
280 }
281
282 impl LightCheck for MockChecker {
283 fn check_symbol_exists(&self, name: &str) -> bool {
284 self.symbols.contains(&name)
285 }
286
287 fn check_trait_impl(&self, type_name: &str, trait_name: &str) -> bool {
288 self.trait_impls
289 .iter()
290 .any(|(t, tr)| *t == type_name && *tr == trait_name)
291 }
292
293 fn check_derive_possible(&self, target: &str, trait_name: &str) -> CheckResult {
294 if !self.check_symbol_exists(target) {
295 return CheckResult::Error(vec![CheckError::type_not_found(target)]);
296 }
297
298 let _ = self.check_trait_impl(target, trait_name);
301 CheckResult::Ok
302 }
303 }
304
305 #[test]
306 fn test_cascade_simple() {
307 let checker = MockChecker {
308 symbols: vec!["MyStruct"],
309 trait_impls: vec![],
310 };
311
312 let result =
313 cascade_add_derive(&checker, "MyStruct", "Default", CascadeStrategy::default());
314
315 assert!(result.is_success());
316 assert_eq!(result.mutations.len(), 1);
317 }
318
319 #[test]
320 fn test_cascade_strategy_immediate_error() {
321 let checker = MockChecker {
322 symbols: vec![],
323 trait_impls: vec![],
324 };
325
326 let result = cascade_add_derive(
327 &checker,
328 "NonExistent",
329 "Default",
330 CascadeStrategy::ImmediateError,
331 );
332
333 assert!(matches!(
335 result.status,
336 CascadeStatus::Success | CascadeStatus::Failed
337 ));
338 }
339}