use crate::{
Bank, BankConfig, Error, Result,
parser::{
CppParser, FileUnit, GoParser, LanguageParser, LanguageType, PythonParser, RustParser,
TypeScriptParser, formatter::Formatter,
},
};
use ignore::WalkBuilder;
use regex::Regex;
use std::cell::OnceCell;
use std::fs;
use std::{ffi::OsStr, path::Path};
#[allow(clippy::declare_interior_mutable_const)]
const REGEX: OnceCell<Regex> = OnceCell::new();
pub struct CodeBank {
rust_parser: RustParser,
python_parser: PythonParser,
typescript_parser: TypeScriptParser,
c_parser: CppParser,
go_parser: GoParser,
}
impl CodeBank {
pub fn try_new() -> Result<Self> {
let rust_parser = RustParser::try_new()?;
let python_parser = PythonParser::try_new()?;
let typescript_parser = TypeScriptParser::try_new()?;
let c_parser = CppParser::try_new()?;
let go_parser = GoParser::try_new()?;
Ok(Self {
rust_parser,
python_parser,
typescript_parser,
c_parser,
go_parser,
})
}
fn detect_language(&self, path: &Path) -> Option<LanguageType> {
match path.extension().and_then(OsStr::to_str) {
Some("rs") => Some(LanguageType::Rust),
Some("py") => Some(LanguageType::Python),
Some("ts") | Some("tsx") | Some("js") | Some("jsx") => Some(LanguageType::TypeScript),
Some("c") | Some("h") | Some("cpp") | Some("hpp") => Some(LanguageType::Cpp),
Some("go") => Some(LanguageType::Go),
_ => Some(LanguageType::Unknown),
}
}
fn parse_file(&mut self, file_path: &Path) -> Result<Option<FileUnit>> {
match self.detect_language(file_path) {
Some(LanguageType::Rust) => self.rust_parser.parse_file(file_path).map(Some),
Some(LanguageType::Python) => self.python_parser.parse_file(file_path).map(Some),
Some(LanguageType::TypeScript) => {
self.typescript_parser.parse_file(file_path).map(Some)
}
Some(LanguageType::Cpp) => self.c_parser.parse_file(file_path).map(Some),
Some(LanguageType::Go) => self.go_parser.parse_file(file_path).map(Some),
Some(LanguageType::Unknown) => Ok(None),
None => Ok(None),
}
}
fn find_and_read_package_file(&self, root_dir: &Path) -> Result<Option<String>> {
const PACKAGE_FILES: &[&str] = &[
"Cargo.toml",
"pyproject.toml",
"setup.py",
"requirements.txt",
"package.json",
"CMakeLists.txt",
"Makefile",
"go.mod",
];
const MAX_DEPTH: usize = 3;
let mut current_dir = root_dir.to_path_buf();
for _ in 0..=MAX_DEPTH {
for filename in PACKAGE_FILES {
let package_path = current_dir.join(filename);
if package_path.is_file() {
match fs::read_to_string(&package_path) {
Ok(content) => return Ok(Some(content)),
Err(e) => return Err(Error::Io(e)),
}
}
}
if !current_dir.pop() {
break; }
}
Ok(None) }
}
impl Bank for CodeBank {
fn generate(&self, config: &BankConfig) -> Result<String> {
let root_dir = &config.root_dir;
if !root_dir.exists() {
return Err(Error::DirectoryNotFound(root_dir.to_path_buf()));
}
if !root_dir.is_dir() {
return Err(Error::InvalidConfig(format!(
"{} is not a directory",
root_dir.display()
)));
}
let mut output = String::new();
output.push_str("# Code Bank\n\n");
match self.find_and_read_package_file(root_dir) {
Ok(Some(content)) => {
output.push_str("## Package File\n\n");
output.push_str("```toml\n"); output.push_str(&content);
output.push_str("\n```\n\n");
}
Ok(None) => { }
Err(e) => {
eprintln!("Warning: Failed to read package file: {}", e);
}
}
let mut code_bank = self.try_clone()?;
let mut file_units = Vec::new();
let walker = WalkBuilder::new(root_dir);
for entry in walker.build().filter_map(|e| e.ok()) {
let path = entry.path();
let should_ignore = config.ignore_dirs.iter().any(|ignored_dir_name| {
path.ancestors().any(|ancestor| {
ancestor
.strip_prefix(root_dir)
.is_ok_and(|p| p.ends_with(ignored_dir_name))
})
});
if should_ignore {
continue;
}
if path.is_file() {
if let Ok(Some(file_unit)) = code_bank.parse_file(path) {
file_units.push(file_unit);
}
}
}
file_units.sort_by(|a, b| a.path.cmp(&b.path));
for file_unit in &file_units {
let relative_path = file_unit
.path
.strip_prefix(root_dir)
.map(|p| p.display().to_string())
.unwrap_or_else(|_| file_unit.path.display().to_string());
let lang = code_bank
.detect_language(&file_unit.path)
.unwrap_or(LanguageType::Unknown);
let formatted_content = file_unit.format(&config.strategy, lang)?;
if !formatted_content.is_empty() {
output.push_str(&format!("## {}\n", relative_path));
output.push_str(&format!("```{}\n", lang.as_str()));
output.push_str(&formatted_content);
output.push_str("```\n\n");
}
}
let regex = REGEX;
let regex = regex.get_or_init(|| Regex::new(r"\n*\s*\n+").unwrap());
output = regex.replace_all(&output, "\n").to_string();
Ok(output)
}
}
impl CodeBank {
fn try_clone(&self) -> Result<Self> {
CodeBank::try_new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
#[test]
fn test_detect_language() {
let code_bank = CodeBank::try_new().unwrap();
let rust_path = PathBuf::from("test.rs");
assert_eq!(
code_bank.detect_language(&rust_path),
Some(LanguageType::Rust)
);
let python_path = PathBuf::from("test.py");
assert_eq!(
code_bank.detect_language(&python_path),
Some(LanguageType::Python)
);
let ts_path = PathBuf::from("test.ts");
assert_eq!(
code_bank.detect_language(&ts_path),
Some(LanguageType::TypeScript)
);
let tsx_path = PathBuf::from("test.tsx");
assert_eq!(
code_bank.detect_language(&tsx_path),
Some(LanguageType::TypeScript)
);
let js_path = PathBuf::from("test.js");
assert_eq!(
code_bank.detect_language(&js_path),
Some(LanguageType::TypeScript)
);
let jsx_path = PathBuf::from("test.jsx");
assert_eq!(
code_bank.detect_language(&jsx_path),
Some(LanguageType::TypeScript)
);
let c_path = PathBuf::from("test.c");
assert_eq!(code_bank.detect_language(&c_path), Some(LanguageType::Cpp));
let h_path = PathBuf::from("test.h");
assert_eq!(code_bank.detect_language(&h_path), Some(LanguageType::Cpp));
let go_path = PathBuf::from("test.go");
assert_eq!(code_bank.detect_language(&go_path), Some(LanguageType::Go));
let unsupported_path = PathBuf::from("test.txt");
assert_eq!(
code_bank.detect_language(&unsupported_path),
Some(LanguageType::Unknown)
);
}
#[test]
fn test_get_language_name() {
let code_bank = CodeBank::try_new().unwrap();
let rust_path = PathBuf::from("test.rs");
let lang = code_bank.detect_language(&rust_path).unwrap();
assert_eq!(lang.as_str(), "rust");
let python_path = PathBuf::from("test.py");
let lang = code_bank.detect_language(&python_path).unwrap();
assert_eq!(lang.as_str(), "python");
let ts_path = PathBuf::from("test.ts");
let lang = code_bank.detect_language(&ts_path).unwrap();
assert_eq!(lang.as_str(), "ts");
let c_path = PathBuf::from("test.c");
let lang = code_bank.detect_language(&c_path).unwrap();
assert_eq!(lang.as_str(), "cpp");
let go_path = PathBuf::from("test.go");
let lang = code_bank.detect_language(&go_path).unwrap();
assert_eq!(lang.as_str(), "go");
let unsupported_path = PathBuf::from("test.txt");
let lang = code_bank.detect_language(&unsupported_path).unwrap();
assert_eq!(lang.as_str(), "unknown");
}
}