Skip to main content

grit_lib/
userdiff.rs

1//! User-defined and built-in diff function-name matching.
2//!
3//! This module implements the subset of Git's `userdiff` behavior needed for
4//! hunk-header function context extraction.
5
6use crate::attributes::{collect_attrs_for_path, AttrValue, MacroTable};
7use crate::config::ConfigSet;
8use crate::crlf::{get_file_attrs, AttrRule, DiffAttr};
9use regex::{Regex, RegexBuilder};
10use std::collections::BTreeMap;
11use std::io::Write;
12use std::process::{Command, Stdio};
13use std::sync::OnceLock;
14
15/// Built-in diff driver funcname patterns (same strings as Git's userdiff builtin drivers).
16const BUILTIN_PATTERN_DEFS: &[(&str, &str, bool)] = &[
17    (
18        "ada",
19        r"!^(.*[ 	])?(is[ 	]+new|renames|is[ 	]+separate)([ 	].*)?$
20!^[ 	]*with[ 	].*$
21^[ 	]*((procedure|function)[ 	]+.*)$
22^[ 	]*((package|protected|task)[ 	]+.*)$",
23        true,
24    ),
25    (
26        "bash",
27        r"^[ 	]*((([a-zA-Z_][a-zA-Z0-9_]*[ 	]*\([ 	]*\))|(function[ 	]+[a-zA-Z_][a-zA-Z0-9_]*(([ 	]*\([ 	]*\))|([ 	]+)))).*$)",
28        false,
29    ),
30    (
31        "bibtex",
32        r#"(@[a-zA-Z]{1,}[ 	]*\{{0,1}[ 	]*[^ 	"@',\#}{~%]*).*$"#,
33        false,
34    ),
35    (
36        "cpp",
37        r"!^[ 	]*[A-Za-z_][A-Za-z_0-9]*:[[:space:]]*($|/[/*])
38^((::[[:space:]]*)?[A-Za-z_].*)$",
39        false,
40    ),
41    (
42        "csharp",
43        r"!(^|[ 	]+)(do|while|for|foreach|if|else|new|default|return|switch|case|throw|catch|using|lock|fixed)([ 	(]+|$)
44^[ 	]*(([][[:alnum:]@_.](<[][[:alnum:]@_, 	<>]+>)?)+([ 	]+([][[:alnum:]@_.](<[][[:alnum:]@_, 	<>]+>)?)+)+[ 	]*\([^;]*)$
45^[ 	]*(([][[:alnum:]@_.](<[][[:alnum:]@_, 	<>]+>)?)+([ 	]+([][[:alnum:]@_.](<[][[:alnum:]@_, 	<>]+>)?)+)+[^;=:,()]*)$
46^[ 	]*(((static|public|internal|private|protected|new|unsafe|sealed|abstract|partial)[ 	]+)*(class|enum|interface|struct|record)[ 	]+.*)$
47^[ 	]*(namespace[ 	]+.*)$",
48        false,
49    ),
50    (
51        "css",
52        r"![:;][[:space:]]*$
53^[:[@.#]?[_a-z0-9].*$",
54        true,
55    ),
56    (
57        "dts",
58        r"!;
59!=
60^[ 	]*((/[ 	]*\{|&?[a-zA-Z_]).*)",
61        false,
62    ),
63    (
64        "elixir",
65        r"^[ 	]*((def(macro|module|impl|protocol|p)?|test)[ 	].*)$",
66        false,
67    ),
68    (
69        "fortran",
70        r#"!^([C*]|[ 	]*!)
71!^[ 	]*MODULE[ 	]+PROCEDURE[ 	]
72^[ 	]*((END[ 	]+)?(PROGRAM|MODULE|BLOCK[ 	]+DATA|([^!'" 	]+[ 	]+)*(SUBROUTINE|FUNCTION))[ 	]+[A-Z].*)$"#,
73        true,
74    ),
75    (
76        "fountain",
77        r"^((\.[^.]|(int|ext|est|int\.?/ext|i/e)[. ]).*)$",
78        true,
79    ),
80    (
81        "golang",
82        r"^[ 	]*(func[ 	]*.*(\{[ 	]*)?)
83^[ 	]*(type[ 	].*(struct|interface)[ 	]*(\{[ 	]*)?)",
84        false,
85    ),
86    ("html", r"^[ 	]*(<[Hh][1-6]([ 	].*)?>.*)$", false),
87    ("ini", r"^[ 	]*\[[^]]+\]", false),
88    (
89        "java",
90        r"!^[ 	]*(catch|do|for|if|instanceof|new|return|switch|throw|while)
91^[ 	]*(([a-z-]+[ 	]+)*(class|enum|interface|record)[ 	]+.*)$
92^[ 	]*(([A-Za-z_<>&][][?&<>.,A-Za-z_0-9]*[ 	]+)+[A-Za-z_][A-Za-z_0-9]*[ 	]*\([^;]*)$",
93        false,
94    ),
95    (
96        "kotlin",
97        r"^[ 	]*(([a-z]+[ 	]+)*(fun|class|interface)[ 	]+.*)$",
98        false,
99    ),
100    ("markdown", r"^ {0,3}#{1,6}[ 	].*", false),
101    (
102        "matlab",
103        r"^[[:space:]]*((classdef|function)[[:space:]].*)$|^(%%%?|##)[[:space:]].*$",
104        false,
105    ),
106    (
107        "objc",
108        r"!^[ 	]*(do|for|if|else|return|switch|while)
109^[ 	]*([-+][ 	]*\([ 	]*[A-Za-z_][A-Za-z_0-9* 	]*\)[ 	]*[A-Za-z_].*)$
110^[ 	]*(([A-Za-z_][A-Za-z_0-9]*[ 	]+)+[A-Za-z_][A-Za-z_0-9]*[ 	]*\([^;]*)$
111^(@(implementation|interface|protocol)[ 	].*)$",
112        false,
113    ),
114    (
115        "pascal",
116        r"^(((class[ 	]+)?(procedure|function)|constructor|destructor|interface|implementation|initialization|finalization)[ 	]*.*)$
117^(.*=[ 	]*(class|record).*)$",
118        false,
119    ),
120    (
121        "perl",
122        r"^package .*
123^sub [[:alnum:]_':]+[ 	]*(\([^)]*\)[ 	]*)?(:[^;#]*)?(\{[ 	]*)?(#.*)?$
124^(BEGIN|END|INIT|CHECK|UNITCHECK|AUTOLOAD|DESTROY)[ 	]*(\{[ 	]*)?(#.*)?$
125^=head[0-9] .*",
126        false,
127    ),
128    (
129        "php",
130        r"^[	 ]*(((public|protected|private|static|abstract|final)[	 ]+)*function.*)$
131^[	 ]*((((final|abstract)[	 ]+)?class|enum|interface|trait).*)$",
132        false,
133    ),
134    ("python", r"^[ 	]*((class|(async[ 	]+)?def)[ 	].*)$", false),
135    (
136        "r",
137        r"^[ 	]*([a-zA-z][a-zA-Z0-9_.]*[ 	]*(<-|=)[ 	]*function.*)$",
138        false,
139    ),
140    ("ruby", r"^[ 	]*((class|module|def)[ 	].*)$", false),
141    (
142        "rust",
143        r#"^[	 ]*((pub(\([^\)]+\))?[	 ]+)?((async|const|unsafe|extern([	 ]+"[^"]+"))[	 ]+)?(struct|enum|union|mod|trait|fn|impl|macro_rules!)[< 	]+[^;]*)$"#,
144        false,
145    ),
146    (
147        "scheme",
148        r"^[	 ]*(\(((define|def(struct|syntax|class|method|rules|record|proto|alias)?)[-*/ 	]|(library|module|struct|class)[*+ 	]).*)$",
149        false,
150    ),
151    (
152        "tex",
153        r"^(\\((sub)*section|chapter|part)\*{0,1}\{.*)$",
154        false,
155    ),
156];
157
158#[derive(Debug, Clone)]
159struct FuncRule {
160    matcher: RuleMatcher,
161    negate: bool,
162}
163
164#[derive(Debug, Clone)]
165enum RuleMatcher {
166    Rust(Regex),
167    Posix { pattern: String, ignore_case: bool },
168}
169
170#[derive(Debug, Clone)]
171struct BuiltinPattern {
172    pattern: String,
173    ignore_case: bool,
174}
175
176/// Compiled function-name matcher used for diff hunk headers.
177#[derive(Debug, Clone)]
178pub struct FuncnameMatcher {
179    rules: Vec<FuncRule>,
180}
181
182impl FuncnameMatcher {
183    /// Match a source line against configured funcname rules.
184    ///
185    /// Returns the text to show after the hunk header when matched.
186    #[must_use]
187    pub fn match_line(&self, line: &str) -> Option<String> {
188        let mut text = line;
189        if let Some(stripped) = text.strip_suffix('\n') {
190            text = stripped;
191            if let Some(stripped_cr) = text.strip_suffix('\r') {
192                text = stripped_cr;
193            }
194        }
195
196        for rule in &self.rules {
197            let matched_text = match &rule.matcher {
198                RuleMatcher::Rust(regex) => {
199                    let Some(caps) = regex.captures(text) else {
200                        continue;
201                    };
202                    caps.get(1)
203                        .or_else(|| caps.get(0))
204                        .map(|m| m.as_str())
205                        .unwrap_or_default()
206                        .trim_end_matches(char::is_whitespace)
207                        .to_owned()
208                }
209                RuleMatcher::Posix {
210                    pattern,
211                    ignore_case,
212                } => {
213                    if !posix_line_matches(pattern, *ignore_case, text) {
214                        continue;
215                    }
216                    text.trim_end_matches(char::is_whitespace).to_owned()
217                }
218            };
219            if rule.negate {
220                return None;
221            }
222            return Some(matched_text);
223        }
224        None
225    }
226}
227
228/// Resolve a function-name matcher for `rel_path` from attributes + config.
229///
230/// Returns `Ok(None)` when no diff driver is configured for the path.
231pub fn matcher_for_path(
232    config: &ConfigSet,
233    rules: &[AttrRule],
234    rel_path: &str,
235) -> Result<Option<FuncnameMatcher>, String> {
236    let attrs = get_file_attrs(rules, rel_path, false, config);
237    let DiffAttr::Driver(ref driver) = attrs.diff_attr else {
238        return Ok(None);
239    };
240    matcher_for_driver(config, driver)
241}
242
243/// Like [`matcher_for_path`] but uses parsed `.gitattributes` rules from [`crate::attributes`].
244pub fn matcher_for_path_parsed(
245    config: &ConfigSet,
246    rules: &[crate::attributes::AttrRule],
247    macros: &MacroTable,
248    rel_path: &str,
249    ignore_case: bool,
250) -> Result<Option<FuncnameMatcher>, String> {
251    let map = collect_attrs_for_path(rules, macros, rel_path, ignore_case);
252    let Some(AttrValue::Value(driver)) = map.get("diff") else {
253        return Ok(None);
254    };
255    matcher_for_driver(config, driver.as_str())
256}
257
258/// Resolve a function-name matcher for a named diff driver.
259///
260/// Returns `Ok(None)` when the driver has no built-in or configured funcname
261/// pattern.
262pub fn matcher_for_driver(
263    config: &ConfigSet,
264    driver: &str,
265) -> Result<Option<FuncnameMatcher>, String> {
266    if let Some(pattern) = config.get(&format!("diff.{driver}.xfuncname")) {
267        return compile_matcher(&pattern, true, false).map(Some);
268    }
269    if let Some(pattern) = config.get(&format!("diff.{driver}.funcname")) {
270        return compile_matcher(&pattern, false, false).map(Some);
271    }
272    if let Some(builtin) = builtin_patterns().get(driver) {
273        return compile_matcher(&builtin.pattern, true, builtin.ignore_case).map(Some);
274    }
275    Ok(None)
276}
277
278fn compile_matcher(
279    pattern: &str,
280    extended: bool,
281    ignore_case: bool,
282) -> Result<FuncnameMatcher, String> {
283    let lines: Vec<&str> = pattern.split('\n').collect();
284    if lines.is_empty() {
285        return Ok(FuncnameMatcher { rules: Vec::new() });
286    }
287
288    let mut rules = Vec::with_capacity(lines.len());
289    for (idx, raw) in lines.iter().enumerate() {
290        let mut line = *raw;
291        let negate = line.starts_with('!');
292        if negate {
293            if idx == lines.len() - 1 {
294                return Err(format!("Last expression must not be negated: {line}"));
295            }
296            line = &line[1..];
297        }
298
299        let rust_pattern = if extended {
300            fix_charclass_escapes(line)
301        } else {
302            bre_to_ere(line)
303        };
304        let posix_pattern = if extended {
305            line.to_owned()
306        } else {
307            bre_to_ere(line)
308        };
309
310        validate_posix_regex_via_grep(&posix_pattern, ignore_case)
311            .map_err(|_| format!("Invalid regexp to look for hunk header: {line}"))?;
312
313        let matcher = RegexBuilder::new(&rust_pattern)
314            .case_insensitive(ignore_case)
315            .build()
316            .map(RuleMatcher::Rust)
317            .unwrap_or_else(|_| RuleMatcher::Posix {
318                pattern: posix_pattern,
319                ignore_case,
320            });
321        rules.push(FuncRule { matcher, negate });
322    }
323
324    Ok(FuncnameMatcher { rules })
325}
326
327fn builtin_patterns() -> &'static BTreeMap<String, BuiltinPattern> {
328    static BUILTIN_PATTERNS: OnceLock<BTreeMap<String, BuiltinPattern>> = OnceLock::new();
329    BUILTIN_PATTERNS.get_or_init(parse_builtin_patterns)
330}
331
332fn parse_builtin_patterns() -> BTreeMap<String, BuiltinPattern> {
333    BUILTIN_PATTERN_DEFS
334        .iter()
335        .filter(|(name, _, _)| !name.is_empty() && *name != "default")
336        .map(|(name, pattern, ignore_case)| {
337            (
338                (*name).to_owned(),
339                BuiltinPattern {
340                    pattern: (*pattern).to_owned(),
341                    ignore_case: *ignore_case,
342                },
343            )
344        })
345        .collect()
346}
347
348fn bre_to_ere(pattern: &str) -> String {
349    let mut result = String::with_capacity(pattern.len());
350    let chars: Vec<char> = pattern.chars().collect();
351    let mut i = 0usize;
352    let mut in_bracket = false;
353
354    while i < chars.len() {
355        if in_bracket {
356            if chars[i] == ']' && i > 0 {
357                result.push(']');
358                in_bracket = false;
359                i += 1;
360            } else if chars[i] == '[' {
361                result.push('[');
362                i += 1;
363            } else if chars[i] == '\\' {
364                // Preserve literal backslashes inside character classes.
365                // Rust `regex` understands POSIX classes like `[:alnum:]`,
366                // so we only need to escape unknown escapes.
367                if i + 1 < chars.len() {
368                    let next = chars[i + 1];
369                    if next.is_ascii_alphabetic() {
370                        result.push('\\');
371                        result.push('\\');
372                        result.push(next);
373                        i += 2;
374                    } else {
375                        result.push('\\');
376                        result.push(next);
377                        i += 2;
378                    }
379                } else {
380                    result.push('\\');
381                    i += 1;
382                }
383            } else {
384                result.push(chars[i]);
385                i += 1;
386            }
387        } else if chars[i] == '[' {
388            result.push('[');
389            in_bracket = true;
390            i += 1;
391            if i < chars.len() && (chars[i] == '^' || chars[i] == '!') {
392                result.push(chars[i]);
393                i += 1;
394            }
395            if i < chars.len() && chars[i] == ']' {
396                result.push(']');
397                i += 1;
398            }
399        } else if chars[i] == '\\' && i + 1 < chars.len() {
400            match chars[i + 1] {
401                '+' | '?' | '{' | '}' | '(' | ')' | '|' => {
402                    result.push(chars[i + 1]);
403                    i += 2;
404                }
405                _ => {
406                    result.push(chars[i]);
407                    result.push(chars[i + 1]);
408                    i += 2;
409                }
410            }
411        } else if matches!(chars[i], '+' | '?' | '{' | '}' | '(' | ')' | '|') {
412            result.push('\\');
413            result.push(chars[i]);
414            i += 1;
415        } else {
416            result.push(chars[i]);
417            i += 1;
418        }
419    }
420
421    result
422}
423
424fn fix_charclass_escapes(pattern: &str) -> String {
425    let mut result = String::with_capacity(pattern.len());
426    let chars: Vec<char> = pattern.chars().collect();
427    let mut i = 0usize;
428    let mut in_bracket = false;
429
430    while i < chars.len() {
431        if in_bracket {
432            if chars[i] == ']' {
433                result.push(']');
434                in_bracket = false;
435                i += 1;
436            } else if chars[i] == '[' {
437                result.push('[');
438                i += 1;
439            } else if chars[i] == '\\' && i + 1 < chars.len() {
440                let next = chars[i + 1];
441                if next.is_ascii_alphabetic() {
442                    result.push('\\');
443                    result.push('\\');
444                    result.push(next);
445                } else {
446                    result.push('\\');
447                    result.push(next);
448                }
449                i += 2;
450            } else {
451                result.push(chars[i]);
452                i += 1;
453            }
454        } else if chars[i] == '[' {
455            result.push('[');
456            in_bracket = true;
457            i += 1;
458            if i < chars.len() && (chars[i] == '^' || chars[i] == '!') {
459                result.push(chars[i]);
460                i += 1;
461            }
462            if i < chars.len() && chars[i] == ']' {
463                result.push(']');
464                i += 1;
465            }
466        } else if chars[i] == '\\' && i + 1 < chars.len() {
467            result.push(chars[i]);
468            result.push(chars[i + 1]);
469            i += 2;
470        } else {
471            result.push(chars[i]);
472            i += 1;
473        }
474    }
475
476    result
477}
478
479fn validate_posix_regex_via_grep(pattern: &str, ignore_case: bool) -> std::io::Result<()> {
480    let mut cmd = Command::new("grep");
481    cmd.arg("-E").arg("-q");
482    if ignore_case {
483        cmd.arg("-i");
484    }
485    cmd.arg("--").arg(pattern).arg("/dev/null");
486    let status = cmd.status()?;
487    if status.success() || status.code() == Some(1) {
488        Ok(())
489    } else {
490        Err(std::io::Error::new(
491            std::io::ErrorKind::InvalidInput,
492            "invalid regex",
493        ))
494    }
495}
496
497fn posix_line_matches(pattern: &str, ignore_case: bool, line: &str) -> bool {
498    let mut cmd = Command::new("grep");
499    cmd.arg("-E").arg("-q");
500    if ignore_case {
501        cmd.arg("-i");
502    }
503    cmd.arg("--").arg(pattern);
504    cmd.stdin(Stdio::piped());
505    cmd.stdout(Stdio::null());
506    cmd.stderr(Stdio::null());
507
508    let Ok(mut child) = cmd.spawn() else {
509        return false;
510    };
511    if let Some(mut stdin) = child.stdin.take() {
512        let _ = stdin.write_all(line.as_bytes());
513        let _ = stdin.write_all(b"\n");
514    }
515
516    child.wait().map(|status| status.success()).unwrap_or(false)
517}