opencodesearchparser 0.1.3

Rust Library for Parsing and Segmentation of Source code
Documentation
use anyhow::{anyhow, Context, Result};
use rayon::prelude::*;
use std::fs;
use tree_sitter::Parser;
use tree_sitter_c::LANGUAGE as language_c;
use tree_sitter_cpp::LANGUAGE as language_cpp;
use tree_sitter_javascript::LANGUAGE as language_javascript;
use tree_sitter_python::LANGUAGE as language_python;
use tree_sitter_rust::LANGUAGE as language_rust;
use walkdir::WalkDir;

#[derive(Clone, Copy, Debug, PartialEq)]
pub enum CodeLanguage {
    C,
    Cpp,
    Go,
    Java,
    Kotlin,
    Python,
    JavaScript,
    Js,
    Ts,
    Php,
    Proto,
    R,
    Rust,
}

fn get_file_extension(lang: CodeLanguage) -> Result<&'static str> {
    match lang {
        CodeLanguage::C => Ok("c"),
        CodeLanguage::Cpp => Ok("cpp"),
        CodeLanguage::Python => Ok("py"),
        CodeLanguage::JavaScript | CodeLanguage::Js => Ok("js"),
        CodeLanguage::Rust => Ok("rs"),
        unsupported => Err(anyhow!(
            "Language {:?} is not supported by parse_dir file-extension mapping",
            unsupported
        )),
    }
}

fn is_c_cpp_preprocessor_kind(kind: &str) -> bool {
    matches!(
        kind,
        "preproc_include"
            | "preproc_def"
            | "preproc_function_def"
            | "preproc_call"
            | "preproc_if"
            | "preproc_ifdef"
            | "preproc_elif"
            | "preproc_elifdef"
            | "preproc_else"
    )
}

pub fn parse_str(source: &str, lang: CodeLanguage, thread_num: u16) -> Result<Vec<String>> {
    let mut parser = Parser::new();
    let language = match lang {
        CodeLanguage::C => language_c.into(),
        CodeLanguage::Cpp => language_cpp.into(),
        CodeLanguage::Python => language_python.into(),
        CodeLanguage::JavaScript | CodeLanguage::Js => language_javascript.into(),
        CodeLanguage::Rust => language_rust.into(),
        unsupported => {
            return Err(anyhow!(
                "Language {:?} is not supported by parse_str",
                unsupported
            ));
        }
    };
    parser
        .set_language(&language)
        .map_err(|e| anyhow!("Error loading language: {}", e))?;
    let tree = parser
        .parse(source, None)
        .ok_or_else(|| anyhow!("Failed to parse source"))?;
    let root_node = tree.root_node();
    let num_threads = thread_num as usize;
    let pool = rayon::ThreadPoolBuilder::new()
        .num_threads(num_threads.max(1))
        .build()?;

    let mut segment_ranges: Vec<(usize, usize)> = Vec::new();
    let mut i = 0usize;
    while i < root_node.child_count() {
        if let Some(node) = root_node.child(i as u32) {
            let kind = node.kind();
            let should_keep = match lang {
                CodeLanguage::C => matches!(
                    kind,
                    "function_definition"
                        | "declaration"
                        | "struct_specifier"
                        | "class_specifier"
                        | "enum_specifier"
                        | "union_specifier"
                        | "type_definition"
                ) || is_c_cpp_preprocessor_kind(kind),
                CodeLanguage::Cpp => matches!(
                    kind,
                    "function_definition"
                        | "declaration"
                        | "struct_specifier"
                        | "class_specifier"
                        | "enum_specifier"
                        | "union_specifier"
                        | "type_definition"
                ) || is_c_cpp_preprocessor_kind(kind),
                CodeLanguage::Python => matches!(
                    kind,
                    "function_definition"
                        | "class_definition"
                        | "expression_statement"
                        | "assignment"
                        | "global_statement"
                ),
                CodeLanguage::JavaScript | CodeLanguage::Js => matches!(
                    kind,
                    "function_declaration"
                        | "class_declaration"
                        | "lexical_declaration"
                        | "variable_declaration"
                        | "expression_statement"
                ),
                CodeLanguage::Rust => kind.ends_with("_item") || kind.ends_with("_definition"),
                _ => false,
            };

            if should_keep
                && !kind.contains("comment")
                && kind != "translation_unit"
                && node.byte_range().len() > 0
            {
                let start = node.start_byte();
                let mut end = node.end_byte();
                let mut consumed_next_semicolon = false;

                let should_merge_trailing_semicolon = matches!(
                    kind,
                    "struct_specifier" | "class_specifier" | "enum_specifier" | "union_specifier"
                ) && matches!(lang, CodeLanguage::C | CodeLanguage::Cpp);

                if should_merge_trailing_semicolon {
                    if let Some(next_node) = root_node.child((i + 1) as u32) {
                        if next_node.kind() == ";" {
                            end = next_node.end_byte();
                            consumed_next_semicolon = true;
                        }
                    }
                }

                if end <= source.len() {
                    segment_ranges.push((start, end));
                }

                if consumed_next_semicolon {
                    i += 1;
                }
            }
        }
        i += 1;
    }

    let segments: Vec<String> = pool.install(|| {
        segment_ranges
            .par_iter()
            .filter_map(|(start, end)| {
                let text = source[*start..*end].to_string();
                if text.trim().is_empty() {
                    None
                } else {
                    Some(text)
                }
            })
            .collect()
    });

    Ok(segments)
}

pub fn parse_file(file_path: &str, lang: CodeLanguage, thread_num: u16) -> Result<Vec<String>> {
    let content = fs::read_to_string(file_path)?;
    parse_str(&content, lang, thread_num)
}

pub fn parse_dir(dir_path: &str, lang: CodeLanguage, thread_num: u16) -> Result<Vec<String>> {
    let extension = get_file_extension(lang)?;
    let mut files: Vec<String> = Vec::new();
    for entry in WalkDir::new(dir_path).into_iter() {
        let entry = entry
            .map_err(|e| anyhow!("Failed to walk directory '{}': {}", dir_path, e))?;
        if entry.path().extension().map_or(false, |ext| ext == extension) {
            files.push(entry.path().to_string_lossy().to_string());
        }
    }

    let num_threads = thread_num as usize;
    let pool = rayon::ThreadPoolBuilder::new()
        .num_threads(num_threads.max(1))
        .build()?;

    let parsed_segments: Result<Vec<Vec<String>>> = pool.install(|| {
        files
            .par_iter()
            .map(|file| {
                parse_file(file, lang, thread_num)
                    .with_context(|| format!("Failed to parse file '{}'", file))
            })
            .collect()
    });

    Ok(parsed_segments?.into_iter().flatten().collect())
}

pub mod recursive_character_text_splitter;

#[cfg(test)]
mod tests;