#![cfg_attr(coverage_nightly, coverage(off))]
use anyhow::Result;
use std::collections::HashMap;
use std::path::Path;
use std::time::Duration;
use tracing::debug;
#[derive(Debug, Clone, PartialEq)]
pub struct LanguageDetection {
pub language: String,
pub confidence: f64,
}
#[derive(Debug, Clone)]
pub struct MultiLanguageDetection {
pub languages: Vec<LanguageInfo>,
pub primary: String,
}
#[derive(Debug, Clone, PartialEq)]
pub struct LanguageInfo {
pub language: String,
pub confidence: f64,
pub file_count: usize,
pub percentage: f64,
}
pub fn detect_project_language_enhanced(path: &Path) -> LanguageDetection {
debug!("Detecting project language at: {:?}", path);
let mut scores: HashMap<String, f64> = HashMap::new();
if path.join("Cargo.toml").exists() {
*scores.entry("rust".to_string()).or_insert(0.0) += 90.0;
debug!("Found Cargo.toml - boosting Rust confidence by 90");
}
if path.join("CMakeLists.txt").exists() {
*scores.entry("cpp".to_string()).or_insert(0.0) += 85.0;
debug!("Found CMakeLists.txt - boosting C++ confidence by 85");
}
if path.join("package.json").exists() {
*scores.entry("javascript".to_string()).or_insert(0.0) += 30.0;
*scores.entry("typescript".to_string()).or_insert(0.0) += 30.0;
debug!("Found package.json - boosting JS/TS confidence by 30");
}
if path.join("pyproject.toml").exists() {
*scores.entry("python".to_string()).or_insert(0.0) += 50.0;
debug!("Found pyproject.toml - boosting Python confidence by 50");
}
if path.join("go.mod").exists() {
*scores.entry("go".to_string()).or_insert(0.0) += 90.0;
debug!("Found go.mod - boosting Go confidence by 90");
}
if path.join("lakefile.lean").exists() || path.join("lean-toolchain").exists() {
*scores.entry("lean".to_string()).or_insert(0.0) += 90.0;
debug!("Found lakefile.lean/lean-toolchain - boosting Lean confidence by 90");
}
let file_counts = count_files_by_extension(path);
let total_files: usize = file_counts.values().sum();
if total_files == 0 {
debug!("No files found, returning unknown");
return LanguageDetection {
language: "unknown".to_string(),
confidence: 0.0,
};
}
debug!("Total files: {}, counts: {:?}", total_files, file_counts);
for (ext, count) in file_counts.iter() {
let percentage = (*count as f64 / total_files as f64) * 100.0;
if let Some(lang) = extension_to_language(ext) {
*scores.entry(lang.to_string()).or_insert(0.0) += percentage;
debug!(
"Extension {} ({} files, {:.1}%) maps to {}, adding {:.1} to score",
ext, count, percentage, lang, percentage
);
}
}
let (best_lang, best_score) = scores
.iter()
.max_by(|a, b| a.1.total_cmp(b.1))
.map(|(lang, score)| (lang.clone(), *score))
.unwrap_or_else(|| ("unknown".to_string(), 0.0));
debug!("Best language: {} with score: {:.1}", best_lang, best_score);
let confidence = best_score.min(100.0);
LanguageDetection {
language: best_lang,
confidence,
}
}
fn compute_confidence_boost(lang: &str, path: &Path) -> f64 {
if lang == "rust" && path.join("Cargo.toml").exists() {
10.0
} else if (lang == "cpp" || lang == "c") && path.join("CMakeLists.txt").exists() {
10.0
} else if (lang == "javascript" || lang == "typescript") && path.join("package.json").exists() {
5.0
} else {
0.0
}
}
pub fn detect_all_languages(path: &Path) -> MultiLanguageDetection {
debug!("Detecting all languages at: {:?}", path);
let file_counts = count_files_by_extension(path);
let total_files: usize = file_counts.values().sum();
if total_files == 0 {
return MultiLanguageDetection {
languages: vec![],
primary: "unknown".to_string(),
};
}
let mut languages = Vec::new();
let mut lang_counts: HashMap<String, usize> = HashMap::new();
for (ext, count) in file_counts.iter() {
if let Some(lang) = extension_to_language(ext) {
*lang_counts.entry(lang.to_string()).or_insert(0) += count;
}
}
for (lang, count) in lang_counts.iter() {
let percentage = (*count as f64 / total_files as f64) * 100.0;
if percentage > 5.0 {
let mut confidence = percentage;
confidence += compute_confidence_boost(lang, path);
languages.push(LanguageInfo {
language: lang.clone(),
confidence: confidence.min(100.0),
file_count: *count,
percentage,
});
}
}
languages.sort_by(|a, b| {
b.percentage
.partial_cmp(&a.percentage)
.expect("internal error")
});
let primary = languages
.first()
.map(|l| l.language.clone())
.unwrap_or_else(|| "unknown".to_string());
debug!(
"Detected {} languages, primary: {}",
languages.len(),
primary
);
MultiLanguageDetection { languages, primary }
}
pub fn detect_project_language_with_timeout(
path: &Path,
_timeout: Duration,
) -> Result<LanguageDetection> {
Ok(detect_project_language_enhanced(path))
}
pub fn override_language_detection(_path: &Path, language: &str) -> LanguageDetection {
LanguageDetection {
language: language.to_string(),
confidence: 100.0,
}
}
pub fn override_multiple_languages(path: &Path, languages: Vec<String>) -> MultiLanguageDetection {
let file_counts = count_files_by_extension(path);
let total_files: usize = file_counts.values().sum();
let language_infos: Vec<LanguageInfo> = languages
.into_iter()
.map(|lang| {
let count = file_counts
.iter()
.filter(|(ext, _)| {
extension_to_language(ext).map(|s| s.to_string()) == Some(lang.clone())
})
.map(|(_, c)| *c)
.sum();
let percentage = if total_files > 0 {
(count as f64 / total_files as f64) * 100.0
} else {
0.0
};
LanguageInfo {
language: lang,
confidence: 100.0, file_count: count,
percentage,
}
})
.collect();
let primary = language_infos
.first()
.map(|l| l.language.clone())
.unwrap_or_else(|| "unknown".to_string());
MultiLanguageDetection {
languages: language_infos,
primary,
}
}
fn count_files_by_extension(path: &Path) -> HashMap<String, usize> {
use walkdir::WalkDir;
let mut counts: HashMap<String, usize> = HashMap::new();
for entry in WalkDir::new(path)
.max_depth(10)
.follow_links(false)
.into_iter()
.filter_map(|e| e.ok())
{
if entry.file_type().is_file() {
if let Some(ext) = entry.path().extension().and_then(|e| e.to_str()) {
*counts.entry(ext.to_string()).or_insert(0) += 1;
}
}
}
counts
}
fn extension_to_language(ext: &str) -> Option<&'static str> {
match ext {
"rs" => Some("rust"),
"py" => Some("python"),
"js" | "jsx" => Some("javascript"),
"ts" | "tsx" => Some("typescript"),
"c" | "h" => Some("c"),
"cc" | "cpp" | "cxx" | "hpp" | "hxx" | "h++" | "c++" => Some("cpp"),
"go" => Some("go"),
"java" => Some("java"),
"kt" | "kts" => Some("kotlin"),
"rb" => Some("ruby"),
"php" => Some("php"),
"swift" => Some("swift"),
"cs" => Some("csharp"),
"sh" | "bash" => Some("bash"),
"lua" => Some("lua"),
"lean" => Some("lean"),
"sql" | "ddl" | "dml" => Some("sql"),
"scala" | "sc" | "sbt" => Some("scala"),
"md" | "mdx" | "markdown" => Some("markdown"),
"yaml" | "yml" => Some("yaml"),
_ => None,
}
}
#[cfg_attr(coverage_nightly, coverage(off))]
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_extension_to_language_rust() {
assert_eq!(extension_to_language("rs"), Some("rust"));
}
#[test]
fn test_extension_to_language_cpp() {
assert_eq!(extension_to_language("cpp"), Some("cpp"));
assert_eq!(extension_to_language("cc"), Some("cpp"));
assert_eq!(extension_to_language("cxx"), Some("cpp"));
}
#[test]
fn test_detect_rust_project_with_cargo_toml() {
let temp = TempDir::new().expect("internal error");
std::fs::write(
temp.path().join("Cargo.toml"),
"[package]\nname = \"test\"\n",
)
.expect("internal error");
std::fs::create_dir_all(temp.path().join("src")).expect("internal error");
std::fs::write(temp.path().join("src/main.rs"), "fn main() {}").expect("internal error");
let detection = detect_project_language_enhanced(temp.path());
assert_eq!(detection.language, "rust");
assert!(detection.confidence >= 90.0);
}
#[test]
fn test_detect_cpp_project_with_cmake() {
let temp = TempDir::new().expect("internal error");
std::fs::write(
temp.path().join("CMakeLists.txt"),
"cmake_minimum_required(VERSION 3.10)\n",
)
.expect("internal error");
std::fs::create_dir_all(temp.path().join("src")).expect("internal error");
std::fs::write(temp.path().join("src/main.cpp"), "int main() {}").expect("internal error");
let detection = detect_project_language_enhanced(temp.path());
assert_eq!(detection.language, "cpp");
assert!(detection.confidence >= 85.0);
}
#[test]
fn test_multi_language_detection() {
let temp = TempDir::new().expect("internal error");
std::fs::create_dir_all(temp.path().join("src")).expect("internal error");
for i in 0..50 {
std::fs::write(
temp.path().join(format!("src/file_{}.rs", i)),
"fn main() {}",
)
.expect("internal error");
}
for i in 0..30 {
std::fs::write(
temp.path().join(format!("src/tool_{}.py", i)),
"print('hello')",
)
.expect("internal error");
}
let detection = detect_all_languages(temp.path());
assert_eq!(detection.languages.len(), 2);
assert_eq!(detection.primary, "rust");
}
}