agent-file-tools 0.35.4

Agent File Tools — tree-sitter powered code analysis for AI agents
Documentation
use regex::bytes::{Regex, RegexBuilder};

const DEFAULT_SIZE_LIMIT_BYTES: usize = 10 * 1024 * 1024;

#[derive(Clone, Debug)]
pub enum CompiledPattern {
    Literal(LiteralSearch),
    Regex {
        compiled: Regex,
        raw_pattern: String,
        case_insensitive: bool,
    },
}

#[derive(Clone, Debug, PartialEq, Eq)]
pub struct LiteralSearch {
    pub needle: Vec<u8>,
    pub case_insensitive_ascii: bool,
}

#[derive(Clone, Debug)]
pub struct CompileOpts {
    pub literal: bool,
    pub case_insensitive: bool,
    pub multi_line: bool,
    pub size_limit_bytes: usize,
}

impl Default for CompileOpts {
    fn default() -> Self {
        Self {
            literal: false,
            case_insensitive: false,
            multi_line: true,
            size_limit_bytes: DEFAULT_SIZE_LIMIT_BYTES,
        }
    }
}

#[derive(Clone, Debug, PartialEq, Eq)]
pub enum CompileResult {
    Ok(CompiledPattern),
    InvalidPattern { message: String, pattern: String },
    UnsupportedSyntax { feature: String, pattern: String },
}

impl PartialEq for CompiledPattern {
    fn eq(&self, other: &Self) -> bool {
        match (self, other) {
            (CompiledPattern::Literal(left), CompiledPattern::Literal(right)) => left == right,
            (
                CompiledPattern::Regex {
                    raw_pattern: left_pattern,
                    case_insensitive: left_case,
                    ..
                },
                CompiledPattern::Regex {
                    raw_pattern: right_pattern,
                    case_insensitive: right_case,
                    ..
                },
            ) => left_pattern == right_pattern && left_case == right_case,
            _ => false,
        }
    }
}

impl Eq for CompiledPattern {}

impl CompiledPattern {
    pub fn is_literal(&self) -> bool {
        matches!(self, CompiledPattern::Literal(_))
    }

    pub fn case_insensitive(&self) -> bool {
        match self {
            CompiledPattern::Literal(literal) => literal.case_insensitive_ascii,
            CompiledPattern::Regex {
                case_insensitive, ..
            } => *case_insensitive,
        }
    }

    pub fn raw_pattern_for_trigrams(&self) -> String {
        match self {
            CompiledPattern::Literal(literal) => {
                String::from_utf8_lossy(&literal.needle).into_owned()
            }
            CompiledPattern::Regex { raw_pattern, .. } => raw_pattern.clone(),
        }
    }

    pub fn ripgrep_pattern(&self) -> String {
        match self {
            CompiledPattern::Literal(literal) => {
                String::from_utf8_lossy(&literal.needle).into_owned()
            }
            CompiledPattern::Regex { raw_pattern, .. } => raw_pattern.clone(),
        }
    }
}

pub fn compile(pattern: &str, opts: CompileOpts) -> CompileResult {
    if pattern.len() > opts.size_limit_bytes {
        return CompileResult::InvalidPattern {
            message: format!(
                "invalid regex: pattern exceeds size limit of {} bytes",
                opts.size_limit_bytes
            ),
            pattern: pattern.to_string(),
        };
    }

    if !opts.literal {
        if let Some(feature) = detect_unsupported_features(pattern) {
            return CompileResult::UnsupportedSyntax {
                feature,
                pattern: pattern.to_string(),
            };
        }
    }

    let has_regex_meta = has_regex_metachar(pattern);
    let ascii_safe_literal = opts.case_insensitive && pattern.is_ascii();
    if opts.literal || (!has_regex_meta && (!opts.case_insensitive || ascii_safe_literal)) {
        if !opts.case_insensitive || pattern.is_ascii() {
            let needle = if opts.case_insensitive {
                pattern
                    .as_bytes()
                    .iter()
                    .map(|byte| byte.to_ascii_lowercase())
                    .collect()
            } else {
                pattern.as_bytes().to_vec()
            };
            return CompileResult::Ok(CompiledPattern::Literal(LiteralSearch {
                needle,
                case_insensitive_ascii: opts.case_insensitive,
            }));
        }
    }

    let mut regex_pattern = if opts.literal || !has_regex_meta {
        regex::escape(pattern)
    } else {
        pattern.to_string()
    };
    let mut builder_case_insensitive = opts.case_insensitive;
    if opts.case_insensitive && !pattern.is_ascii() {
        regex_pattern = format!("(?i){regex_pattern}");
        builder_case_insensitive = false;
    }

    let mut builder = RegexBuilder::new(&regex_pattern);
    builder.case_insensitive(builder_case_insensitive);
    builder.multi_line(opts.multi_line);
    builder.size_limit(opts.size_limit_bytes);

    match builder.build() {
        Ok(compiled) => CompileResult::Ok(CompiledPattern::Regex {
            compiled,
            raw_pattern: regex_pattern,
            case_insensitive: opts.case_insensitive,
        }),
        Err(error) => CompileResult::InvalidPattern {
            message: format!("invalid regex: {error}"),
            pattern: pattern.to_string(),
        },
    }
}

