1use syn::{
7 spanned::Spanned, visit::Visit, Expr, ExprBlock, ExprStruct, File, ImplItem, ImplItemFn, Item,
8 ItemImpl, ReturnType, Stmt, Type,
9};
10
11#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
13pub struct FieldDependency {
14 pub field_name: String,
16
17 pub depends_on: Vec<String>,
19
20 pub initialization_complexity: usize,
22}
23
24#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
26pub struct StructInitPattern {
27 pub struct_name: String,
29
30 pub field_count: usize,
32
33 pub function_lines: usize,
35
36 pub initialization_lines: usize,
38
39 pub initialization_ratio: f64,
41
42 pub avg_nesting_depth: f64,
44
45 pub max_nesting_depth: usize,
47
48 pub field_dependencies: Vec<FieldDependency>,
50
51 pub complex_fields: Vec<String>,
53
54 pub cyclomatic_complexity: usize,
56
57 pub is_result_wrapped: bool,
59
60 pub calls_constructors: bool,
62}
63
64#[derive(Debug, Clone)]
66pub struct ReturnAnalysis {
67 pub returns_struct: bool,
68 pub struct_name: Option<String>,
69 pub field_count: usize,
70 pub field_names: Vec<String>,
71 pub is_result_wrapped: bool,
72}
73
74pub struct StructInitPatternDetector {
76 pub min_field_count: usize,
77 pub min_init_ratio: f64,
78 pub max_nesting_depth: usize,
79}
80
81impl Default for StructInitPatternDetector {
82 fn default() -> Self {
83 Self {
84 min_field_count: 15,
85 min_init_ratio: 0.70,
86 max_nesting_depth: 4,
87 }
88 }
89}
90
91impl StructInitPatternDetector {
92 pub fn new() -> Self {
93 Self::default()
94 }
95
96 pub fn detect(&self, file: &File, file_content: &str) -> Option<StructInitPattern> {
98 let mut detector = StructInitVisitor::new(file_content);
99 detector.visit_file(file);
100
101 detector
103 .patterns
104 .into_iter()
105 .filter(|p| {
106 p.field_count >= self.min_field_count
107 && p.initialization_ratio >= self.min_init_ratio
108 && p.max_nesting_depth <= self.max_nesting_depth
109 })
110 .max_by(|a, b| a.field_count.cmp(&b.field_count))
111 }
112
113 pub fn calculate_init_complexity_score(&self, pattern: &StructInitPattern) -> f64 {
115 let field_score = match pattern.field_count {
116 0..=20 => 1.0,
117 21..=40 => 2.0,
118 41..=60 => 3.5,
119 _ => 5.0,
120 };
121
122 let nesting_penalty = pattern.max_nesting_depth as f64 * 0.5;
123 let complex_field_penalty = pattern.complex_fields.len() as f64 * 1.0;
124
125 field_score + nesting_penalty + complex_field_penalty
126 }
127
128 pub fn generate_recommendation(&self, pattern: &StructInitPattern) -> String {
130 if pattern.field_count > 50 {
131 "Consider builder pattern to reduce initialization complexity".to_string()
132 } else if pattern.complex_fields.len() > 5 {
133 "Extract complex field initializations into helper functions".to_string()
134 } else if pattern.max_nesting_depth > 3 {
135 "Reduce nesting depth in field initialization".to_string()
136 } else {
137 "Initialization is appropriately complex for field count".to_string()
138 }
139 }
140
141 pub fn confidence(&self, pattern: &StructInitPattern) -> f64 {
143 let mut confidence = 0.0;
144
145 if pattern.initialization_ratio > 0.85 {
147 confidence += 0.35;
148 } else if pattern.initialization_ratio > 0.75 {
149 confidence += 0.25;
150 } else if pattern.initialization_ratio > 0.65 {
151 confidence += 0.15;
152 } else {
153 confidence += 0.05;
154 }
155
156 confidence += (pattern.field_count as f64 / 50.0).min(0.25);
158
159 if pattern.max_nesting_depth <= 2 {
161 confidence += 0.20;
162 } else if pattern.max_nesting_depth <= 3 {
163 confidence += 0.10;
164 }
165
166 if pattern.struct_name.contains("Args")
168 || pattern.struct_name.contains("Config")
169 || pattern.struct_name.contains("Options")
170 {
171 confidence += 0.10;
172 }
173
174 if pattern.complex_fields.len() > pattern.field_count / 3 {
176 confidence *= 0.7;
177 }
178
179 confidence.min(1.0)
180 }
181}
182
183struct StructInitVisitor<'a> {
185 patterns: Vec<StructInitPattern>,
186 file_content: &'a str,
187}
188
189impl<'a> StructInitVisitor<'a> {
190 fn new(file_content: &'a str) -> Self {
191 Self {
192 patterns: Vec::new(),
193 file_content,
194 }
195 }
196
197 fn analyze_function(&mut self, function: &ImplItemFn, _impl_block: &ItemImpl) {
198 let return_analysis = analyze_return_statement(function);
200
201 if !return_analysis.returns_struct || return_analysis.field_count == 0 {
202 return;
203 }
204
205 let span = function.span();
207 let start_line = span.start().line;
208 let end_line = span.end().line;
209 let function_lines = count_lines_in_span(self.file_content, start_line, end_line);
210
211 let initialization_lines =
213 estimate_initialization_lines(self.file_content, start_line, end_line);
214
215 let initialization_ratio = initialization_lines as f64 / function_lines as f64;
216
217 let (avg_nesting, max_nesting) = measure_nesting_depth(&function.block);
219
220 let cyclomatic = estimate_cyclomatic_complexity(&function.block);
222
223 let calls_constructors = detect_constructor_calls(&function.block);
225
226 let (field_dependencies, complex_fields) = analyze_field_dependencies_and_complexity(
228 &function.block,
229 &return_analysis.field_names,
230 self.file_content,
231 );
232
233 let pattern = StructInitPattern {
235 struct_name: return_analysis.struct_name.unwrap_or_default(),
236 field_count: return_analysis.field_count,
237 function_lines,
238 initialization_lines,
239 initialization_ratio,
240 avg_nesting_depth: avg_nesting,
241 max_nesting_depth: max_nesting,
242 field_dependencies,
243 complex_fields,
244 cyclomatic_complexity: cyclomatic,
245 is_result_wrapped: return_analysis.is_result_wrapped,
246 calls_constructors,
247 };
248
249 self.patterns.push(pattern);
250 }
251}
252
253impl<'a, 'ast> Visit<'ast> for StructInitVisitor<'a> {
254 fn visit_item(&mut self, item: &'ast Item) {
255 if let Item::Impl(item_impl) = item {
256 for impl_item in &item_impl.items {
257 if let ImplItem::Fn(method) = impl_item {
258 self.analyze_function(method, item_impl);
259 }
260 }
261 }
262 syn::visit::visit_item(self, item);
263 }
264}
265
266fn analyze_return_statement(function: &ImplItemFn) -> ReturnAnalysis {
268 let mut visitor = ReturnStructVisitor {
269 struct_name: None,
270 field_count: 0,
271 field_names: Vec::new(),
272 is_result_wrapped: false,
273 };
274
275 if let ReturnType::Type(_, ty) = &function.sig.output {
277 visitor.is_result_wrapped = is_result_type(ty);
278 }
279
280 visitor.visit_block(&function.block);
282
283 ReturnAnalysis {
284 returns_struct: visitor.struct_name.is_some(),
285 struct_name: visitor.struct_name,
286 field_count: visitor.field_count,
287 field_names: visitor.field_names,
288 is_result_wrapped: visitor.is_result_wrapped,
289 }
290}
291
292struct ReturnStructVisitor {
294 struct_name: Option<String>,
295 field_count: usize,
296 field_names: Vec<String>,
297 is_result_wrapped: bool,
298}
299
300impl<'ast> Visit<'ast> for ReturnStructVisitor {
301 fn visit_expr(&mut self, expr: &'ast Expr) {
302 match expr {
303 Expr::Struct(struct_expr) => {
304 self.extract_struct_info(struct_expr);
305 }
306 Expr::Call(call_expr) => {
307 if let Expr::Path(path) = &*call_expr.func {
309 if path
310 .path
311 .segments
312 .last()
313 .map(|s| s.ident == "Ok")
314 .unwrap_or(false)
315 {
316 if let Some(Expr::Struct(struct_expr)) = call_expr.args.first() {
317 self.extract_struct_info(struct_expr);
318 }
319 }
320 }
321 }
322 _ => {}
323 }
324 syn::visit::visit_expr(self, expr);
325 }
326}
327
328impl ReturnStructVisitor {
329 fn extract_struct_info(&mut self, struct_expr: &ExprStruct) {
330 if let Some(segment) = struct_expr.path.segments.last() {
332 self.struct_name = Some(segment.ident.to_string());
333 }
334
335 self.field_count = struct_expr.fields.len();
337
338 self.field_names = struct_expr
340 .fields
341 .iter()
342 .filter_map(|f| match &f.member {
343 syn::Member::Named(ident) => Some(ident.to_string()),
344 _ => None,
345 })
346 .collect();
347 }
348}
349
350fn is_result_type(ty: &Type) -> bool {
352 if let Type::Path(type_path) = ty {
353 type_path
354 .path
355 .segments
356 .first()
357 .map(|s| s.ident == "Result")
358 .unwrap_or(false)
359 } else {
360 false
361 }
362}
363
364fn count_lines_in_span(content: &str, start_line: usize, end_line: usize) -> usize {
366 content
367 .lines()
368 .enumerate()
369 .skip(start_line.saturating_sub(1))
370 .take(end_line.saturating_sub(start_line) + 1)
371 .filter(|(_, line)| {
372 let trimmed = line.trim();
373 !trimmed.is_empty() && !trimmed.starts_with("//")
374 })
375 .count()
376}
377
378fn estimate_initialization_lines(content: &str, start_line: usize, end_line: usize) -> usize {
380 content
381 .lines()
382 .enumerate()
383 .skip(start_line.saturating_sub(1))
384 .take(end_line.saturating_sub(start_line) + 1)
385 .filter(|(_, line)| {
386 let trimmed = line.trim();
387 trimmed.contains("let ")
389 || trimmed.contains(":")
390 || trimmed.contains("unwrap_or")
391 || trimmed.contains("match")
392 })
393 .count()
394}
395
396fn measure_nesting_depth(block: &syn::Block) -> (f64, usize) {
398 let mut max_depth = 0;
399 let mut depth_sum = 0;
400 let mut node_count = 0;
401
402 measure_depth_recursive(
403 &block.stmts,
404 1,
405 &mut max_depth,
406 &mut depth_sum,
407 &mut node_count,
408 );
409
410 let avg_depth = if node_count > 0 {
411 depth_sum as f64 / node_count as f64
412 } else {
413 0.0
414 };
415
416 (avg_depth, max_depth)
417}
418
419fn measure_depth_recursive(
420 stmts: &[Stmt],
421 current_depth: usize,
422 max_depth: &mut usize,
423 depth_sum: &mut usize,
424 node_count: &mut usize,
425) {
426 *max_depth = (*max_depth).max(current_depth);
427 *depth_sum += current_depth * stmts.len();
428 *node_count += stmts.len();
429
430 for stmt in stmts {
431 match stmt {
432 Stmt::Expr(Expr::If(expr_if), _) => {
433 measure_depth_recursive(
434 &expr_if.then_branch.stmts,
435 current_depth + 1,
436 max_depth,
437 depth_sum,
438 node_count,
439 );
440 }
441 Stmt::Expr(Expr::Match(expr_match), _) => {
442 for arm in &expr_match.arms {
443 if let Expr::Block(ExprBlock { block, .. }) = &*arm.body {
444 measure_depth_recursive(
445 &block.stmts,
446 current_depth + 1,
447 max_depth,
448 depth_sum,
449 node_count,
450 );
451 }
452 }
453 }
454 _ => {}
455 }
456 }
457}
458
459fn estimate_cyclomatic_complexity(block: &syn::Block) -> usize {
461 let mut complexity = 1; count_decision_points(&block.stmts, &mut complexity);
463 complexity
464}
465
466fn count_decision_points(stmts: &[Stmt], complexity: &mut usize) {
467 for stmt in stmts {
468 match stmt {
469 Stmt::Expr(Expr::If(_), _) => {
470 *complexity += 1;
471 }
472 Stmt::Expr(Expr::Match(expr_match), _) => {
473 *complexity += expr_match.arms.len().saturating_sub(1);
474 }
475 Stmt::Expr(Expr::While(_), _) | Stmt::Expr(Expr::ForLoop(_), _) => {
476 *complexity += 1;
477 }
478 _ => {}
479 }
480 }
481}
482
483fn detect_constructor_calls(block: &syn::Block) -> bool {
485 let mut visitor = ConstructorCallVisitor {
486 calls_constructor: false,
487 };
488 visitor.visit_block(block);
489 visitor.calls_constructor
490}
491
492struct ConstructorCallVisitor {
493 calls_constructor: bool,
494}
495
496impl<'ast> Visit<'ast> for ConstructorCallVisitor {
497 fn visit_expr(&mut self, expr: &'ast Expr) {
498 if let Expr::Call(call_expr) = expr {
499 if let Expr::Path(path) = &*call_expr.func {
500 if let Some(segment) = path.path.segments.last() {
501 let name = segment.ident.to_string();
502 if name == "new" || name.starts_with("from_") || name.starts_with("with_") {
503 self.calls_constructor = true;
504 }
505 }
506 }
507 }
508 syn::visit::visit_expr(self, expr);
509 }
510}
511
512fn analyze_field_dependencies_and_complexity(
514 block: &syn::Block,
515 field_names: &[String],
516 file_content: &str,
517) -> (Vec<FieldDependency>, Vec<String>) {
518 let mut field_dependencies = Vec::new();
519 let mut complex_fields = Vec::new();
520
521 let local_bindings = extract_local_bindings(block);
523
524 for field_name in field_names {
526 if let Some(binding) = local_bindings.iter().find(|(name, _)| name == field_name) {
527 let (_name, expr) = binding;
528
529 let span = expr.span();
531 let start_line = span.start().line;
532 let end_line = span.end().line;
533 let init_lines = count_lines_in_span(file_content, start_line, end_line);
534
535 if init_lines > 10 {
537 complex_fields.push(field_name.clone());
538 }
539
540 let depends_on = find_variable_references(expr, &local_bindings);
542
543 if !depends_on.is_empty() || init_lines > 5 {
545 field_dependencies.push(FieldDependency {
546 field_name: field_name.clone(),
547 depends_on,
548 initialization_complexity: init_lines,
549 });
550 }
551 }
552 }
553
554 (field_dependencies, complex_fields)
555}
556
557fn extract_local_bindings(block: &syn::Block) -> Vec<(String, Expr)> {
559 let mut bindings = Vec::new();
560
561 for stmt in &block.stmts {
562 if let Stmt::Local(local) = stmt {
563 if let syn::Pat::Ident(pat_ident) = &local.pat {
564 let var_name = pat_ident.ident.to_string();
565 if let Some(init) = &local.init {
566 bindings.push((var_name, (*init.expr).clone()));
567 }
568 }
569 }
570 }
571
572 bindings
573}
574
575fn find_variable_references(expr: &Expr, local_bindings: &[(String, Expr)]) -> Vec<String> {
577 let mut visitor = VariableRefVisitor {
578 references: Vec::new(),
579 local_vars: local_bindings
580 .iter()
581 .map(|(name, _)| name.clone())
582 .collect(),
583 };
584 visitor.visit_expr(expr);
585 visitor.references
586}
587
588struct VariableRefVisitor {
590 references: Vec<String>,
591 local_vars: Vec<String>,
592}
593
594impl<'ast> Visit<'ast> for VariableRefVisitor {
595 fn visit_expr(&mut self, expr: &'ast Expr) {
596 match expr {
597 Expr::Path(expr_path) => {
598 if let Some(ident) = expr_path.path.get_ident() {
599 let var_name = ident.to_string();
600 if self.local_vars.contains(&var_name) && !self.references.contains(&var_name) {
602 self.references.push(var_name);
603 }
604 }
605 }
606 Expr::Field(expr_field) => {
607 if let Expr::Path(base_path) = &*expr_field.base {
609 if let Some(ident) = base_path.path.get_ident() {
610 let var_name = ident.to_string();
611 if self.local_vars.contains(&var_name)
612 && !self.references.contains(&var_name)
613 {
614 self.references.push(var_name);
615 }
616 }
617 }
618 }
619 _ => {}
620 }
621 syn::visit::visit_expr(self, expr);
622 }
623}
624
625#[cfg(test)]
626mod tests {
627 use super::*;
628
629 fn parse_rust_code(code: &str) -> File {
630 syn::parse_str(code).expect("Failed to parse Rust code")
631 }
632
633 #[test]
634 fn test_detect_struct_init_basic() {
635 let code = r#"
636 pub struct HiArgs {
637 patterns: String,
638 paths: String,
639 column: bool,
640 heading: bool,
641 timeout: u32,
642 retries: u32,
643 max_wait: u32,
644 backoff: u32,
645 host: String,
646 port: u16,
647 path: String,
648 headers: Vec<String>,
649 buffer_size: usize,
650 enable_logging: bool,
651 enable_metrics: bool,
652 }
653
654 impl HiArgs {
655 pub fn from_low_args(low: LowArgs) -> Result<HiArgs> {
656 let column = low.column.unwrap_or(low.vimgrep);
657 let heading = match low.heading {
658 None => !low.vimgrep && true,
659 Some(false) => false,
660 Some(true) => !low.vimgrep,
661 };
662 let timeout = low.timeout.unwrap_or(30);
663 let retries = low.retries.unwrap_or(3);
664 let max_wait = timeout * retries;
665 let backoff = timeout / retries;
666 let host = low.host.unwrap_or_default();
667 let port = low.port.unwrap_or(8080);
668 let path = low.path.unwrap_or_else(|| "/".to_string());
669 let headers = low.headers.unwrap_or_default();
670 let buffer_size = low.buffer_size.unwrap_or(8192);
671 let enable_logging = low.enable_logging;
672 let enable_metrics = low.enable_metrics;
673
674 Ok(HiArgs {
675 patterns: "test".into(),
676 paths: "test".into(),
677 column,
678 heading,
679 timeout,
680 retries,
681 max_wait,
682 backoff,
683 host,
684 port,
685 path,
686 headers,
687 buffer_size,
688 enable_logging,
689 enable_metrics,
690 })
691 }
692 }
693
694 pub struct LowArgs {
695 pub column: Option<bool>,
696 pub vimgrep: bool,
697 pub heading: Option<bool>,
698 pub timeout: Option<u32>,
699 pub retries: Option<u32>,
700 pub host: Option<String>,
701 pub port: Option<u16>,
702 pub path: Option<String>,
703 pub headers: Option<Vec<String>>,
704 pub buffer_size: Option<usize>,
705 pub enable_logging: bool,
706 pub enable_metrics: bool,
707 }
708
709 pub struct Result<T> {
710 value: T,
711 }
712 "#;
713
714 let file = parse_rust_code(code);
715 let detector = StructInitPatternDetector {
717 min_field_count: 10,
718 min_init_ratio: 0.40, max_nesting_depth: 5,
720 };
721
722 let pattern = detector.detect(&file, code);
723 assert!(
724 pattern.is_some(),
725 "Should detect struct initialization pattern"
726 );
727
728 let pattern = pattern.unwrap();
729 assert_eq!(pattern.struct_name, "HiArgs");
730 assert!(pattern.field_count >= 15, "Should detect 15 fields");
731 assert!(
732 pattern.initialization_ratio > 0.40,
733 "Initialization ratio should be > 0.40, got {:.2}",
734 pattern.initialization_ratio
735 );
736 }
737
738 #[test]
739 fn test_field_based_complexity_lower_than_cyclomatic() {
740 let pattern = StructInitPattern {
741 struct_name: "HiArgs".into(),
742 field_count: 40,
743 function_lines: 214,
744 initialization_lines: 180,
745 initialization_ratio: 0.84,
746 avg_nesting_depth: 1.8,
747 max_nesting_depth: 3,
748 field_dependencies: vec![],
749 complex_fields: vec![],
750 cyclomatic_complexity: 42,
751 is_result_wrapped: true,
752 calls_constructors: true,
753 };
754
755 let detector = StructInitPatternDetector::default();
756 let field_score = detector.calculate_init_complexity_score(&pattern);
757
758 assert!(
760 field_score < 10.0,
761 "Field score {} should be < 10.0",
762 field_score
763 );
764 assert!(
765 field_score < pattern.cyclomatic_complexity as f64 / 4.0,
766 "Field score {} should be < cyclomatic/4",
767 field_score
768 );
769 }
770
771 #[test]
772 fn test_not_initialization_business_logic() {
773 let code = r#"
774 impl Calculator {
775 pub fn calculate_scores(data: &[Item]) -> Vec<Score> {
776 data.iter()
777 .filter(|item| item.is_valid())
778 .map(|item| {
779 let base = item.value * 2;
780 let bonus = if item.premium { 10 } else { 0 };
781 Score { value: base + bonus }
782 })
783 .collect()
784 }
785 }
786
787 pub struct Score {
788 value: i32,
789 }
790 "#;
791
792 let file = parse_rust_code(code);
793 let detector = StructInitPatternDetector::default();
794
795 let pattern = detector.detect(&file, code);
796 assert!(
798 pattern.is_none(),
799 "Business logic should not be detected as initialization"
800 );
801 }
802
803 #[test]
804 fn test_confidence_calculation() {
805 let detector = StructInitPatternDetector::default();
806
807 let high_confidence = StructInitPattern {
808 struct_name: "HttpClientArgs".into(),
809 field_count: 35,
810 function_lines: 150,
811 initialization_lines: 130,
812 initialization_ratio: 0.87,
813 avg_nesting_depth: 1.5,
814 max_nesting_depth: 2,
815 field_dependencies: vec![],
816 complex_fields: vec![],
817 cyclomatic_complexity: 38,
818 is_result_wrapped: true,
819 calls_constructors: false,
820 };
821
822 let confidence = detector.confidence(&high_confidence);
823 assert!(
824 confidence > 0.70,
825 "High confidence pattern should score > 0.70, got {}",
826 confidence
827 );
828 }
829}