Skip to main content

enya_analyzer/
diff.rs

1//! Diff parsing and semantic extraction.
2//!
3//! This module parses unified diffs and extracts semantic information
4//! about what changed (functions, metrics, imports).
5
6use std::path::Path;
7
8use crate::repo::DiffSemantics;
9
10/// Parses a unified diff and extracts semantic information.
11///
12/// This function analyzes the diff to identify:
13/// - Functions that were added, removed, or modified
14/// - Metric instrumentation changes
15/// - Import statement changes
16#[must_use]
17pub fn extract_semantics(diff: &str) -> DiffSemantics {
18    let mut semantics = DiffSemantics::default();
19
20    // Parse the diff into hunks
21    let hunks = parse_diff_hunks(diff);
22
23    for hunk in hunks {
24        // Extract function changes from this hunk
25        extract_function_changes(&hunk, &mut semantics);
26
27        // Extract metric changes
28        extract_metric_changes(&hunk, &mut semantics);
29
30        // Extract import changes
31        extract_import_changes(&hunk, &mut semantics);
32    }
33
34    // Deduplicate
35    semantics.functions_added.sort();
36    semantics.functions_added.dedup();
37    semantics.functions_removed.sort();
38    semantics.functions_removed.dedup();
39    semantics.functions_modified.sort();
40    semantics.functions_modified.dedup();
41    semantics.metrics_added.sort();
42    semantics.metrics_added.dedup();
43    semantics.metrics_removed.sort();
44    semantics.metrics_removed.dedup();
45    semantics.imports_added.sort();
46    semantics.imports_added.dedup();
47    semantics.imports_removed.sort();
48    semantics.imports_removed.dedup();
49
50    semantics
51}
52
53/// A parsed diff hunk with added and removed lines.
54#[derive(Debug, Default)]
55struct DiffHunk {
56    /// File path this hunk belongs to
57    file_path: String,
58    /// Lines that were added (without the + prefix)
59    added_lines: Vec<String>,
60    /// Lines that were removed (without the - prefix)
61    removed_lines: Vec<String>,
62    /// The function context from @@ header (if available)
63    function_context: Option<String>,
64}
65
66/// Language type detected from file extension.
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68enum Language {
69    Rust,
70    Go,
71    Python,
72    JavaScript,
73    Unknown,
74}
75
76impl Language {
77    fn from_path(path: &str) -> Self {
78        let path = Path::new(path);
79        match path.extension().and_then(|e| e.to_str()) {
80            Some(ext) if ext.eq_ignore_ascii_case("rs") => Self::Rust,
81            Some(ext) if ext.eq_ignore_ascii_case("go") => Self::Go,
82            Some(ext) if ext.eq_ignore_ascii_case("py") => Self::Python,
83            Some(ext)
84                if ext.eq_ignore_ascii_case("js")
85                    || ext.eq_ignore_ascii_case("ts")
86                    || ext.eq_ignore_ascii_case("jsx")
87                    || ext.eq_ignore_ascii_case("tsx") =>
88            {
89                Self::JavaScript
90            }
91            _ => Self::Unknown,
92        }
93    }
94}
95
96/// Parse a unified diff into individual hunks.
97fn parse_diff_hunks(diff: &str) -> Vec<DiffHunk> {
98    let mut hunks = Vec::new();
99    let mut current_file = String::new();
100    let mut current_hunk: Option<DiffHunk> = None;
101
102    for line in diff.lines() {
103        // New file header: +++ b/path/to/file.rs
104        if let Some(path) = line.strip_prefix("+++ b/") {
105            current_file = path.to_string();
106            continue;
107        }
108        if let Some(path) = line.strip_prefix("+++ ") {
109            // Handle +++ a/path format
110            current_file = path.strip_prefix("a/").unwrap_or(path).to_string();
111            continue;
112        }
113
114        // Hunk header: @@ -start,count +start,count @@ optional function context
115        if line.starts_with("@@") {
116            // Save previous hunk if any
117            if let Some(hunk) = current_hunk.take() {
118                if !hunk.added_lines.is_empty() || !hunk.removed_lines.is_empty() {
119                    hunks.push(hunk);
120                }
121            }
122
123            // Extract function context from @@ header
124            let function_context = extract_function_from_hunk_header(line);
125
126            current_hunk = Some(DiffHunk {
127                file_path: current_file.clone(),
128                function_context,
129                ..Default::default()
130            });
131            continue;
132        }
133
134        // Skip diff metadata lines
135        if line.starts_with("diff --git")
136            || line.starts_with("index ")
137            || line.starts_with("--- ")
138            || line.starts_with("Binary files")
139            || line.starts_with("new file mode")
140            || line.starts_with("deleted file mode")
141        {
142            continue;
143        }
144
145        // Collect added/removed lines
146        if let Some(hunk) = &mut current_hunk {
147            if let Some(added) = line.strip_prefix('+') {
148                hunk.added_lines.push(added.to_string());
149            } else if let Some(removed) = line.strip_prefix('-') {
150                hunk.removed_lines.push(removed.to_string());
151            }
152            // Context lines (starting with space) are ignored
153        }
154    }
155
156    // Don't forget the last hunk
157    if let Some(hunk) = current_hunk {
158        if !hunk.added_lines.is_empty() || !hunk.removed_lines.is_empty() {
159            hunks.push(hunk);
160        }
161    }
162
163    hunks
164}
165
166/// Extract function name from @@ hunk header.
167///
168/// Format: `@@ -start,count +start,count @@ fn function_name(...)`
169fn extract_function_from_hunk_header(line: &str) -> Option<String> {
170    // Find the second @@ and get everything after it
171    let parts: Vec<&str> = line.splitn(3, "@@").collect();
172    if parts.len() < 3 {
173        return None;
174    }
175
176    let context = parts[2].trim();
177    if context.is_empty() {
178        return None;
179    }
180
181    // Try to extract function name from various patterns
182    // Rust: fn function_name, pub fn function_name, async fn function_name
183    // Go: func FunctionName, func (r *Receiver) MethodName
184    // Python: def function_name
185    // JavaScript/TypeScript: function functionName
186
187    // Rust patterns
188    if let Some(idx) = context.find("fn ") {
189        let after_fn = &context[idx + 3..];
190        if let Some(name) = extract_identifier(after_fn) {
191            return Some(name);
192        }
193    }
194
195    // Go patterns
196    if let Some(idx) = context.find("func ") {
197        let after_func = &context[idx + 5..];
198        // Skip receiver: func (r *Receiver)
199        let name_part = if after_func.starts_with('(') {
200            // Find closing paren and get the name after
201            if let Some(close) = after_func.find(')') {
202                after_func[close + 1..].trim()
203            } else {
204                after_func
205            }
206        } else {
207            after_func
208        };
209        if let Some(name) = extract_identifier(name_part) {
210            return Some(name);
211        }
212    }
213
214    // Python patterns
215    if let Some(idx) = context.find("def ") {
216        let after_def = &context[idx + 4..];
217        if let Some(name) = extract_identifier(after_def) {
218            return Some(name);
219        }
220    }
221
222    // JavaScript/TypeScript function keyword
223    if let Some(idx) = context.find("function ") {
224        let after_function = &context[idx + 9..];
225        if let Some(name) = extract_identifier(after_function) {
226            return Some(name);
227        }
228    }
229
230    None
231}
232
233/// Extract an identifier (function name) from the start of a string.
234fn extract_identifier(s: &str) -> Option<String> {
235    let s = s.trim();
236    if s.is_empty() {
237        return None;
238    }
239
240    let mut chars = s.chars().peekable();
241
242    // First char must be alphabetic or underscore
243    let first = chars.next()?;
244    if !first.is_alphabetic() && first != '_' {
245        return None;
246    }
247
248    let mut name = String::new();
249    name.push(first);
250
251    // Rest can be alphanumeric or underscore
252    for c in chars {
253        if c.is_alphanumeric() || c == '_' {
254            name.push(c);
255        } else {
256            break;
257        }
258    }
259
260    if name.is_empty() { None } else { Some(name) }
261}
262
263/// Extract function changes from a hunk using pattern matching.
264fn extract_function_changes(hunk: &DiffHunk, semantics: &mut DiffSemantics) {
265    // If we have a function context from the @@ header, that function was modified
266    if let Some(ref func_name) = hunk.function_context {
267        // Only add to modified if we have actual content changes (not just context)
268        if !hunk.added_lines.is_empty() || !hunk.removed_lines.is_empty() {
269            semantics.functions_modified.push(func_name.clone());
270        }
271    }
272
273    let lang = Language::from_path(&hunk.file_path);
274
275    // Look for function definitions in added lines
276    for line in &hunk.added_lines {
277        if let Some(name) = extract_function_definition(line, lang) {
278            semantics.functions_added.push(name);
279        }
280    }
281
282    // Look for function definitions in removed lines
283    for line in &hunk.removed_lines {
284        if let Some(name) = extract_function_definition(line, lang) {
285            semantics.functions_removed.push(name);
286        }
287    }
288}
289
290/// Extract a function definition from a single line of code.
291fn extract_function_definition(line: &str, lang: Language) -> Option<String> {
292    let line = line.trim();
293
294    match lang {
295        Language::Rust => {
296            // Rust: fn name, pub fn name, async fn name, pub async fn name
297            if let Some(idx) = line.find("fn ") {
298                // Make sure 'fn' is at word boundary (not part of another word)
299                let is_word_boundary =
300                    idx == 0 || !line.chars().nth(idx - 1).is_some_and(char::is_alphanumeric);
301                if is_word_boundary {
302                    let after_fn = &line[idx + 3..];
303                    return extract_identifier(after_fn);
304                }
305            }
306        }
307        Language::Go => {
308            // Go: func Name or func (r *Receiver) Name
309            if let Some(rest) = line.strip_prefix("func ") {
310                // Skip receiver if present
311                let name_part = if rest.starts_with('(') {
312                    if let Some(close) = rest.find(')') {
313                        rest[close + 1..].trim()
314                    } else {
315                        rest
316                    }
317                } else {
318                    rest
319                };
320                return extract_identifier(name_part);
321            }
322        }
323        Language::Python => {
324            // Python: def name or async def name
325            if let Some(idx) = line.find("def ") {
326                let after_def = &line[idx + 4..];
327                return extract_identifier(after_def);
328            }
329        }
330        Language::JavaScript => {
331            // JavaScript/TypeScript: function name, async function name
332            if let Some(idx) = line.find("function ") {
333                let after_function = &line[idx + 9..];
334                return extract_identifier(after_function);
335            }
336            // Arrow function: const name = (...) => or let name = function
337            if line.contains(" = (") || line.contains(" = function") {
338                if let Some(rest) = line
339                    .strip_prefix("const ")
340                    .or_else(|| line.strip_prefix("let "))
341                    .or_else(|| line.strip_prefix("var "))
342                {
343                    return extract_identifier(rest);
344                }
345            }
346        }
347        Language::Unknown => {}
348    }
349
350    None
351}
352
353/// Extract metric instrumentation changes from a hunk.
354fn extract_metric_changes(hunk: &DiffHunk, semantics: &mut DiffSemantics) {
355    // Patterns that indicate metric instrumentation
356    let metric_patterns = [
357        // Prometheus/metrics-rs patterns
358        ".inc()",
359        ".inc_by(",
360        ".dec()",
361        ".dec_by(",
362        ".set(",
363        ".observe(",
364        ".record(",
365        ".add(",
366        // Counter/Gauge/Histogram constructors
367        "Counter::new(",
368        "Gauge::new(",
369        "Histogram::new(",
370        "IntCounter::new(",
371        "IntGauge::new(",
372        "register_counter!",
373        "register_gauge!",
374        "register_histogram!",
375        // Go prometheus patterns
376        "prometheus.NewCounter(",
377        "prometheus.NewGauge(",
378        "prometheus.NewHistogram(",
379        "promauto.NewCounter(",
380        "promauto.NewGauge(",
381        "promauto.NewHistogram(",
382        ".WithLabelValues(",
383        ".With(",
384        // Python prometheus patterns
385        "Counter(",
386        "Gauge(",
387        "Histogram(",
388        "Summary(",
389    ];
390
391    for line in &hunk.added_lines {
392        for pattern in &metric_patterns {
393            if line.contains(pattern) {
394                // Try to extract the metric name
395                if let Some(name) = extract_metric_name(line) {
396                    semantics.metrics_added.push(name);
397                }
398                break;
399            }
400        }
401    }
402
403    for line in &hunk.removed_lines {
404        for pattern in &metric_patterns {
405            if line.contains(pattern) {
406                if let Some(name) = extract_metric_name(line) {
407                    semantics.metrics_removed.push(name);
408                }
409                break;
410            }
411        }
412    }
413}
414
415/// Try to extract a metric name from a line of code.
416fn extract_metric_name(line: &str) -> Option<String> {
417    // Look for quoted strings that look like metric names
418    // e.g., "http_requests_total", 'grpc_latency_seconds'
419
420    let mut in_quote = false;
421    let mut quote_char = '"';
422    let mut current_string = String::new();
423    let mut found_strings = Vec::new();
424
425    for c in line.chars() {
426        if !in_quote && (c == '"' || c == '\'') {
427            in_quote = true;
428            quote_char = c;
429            current_string.clear();
430        } else if in_quote && c == quote_char {
431            in_quote = false;
432            if looks_like_metric_name(&current_string) {
433                found_strings.push(current_string.clone());
434            }
435        } else if in_quote {
436            current_string.push(c);
437        }
438    }
439
440    // Return the first string that looks like a metric name
441    found_strings.into_iter().next()
442}
443
444/// Check if a string looks like a Prometheus metric name.
445fn looks_like_metric_name(s: &str) -> bool {
446    // Metric names typically:
447    // - Contain underscores
448    // - Are lowercase
449    // - End with _total, _count, _sum, _bucket, _seconds, _bytes, etc.
450    // - Or contain common metric words
451
452    if s.len() < 3 || !s.contains('_') {
453        return false;
454    }
455
456    let s_lower = s.to_lowercase();
457
458    // Common metric suffixes
459    let suffixes = [
460        "_total", "_count", "_sum", "_bucket", "_seconds", "_bytes", "_info", "_created", "_gauge",
461        "_counter",
462    ];
463
464    // Common metric words
465    let keywords = [
466        "request",
467        "response",
468        "error",
469        "latency",
470        "duration",
471        "http",
472        "grpc",
473        "queue",
474        "cache",
475        "connection",
476        "active",
477        "pending",
478    ];
479
480    // Check for suffix match
481    for suffix in &suffixes {
482        if s_lower.ends_with(suffix) {
483            return true;
484        }
485    }
486
487    // Check for keyword match
488    for keyword in &keywords {
489        if s_lower.contains(keyword) {
490            return true;
491        }
492    }
493
494    false
495}
496
497/// Extract import statement changes from a hunk.
498fn extract_import_changes(hunk: &DiffHunk, semantics: &mut DiffSemantics) {
499    let lang = Language::from_path(&hunk.file_path);
500
501    for line in &hunk.added_lines {
502        if let Some(import) = extract_import(line, lang) {
503            semantics.imports_added.push(import);
504        }
505    }
506
507    for line in &hunk.removed_lines {
508        if let Some(import) = extract_import(line, lang) {
509            semantics.imports_removed.push(import);
510        }
511    }
512}
513
514/// Extract an import statement from a line.
515fn extract_import(line: &str, lang: Language) -> Option<String> {
516    let line = line.trim();
517
518    match lang {
519        Language::Rust => {
520            if line.starts_with("use ") {
521                // Rust: use foo::bar;
522                let import = line
523                    .strip_prefix("use ")?
524                    .trim_end_matches(';')
525                    .trim()
526                    .to_string();
527                return Some(import);
528            }
529        }
530        Language::Go => {
531            if line.starts_with("import ") {
532                // Go: import "path/to/package" or import name "path"
533                if let Some(start) = line.find('"') {
534                    if let Some(end) = line[start + 1..].find('"') {
535                        return Some(line[start + 1..start + 1 + end].to_string());
536                    }
537                }
538            }
539        }
540        Language::Python => {
541            if line.starts_with("import ") {
542                // Python: import foo
543                return Some(line.strip_prefix("import ")?.trim().to_string());
544            }
545            if line.starts_with("from ") {
546                // Python: from foo import bar
547                let parts: Vec<&str> = line.split_whitespace().collect();
548                if parts.len() >= 2 {
549                    return Some(parts[1].to_string());
550                }
551            }
552        }
553        Language::JavaScript => {
554            if line.starts_with("import ") {
555                // JavaScript/TypeScript: import { foo } from 'bar'
556                if let Some(from_idx) = line.find(" from ") {
557                    let after_from = &line[from_idx + 6..];
558                    let path = after_from
559                        .trim()
560                        .trim_matches(|c| c == '\'' || c == '"' || c == ';');
561                    return Some(path.to_string());
562                }
563            }
564            if line.starts_with("require(") || line.contains("require(") {
565                // CommonJS: require('foo')
566                if let Some(start) = line.find("require(") {
567                    let after = &line[start + 8..];
568                    if let Some(quote_start) = after.find(['\'', '"']) {
569                        let quote_char = after.chars().nth(quote_start)?;
570                        let path_start = quote_start + 1;
571                        if let Some(end) = after[path_start..].find(quote_char) {
572                            return Some(after[path_start..path_start + end].to_string());
573                        }
574                    }
575                }
576            }
577        }
578        Language::Unknown => {}
579    }
580
581    None
582}
583
584#[cfg(test)]
585mod tests {
586    use super::*;
587
588    #[test]
589    fn test_parse_diff_hunks() {
590        let diff = r#"diff --git a/src/main.rs b/src/main.rs
591index 1234567..abcdefg 100644
592--- a/src/main.rs
593+++ b/src/main.rs
594@@ -10,6 +10,8 @@ fn main() {
595     let x = 1;
596+    let y = 2;
597+    let z = 3;
598     println!("Hello");
599-    let old = 5;
600 }
601"#;
602
603        let hunks = parse_diff_hunks(diff);
604        assert_eq!(hunks.len(), 1);
605        assert_eq!(hunks[0].file_path, "src/main.rs");
606        assert_eq!(hunks[0].added_lines.len(), 2);
607        assert_eq!(hunks[0].removed_lines.len(), 1);
608        assert_eq!(hunks[0].function_context, Some("main".to_string()));
609    }
610
611    #[test]
612    fn test_extract_function_definition_rust() {
613        assert_eq!(
614            extract_function_definition("fn foo() {", Language::Rust),
615            Some("foo".to_string())
616        );
617        assert_eq!(
618            extract_function_definition("pub fn bar(x: i32) -> i32 {", Language::Rust),
619            Some("bar".to_string())
620        );
621        assert_eq!(
622            extract_function_definition("async fn baz() {", Language::Rust),
623            Some("baz".to_string())
624        );
625        assert_eq!(
626            extract_function_definition("    pub async fn qux() {", Language::Rust),
627            Some("qux".to_string())
628        );
629    }
630
631    #[test]
632    fn test_extract_function_definition_go() {
633        assert_eq!(
634            extract_function_definition("func Foo() {", Language::Go),
635            Some("Foo".to_string())
636        );
637        assert_eq!(
638            extract_function_definition("func (s *Server) Handle() {", Language::Go),
639            Some("Handle".to_string())
640        );
641    }
642
643    #[test]
644    fn test_extract_function_definition_python() {
645        assert_eq!(
646            extract_function_definition("def foo():", Language::Python),
647            Some("foo".to_string())
648        );
649        assert_eq!(
650            extract_function_definition("async def bar():", Language::Python),
651            Some("bar".to_string())
652        );
653    }
654
655    #[test]
656    fn test_extract_metric_name() {
657        assert_eq!(
658            extract_metric_name(r#"counter.with_label_values(&["http_requests_total"]).inc()"#),
659            Some("http_requests_total".to_string())
660        );
661        assert_eq!(
662            extract_metric_name(r#"histogram.observe("grpc_latency_seconds", 0.5)"#),
663            Some("grpc_latency_seconds".to_string())
664        );
665    }
666
667    #[test]
668    fn test_looks_like_metric_name() {
669        assert!(looks_like_metric_name("http_requests_total"));
670        assert!(looks_like_metric_name("grpc_latency_seconds"));
671        assert!(looks_like_metric_name("cache_hits_count"));
672        assert!(!looks_like_metric_name("foo")); // No underscore
673        assert!(!looks_like_metric_name("ab")); // Too short
674    }
675
676    #[test]
677    fn test_extract_semantics_function_added() {
678        let diff = r#"diff --git a/src/lib.rs b/src/lib.rs
679--- a/src/lib.rs
680+++ b/src/lib.rs
681@@ -1,3 +1,7 @@
682 fn existing() {}
683+
684+fn new_function() {
685+    println!("hello");
686+}
687"#;
688
689        let semantics = extract_semantics(diff);
690        assert!(
691            semantics
692                .functions_added
693                .contains(&"new_function".to_string())
694        );
695    }
696
697    #[test]
698    fn test_extract_semantics_metric_change() {
699        let diff = r#"diff --git a/src/server.rs b/src/server.rs
700--- a/src/server.rs
701+++ b/src/server.rs
702@@ -10,6 +10,7 @@ fn handle_request() {
703     process();
704+    counter.with_label_values(&["http_requests_total"]).inc();
705 }
706"#;
707
708        let semantics = extract_semantics(diff);
709        assert!(
710            semantics
711                .metrics_added
712                .contains(&"http_requests_total".to_string())
713        );
714    }
715
716    #[test]
717    fn test_extract_import_rust() {
718        assert_eq!(
719            extract_import("use std::collections::HashMap;", Language::Rust),
720            Some("std::collections::HashMap".to_string())
721        );
722    }
723
724    #[test]
725    fn test_extract_import_python() {
726        assert_eq!(
727            extract_import("import prometheus_client", Language::Python),
728            Some("prometheus_client".to_string())
729        );
730        assert_eq!(
731            extract_import("from prometheus_client import Counter", Language::Python),
732            Some("prometheus_client".to_string())
733        );
734    }
735
736    #[test]
737    fn test_extract_import_javascript() {
738        assert_eq!(
739            extract_import(
740                "import { Counter } from 'prom-client';",
741                Language::JavaScript
742            ),
743            Some("prom-client".to_string())
744        );
745    }
746}