pub fn detect_unsupported_features(pattern: &str) -> Option<String> {
    if pattern.contains("(?=")
        || pattern.contains("(?!")
        || pattern.contains("(?<=")
        || pattern.contains("(?<!")
    {
        return Some("lookaround".to_string());
    }
    if pattern.contains("(?P=") || contains_numeric_backreference(pattern) {
        return Some("backreference".to_string());
    }
    if pattern.contains("*+") || pattern.contains("++") || pattern.contains("?+") {
        return Some("possessive quantifier".to_string());
    }
    if pattern.contains("(?>") {
        return Some("atomic group".to_string());
    }
    None
}

fn has_regex_metachar(pattern: &str) -> bool {
    pattern.chars().any(|c| {
        matches!(
            c,
            '.' | '*' | '+' | '?' | '(' | ')' | '[' | ']' | '{' | '}' | '|' | '^' | '$' | '\\'
        )
    })
}

fn contains_numeric_backreference(pattern: &str) -> bool {
    let mut escaped = false;
    for ch in pattern.chars() {
        if escaped {
            if ('1'..='9').contains(&ch) {
                return true;
            }
            escaped = false;
            continue;
        }
        escaped = ch == '\\';
    }
    false
}

#[cfg(test)]
mod tests {
    use super::*;

    fn assert_literal(pattern: &str, case_insensitive: bool, expected: &[u8]) {
        let result = compile(
            pattern,
            CompileOpts {
                case_insensitive,
                ..CompileOpts::default()
            },
        );
        match result {
            CompileResult::Ok(CompiledPattern::Literal(literal)) => {
                assert_eq!(literal.needle, expected);
                assert_eq!(literal.case_insensitive_ascii, case_insensitive);
            }
            other => panic!("expected literal, got {other:?}"),
        }
    }

    #[test]
    fn literal_pattern_without_metachars_uses_fast_path() {
        assert_literal("needle", false, b"needle");
    }

    #[test]
    fn ascii_case_insensitive_literal_uses_lowercase_fast_path() {
        assert_literal("Needle", true, b"needle");
    }

    #[test]
    fn non_ascii_case_insensitive_literal_forces_regex_with_inline_flag() {
        let result = compile(
            "Äbc",
            CompileOpts {
                case_insensitive: true,
                ..CompileOpts::default()
            },
        );
        match result {
            CompileResult::Ok(CompiledPattern::Regex {
                raw_pattern,
                case_insensitive,
                ..
            }) => {
                assert!(raw_pattern.starts_with("(?i)"));
                assert!(case_insensitive);
            }
            other => panic!("expected regex, got {other:?}"),
        }
    }

    #[test]
    fn regex_pattern_retains_raw_pattern_and_compiles_bytes_regex() {
        let result = compile("foo.*bar", CompileOpts::default());
        match result {
            CompileResult::Ok(CompiledPattern::Regex {
                compiled,
                raw_pattern,
                ..
            }) => {
                assert_eq!(raw_pattern, "foo.*bar");
                assert!(compiled.is_match(b"foo middle bar"));
            }
            other => panic!("expected regex, got {other:?}"),
        }
    }

    #[test]
    fn invalid_pattern_surfaces_compile_error() {
        let result = compile("[", CompileOpts::default());
        assert!(matches!(result, CompileResult::InvalidPattern { .. }));
    }

    #[test]
    fn pattern_exceeding_size_limit_is_invalid() {
        let result = compile(
            "abcd",
            CompileOpts {
                size_limit_bytes: 3,
                ..CompileOpts::default()
            },
        );
        assert!(matches!(result, CompileResult::InvalidPattern { .. }));
    }

    #[test]
    fn unsupported_syntax_is_detected_before_compile() {
        for pattern in [
            "(?=foo)",
            "(?!foo)",
            "(?<=foo)",
            "(?<!foo)",
            "(?P=name)",
            r"\1",
            "foo*+",
            "(?>foo)",
        ] {
            assert!(
                matches!(
                    compile(pattern, CompileOpts::default()),
                    CompileResult::UnsupportedSyntax { .. }
                ),
                "{pattern}"
            );
        }
    }

    #[test]
    fn forced_literal_honors_regex_characters() {
        let result = compile(
            "foo.*bar",
            CompileOpts {
                literal: true,
                ..CompileOpts::default()
            },
        );
        match result {
            CompileResult::Ok(CompiledPattern::Literal(literal)) => {
                assert_eq!(literal.needle, b"foo.*bar");
            }
            other => panic!("expected literal, got {other:?}"),
        }
    }
}