use super::{
language_matches, license_matches, CodeReference, MiningError, MiningQuery, ReferenceMiner,
};
use async_trait::async_trait;
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::sync::Mutex;
const SKIP_DIRS: &[&str] = &[
".git",
"node_modules",
"target",
"dist",
"build",
".venv",
"venv",
"__pycache__",
".next",
".cache",
".idea",
".vscode",
];
const TEXT_EXTS: &[&str] = &[
"rs", "ts", "tsx", "js", "jsx", "mjs", "cjs", "py", "go", "c", "h", "cpp", "cc", "cxx", "hpp",
"hh", "hxx", "java", "rb", "swift", "kt", "kts", "md", "toml", "yaml", "yml", "json",
];
const MAX_FILE_BYTES: u64 = 1_000_000;
pub struct LocalCloneSource {
roots: Vec<PathBuf>,
cache: Arc<Mutex<HashMap<PathBuf, RepoMeta>>>,
}
#[derive(Clone)]
struct RepoMeta {
repo_id: String,
commit: String,
license: Option<String>,
}
impl LocalCloneSource {
pub fn new(roots: Vec<PathBuf>) -> Self {
Self {
roots,
cache: Arc::new(Mutex::new(HashMap::new())),
}
}
fn tokenize(query: &str) -> Vec<String> {
const STOP: &[&str] = &[
"the", "a", "an", "of", "to", "and", "or", "for", "in", "on", "is", "how",
];
query
.split(|c: char| !c.is_alphanumeric() && c != '_')
.filter(|t| !t.is_empty() && t.len() > 1)
.map(|t| t.to_ascii_lowercase())
.filter(|t| !STOP.contains(&t.as_str()))
.collect()
}
}
#[async_trait]
impl ReferenceMiner for LocalCloneSource {
fn name(&self) -> &str {
"local_clone"
}
async fn search(&self, query: &MiningQuery) -> Result<Vec<CodeReference>, MiningError> {
if query.query.trim().is_empty() {
return Err(MiningError::InvalidQuery("empty query".into()));
}
let roots = self.roots.clone();
let filters = query.filters.clone();
let terms = Self::tokenize(&query.query);
if terms.is_empty() {
return Err(MiningError::InvalidQuery(
"query contained no meaningful terms".into(),
));
}
let cache = Arc::clone(&self.cache);
let query_text = query.query.clone();
let hits = tokio::task::spawn_blocking(move || -> Result<Vec<CodeReference>, MiningError> {
let mut out: Vec<CodeReference> = Vec::new();
for root in &roots {
if !root.exists() {
continue;
}
let repo_roots = discover_repo_roots(root);
for repo_root in repo_roots {
let meta = load_repo_meta(&repo_root, &cache);
let mut files = Vec::new();
collect_files(&repo_root, &mut files);
for path in files {
let rel = path
.strip_prefix(&repo_root)
.unwrap_or(&path)
.to_string_lossy()
.replace('\\', "/");
if !language_matches(&filters.languages, &rel) {
continue;
}
if !license_matches(&filters.license_allowlist, meta.license.as_deref()) {
continue;
}
let Ok(md) = std::fs::metadata(&path) else {
continue;
};
if md.len() > MAX_FILE_BYTES {
continue;
}
let Ok(content) = std::fs::read_to_string(&path) else {
continue;
};
if let Some((snippet, score, why)) =
best_match(&content, &terms, &query_text)
{
out.push(CodeReference {
repo: meta.repo_id.clone(),
commit: meta.commit.clone(),
path: rel,
snippet,
score,
license: meta.license.clone(),
why_relevant: why,
});
}
}
}
}
out.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
let limit = if filters.max_results == 0 {
out.len()
} else {
filters.max_results
};
out.truncate(limit);
Ok(out)
})
.await
.map_err(|e| MiningError::Other(anyhow::anyhow!("walk task panicked: {e}")))??;
Ok(hits)
}
}
fn discover_repo_roots(root: &Path) -> Vec<PathBuf> {
if looks_like_repo(root) {
return vec![root.to_path_buf()];
}
let mut roots = Vec::new();
if let Ok(entries) = std::fs::read_dir(root) {
for entry in entries.flatten() {
let p = entry.path();
if p.is_dir() && !SKIP_DIRS.contains(&p.file_name().and_then(|s| s.to_str()).unwrap_or(""))
{
roots.push(p);
}
}
}
if roots.is_empty() {
roots.push(root.to_path_buf());
}
roots
}
fn looks_like_repo(p: &Path) -> bool {
if p.join(".git").exists() {
return true;
}
for candidate in ["LICENSE", "LICENSE.md", "LICENSE.txt", "COPYING", "COPYING.md"] {
if p.join(candidate).exists() {
return true;
}
}
false
}
fn collect_files(dir: &Path, out: &mut Vec<PathBuf>) {
let Ok(rd) = std::fs::read_dir(dir) else {
return;
};
for entry in rd.flatten() {
let path = entry.path();
let name = match path.file_name().and_then(|n| n.to_str()) {
Some(n) => n,
None => continue,
};
if path.is_dir() {
if SKIP_DIRS.contains(&name) {
continue;
}
collect_files(&path, out);
} else if path.is_file() {
let ext = path
.extension()
.and_then(|s| s.to_str())
.unwrap_or("")
.to_ascii_lowercase();
if TEXT_EXTS.contains(&ext.as_str()) {
out.push(path);
}
}
}
}
fn load_repo_meta(repo_root: &Path, cache: &Arc<Mutex<HashMap<PathBuf, RepoMeta>>>) -> RepoMeta {
let mut guard = cache.lock().expect("repo-meta cache poisoned");
if let Some(m) = guard.get(repo_root) {
return m.clone();
}
let commit = read_head_commit(repo_root).unwrap_or_else(|| "HEAD".to_string());
let license = detect_license(repo_root);
let repo_id = format!(
"local:{}",
repo_root
.file_name()
.and_then(|s| s.to_str())
.unwrap_or("<unknown>")
);
let meta = RepoMeta {
repo_id,
commit,
license,
};
guard.insert(repo_root.to_path_buf(), meta.clone());
meta
}
fn read_head_commit(repo_root: &Path) -> Option<String> {
use std::process::Command;
let output = Command::new("git")
.args(["rev-parse", "HEAD"])
.current_dir(repo_root)
.output()
.ok()?;
if !output.status.success() {
return None;
}
let s = String::from_utf8_lossy(&output.stdout).trim().to_string();
if s.is_empty() {
None
} else {
Some(s)
}
}
fn detect_license(repo_root: &Path) -> Option<String> {
for candidate in ["LICENSE", "LICENSE.md", "LICENSE.txt", "COPYING", "COPYING.md"] {
let p = repo_root.join(candidate);
if let Ok(text) = std::fs::read_to_string(&p) {
if let Some(id) = spdx_from_text(&text) {
return Some(id);
}
}
}
None
}
fn spdx_from_text(text: &str) -> Option<String> {
for line in text.lines().take(20) {
if let Some(rest) = line.split_once("SPDX-License-Identifier:") {
let id = rest.1.trim().trim_end_matches('*').trim().to_string();
if !id.is_empty() {
return Some(id);
}
}
}
let lower = text.to_ascii_lowercase();
let haystack = &lower[..lower.len().min(4096)];
if haystack.contains("apache license") && haystack.contains("version 2.0") {
return Some("Apache-2.0".into());
}
if haystack.contains("mit license") {
return Some("MIT".into());
}
if haystack.contains("gnu general public license") {
if haystack.contains("version 3") {
return Some("GPL-3.0".into());
}
if haystack.contains("version 2") {
return Some("GPL-2.0".into());
}
}
if haystack.contains("mozilla public license") && haystack.contains("2.0") {
return Some("MPL-2.0".into());
}
if haystack.contains("bsd 3-clause") {
return Some("BSD-3-Clause".into());
}
if haystack.contains("bsd 2-clause") {
return Some("BSD-2-Clause".into());
}
None
}
fn best_match(content: &str, terms: &[String], query: &str) -> Option<(String, f32, String)> {
let lines: Vec<&str> = content.lines().collect();
let mut best_line: Option<(usize, usize)> = None; for (i, line) in lines.iter().enumerate() {
let lower = line.to_ascii_lowercase();
let hits = terms.iter().filter(|t| lower.contains(t.as_str())).count();
if hits == 0 {
continue;
}
match best_line {
Some((_, best_hits)) if hits <= best_hits => {}
_ => best_line = Some((i, hits)),
}
}
let (idx, hits) = best_line?;
let start = idx.saturating_sub(5);
let end = (idx + 5 + 1).min(lines.len());
let snippet = lines[start..end].join("\n");
let coverage = hits as f32 / terms.len().max(1) as f32;
let density = (hits as f32).min(5.0) / 5.0;
let score = (0.7 * coverage + 0.3 * density).clamp(0.0, 1.0);
let matched: Vec<&str> = terms
.iter()
.filter(|t| {
lines[idx]
.to_ascii_lowercase()
.contains(t.as_str())
})
.map(|s| s.as_str())
.collect();
let why = format!(
"Line {} matched {}/{} terms from \"{}\": [{}]",
idx + 1,
hits,
terms.len(),
truncate(query, 60),
matched.join(", ")
);
Some((snippet, score, why))
}
fn truncate(s: &str, n: usize) -> String {
if s.chars().count() <= n {
s.to_string()
} else {
let mut out: String = s.chars().take(n).collect();
out.push('…');
out
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn spdx_header_detected() {
assert_eq!(
spdx_from_text("// SPDX-License-Identifier: MIT\nrest"),
Some("MIT".into())
);
}
#[test]
fn apache_preamble_detected() {
let t = " Apache License\n Version 2.0, January 2004\n";
assert_eq!(spdx_from_text(t), Some("Apache-2.0".into()));
}
#[test]
fn best_match_picks_densest_line() {
let content = "first line\nrate limiter token bucket here\nunrelated\n";
let terms = vec!["rate".to_string(), "limiter".to_string(), "token".to_string()];
let (snippet, score, _) = best_match(content, &terms, "rate limiter").unwrap();
assert!(snippet.contains("rate limiter token bucket"));
assert!(score > 0.5);
}
}