use std::collections::HashMap;
use std::sync::mpsc;
use std::time::Duration;
use crate::core::graph_provider;
use crate::core::tokens::count_tokens;
use crate::tools::CrpMode;
const DEFAULT_SEMANTIC_BUDGET_MS: u64 = 2500;
fn semantic_budget() -> Duration {
let ms = std::env::var("LEAN_CTX_COMPOSE_BUDGET_MS")
.ok()
.and_then(|v| v.parse::<u64>().ok())
.filter(|&v| v > 0)
.unwrap_or(DEFAULT_SEMANTIC_BUDGET_MS);
Duration::from_millis(ms)
}
const DEFAULT_SYMBOL_BUDGET_TOKENS: usize = 600;
fn symbol_budget_tokens() -> usize {
std::env::var("LEAN_CTX_COMPOSE_SYMBOL_TOKENS")
.ok()
.and_then(|v| v.parse::<usize>().ok())
.filter(|&v| v > 0)
.unwrap_or(DEFAULT_SYMBOL_BUDGET_TOKENS)
}
const DEFAULT_GRAPH_BUDGET_MS: u64 = 1500;
fn graph_budget() -> Duration {
let ms = std::env::var("LEAN_CTX_COMPOSE_GRAPH_BUDGET_MS")
.ok()
.and_then(|v| v.parse::<u64>().ok())
.filter(|&v| v > 0)
.unwrap_or(DEFAULT_GRAPH_BUDGET_MS);
Duration::from_millis(ms)
}
const SPREAD_DECAY: f64 = 0.6;
const SPREAD_HOPS: usize = 3;
const SPREAD_TOP_K: usize = 8;
fn build_associative_block(project_root: &str, keywords: &[String]) -> String {
let Some(open) = graph_provider::open_or_build(project_root) else {
return String::new();
};
let gp = &open.provider;
let mut seed_files: Vec<String> = Vec::new();
for kw in keywords {
for sym in gp.find_symbols(kw, None, None) {
if !seed_files.contains(&sym.file) {
seed_files.push(sym.file);
}
}
}
if seed_files.is_empty() {
return String::new();
}
crate::core::cooccurrence::record_access(project_root, &seed_files);
let mut adjacency: HashMap<String, Vec<(String, f64)>> = HashMap::new();
let mut add_edge = |a: &str, b: &str, w: f64| {
adjacency
.entry(a.to_string())
.or_default()
.push((b.to_string(), w));
adjacency
.entry(b.to_string())
.or_default()
.push((a.to_string(), w));
};
for e in gp.edges() {
add_edge(&e.from, &e.to, if e.weight > 0.0 { e.weight } else { 1.0 });
}
let coaccess = crate::core::cooccurrence::load(project_root);
for sf in &seed_files {
for (nbr, w) in coaccess.related(sf, 16) {
add_edge(sf, &nbr, w);
}
}
let seeds: HashMap<String, f64> = seed_files.iter().map(|f| (f.clone(), 1.0)).collect();
let ranked = crate::core::spreading_activation::related_ranked(
&seeds,
&adjacency,
SPREAD_DECAY,
SPREAD_HOPS,
SPREAD_TOP_K,
);
if ranked.is_empty() {
return String::new();
}
let mut s = String::from("\n## Related (associative: import/call graph + learned co-access)\n");
for (file, activation) in ranked {
s.push_str(&format!("- {file} (activation {activation:.2})\n"));
}
s
}
fn associative_block_budgeted(project_root: &str, keywords: &[String]) -> String {
if keywords.is_empty() {
return String::new();
}
let (tx, rx) = mpsc::channel::<String>();
let root = project_root.to_string();
let kws = keywords.to_vec();
std::thread::spawn(move || {
let block = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
build_associative_block(&root, &kws)
}))
.unwrap_or_else(|_| {
tracing::warn!("[ctx_compose: associative block panicked; omitting section]");
String::new()
});
let _ = tx.send(block);
});
rx.recv_timeout(graph_budget()).unwrap_or_default()
}
const STOPWORDS: &[&str] = &[
"the",
"and",
"for",
"with",
"that",
"this",
"from",
"into",
"how",
"where",
"what",
"does",
"are",
"was",
"use",
"used",
"uses",
"add",
"all",
"any",
"can",
"get",
"set",
"via",
"out",
"its",
"his",
"her",
"you",
"your",
"our",
"find",
"show",
"list",
"make",
"when",
"then",
"has",
"have",
"had",
"not",
"but",
"see",
"function",
"method",
"class",
"code",
"file",
"files",
"implement",
"implementation",
];
fn extract_keywords(task: &str, max: usize) -> Vec<String> {
let mut seen = std::collections::HashSet::new();
let mut out = Vec::new();
for raw in task.split(|c: char| !(c.is_alphanumeric() || c == '_')) {
if raw.len() < 3 {
continue;
}
if STOPWORDS.contains(&raw.to_ascii_lowercase().as_str()) {
continue;
}
if seen.insert(raw.to_string()) {
out.push(raw.to_string());
if out.len() >= max {
break;
}
}
}
out
}
fn ranked_files_budgeted(task: &str, project_root: &str, crp_mode: CrpMode) -> String {
let shared_cache = crate::tools::ctx_semantic_search::get_thread_cache();
let (tx, rx) = mpsc::channel::<String>();
let task_owned = task.to_string();
let root_owned = project_root.to_string();
std::thread::spawn(move || {
if let Some(cache) = shared_cache {
crate::tools::ctx_semantic_search::set_thread_cache(cache);
}
let ranked = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
crate::tools::ctx_semantic_search::handle(
&task_owned,
&root_owned,
8,
crp_mode,
None,
None,
None,
Some(false),
Some(false),
)
}))
.unwrap_or_else(|_| {
tracing::warn!("[ctx_compose: semantic ranking panicked; omitting section]");
String::new()
});
let _ = tx.send(ranked);
});
match rx.recv_timeout(semantic_budget()) {
Ok(ranked) => ranked.trim().to_string(),
Err(_) => "(deferred — semantic index is warming; the exact matches below are \
authoritative for this call, and ranking will be instant on the next call)"
.to_string(),
}
}
pub fn handle(task: &str, project_root: &str, crp_mode: CrpMode) -> (String, usize) {
let task = task.trim();
if task.is_empty() {
return ("ERROR: task is required".to_string(), 0);
}
let keywords = extract_keywords(task, 6);
let allow_secret = crate::core::roles::active_role().io.allow_secret_paths;
let mut out = String::new();
out.push_str(&format!("TASK: {task}\n"));
if keywords.is_empty() {
out.push_str("KEYWORDS: (none extracted — using full task for ranking)\n");
} else {
out.push_str(&format!("KEYWORDS: {}\n", keywords.join(", ")));
}
out.push_str("\n## Ranked files (semantic)\n");
out.push_str(&ranked_files_budgeted(task, project_root, crp_mode));
out.push('\n');
if let Some(primary) = keywords.first() {
let (grep, _g) = crate::tools::ctx_search::handle(
primary,
project_root,
None,
10,
crp_mode,
true,
allow_secret,
);
out.push_str(&format!("\n## Exact matches: '{primary}'\n"));
out.push_str(grep.trim());
out.push('\n');
}
use crate::core::context_packing::{greedy_max_coverage, CoverageItem};
let mut snippets: Vec<String> = Vec::new();
let mut items: Vec<CoverageItem> = Vec::new();
for kw in &keywords {
if let Some((rendered, toks)) =
crate::tools::ctx_symbol::best_symbol_snippet(kw, project_root)
{
let mut terms: std::collections::HashSet<String> =
std::collections::HashSet::from([kw.clone()]);
for other in &keywords {
if other != kw && rendered.contains(other.as_str()) {
terms.insert(other.clone());
}
}
items.push(CoverageItem {
terms,
cost: toks.max(1),
});
snippets.push(rendered);
}
}
if !items.is_empty() {
let chosen = greedy_max_coverage(&items, symbol_budget_tokens(), |_| 1.0);
let mut seen = std::collections::HashSet::new();
let mut header_written = false;
for idx in chosen {
let rendered = snippets[idx].trim();
if rendered.is_empty() || !seen.insert(rendered.to_string()) {
continue;
}
if !header_written {
out.push_str("\n## Top symbols (bodies)\n");
header_written = true;
}
out.push_str(rendered);
out.push('\n');
}
}
out.push_str(&associative_block_budgeted(project_root, &keywords));
let sent = count_tokens(&out);
(out, sent)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extract_keywords_drops_stopwords_and_short_tokens() {
let kw = extract_keywords("How does the BM25Index cache work for ctx_search?", 6);
assert!(kw.contains(&"BM25Index".to_string()));
assert!(kw.contains(&"cache".to_string()));
assert!(kw.contains(&"ctx_search".to_string()));
assert!(!kw.iter().any(|k| k == "the" || k == "How" || k == "for"));
}
#[test]
fn extract_keywords_dedups_and_caps() {
let kw = extract_keywords("alpha alpha beta gamma delta epsilon zeta eta", 3);
assert_eq!(kw.len(), 3);
assert_eq!(kw[0], "alpha");
}
#[test]
fn empty_task_is_rejected() {
let (out, tok) = handle(" ", "/tmp", CrpMode::Off);
assert!(out.starts_with("ERROR"));
assert_eq!(tok, 0);
}
}