use std::collections::{BTreeSet, HashMap, HashSet};
use crate::index::ProjectIndex;
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct SymbolRef {
pub file: String,
pub symbol: String,
}
impl SymbolRef {
pub fn new(file: impl Into<String>, symbol: impl Into<String>) -> Self {
Self {
file: file.into(),
symbol: symbol.into(),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct SymbolFootprint {
pub writes: HashSet<SymbolRef>,
pub reads: HashSet<SymbolRef>,
pub uncertain: bool,
}
impl SymbolFootprint {
pub fn writing(writes: impl IntoIterator<Item = SymbolRef>) -> Self {
Self {
writes: writes.into_iter().collect(),
..Default::default()
}
}
pub fn with_reads(mut self, reads: impl IntoIterator<Item = SymbolRef>) -> Self {
self.reads = reads.into_iter().collect();
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Scheduling {
Precise,
FileLevel,
Serialize,
}
#[derive(Debug, Clone)]
pub struct ExpandedFootprint {
expanded: SymbolFootprint,
declared_files: BTreeSet<String>,
scheduling: Scheduling,
}
impl ExpandedFootprint {
pub fn assume_expanded(footprint: SymbolFootprint) -> Self {
let declared_files = footprint.writes.iter().map(|w| w.file.clone()).collect();
let scheduling = if footprint.uncertain {
Scheduling::Serialize
} else {
Scheduling::Precise
};
Self {
expanded: footprint,
declared_files,
scheduling,
}
}
pub fn inner(&self) -> &SymbolFootprint {
&self.expanded
}
pub fn scheduling(&self) -> Scheduling {
self.scheduling
}
pub fn declared_files(&self) -> &BTreeSet<String> {
&self.declared_files
}
}
#[derive(Debug, Clone)]
pub struct FootprintSubtask {
pub id: String,
pub footprint: ExpandedFootprint,
}
#[derive(Debug, Default, PartialEq, Eq)]
pub struct DecompositionPlan {
pub levels: Vec<Vec<String>>,
pub edges: Vec<(String, String)>,
pub conflicts: Vec<(String, String)>,
}
pub fn expand_footprint(
index: &ProjectIndex,
declared: &SymbolFootprint,
max_depth: usize,
) -> ExpandedFootprint {
let declared_files: BTreeSet<String> =
declared.writes.iter().map(|w| w.file.clone()).collect();
let mut writes = declared.writes.clone();
let mut scheduling = if declared.uncertain || index.truncated {
Scheduling::Serialize
} else {
Scheduling::Precise
};
let mut frontier: Vec<SymbolRef> = declared.writes.iter().cloned().collect();
for _ in 0..max_depth {
let mut next = Vec::new();
for w in &frontier {
if index.find(&w.symbol).is_empty() {
if scheduling == Scheduling::Precise {
scheduling = Scheduling::FileLevel;
}
continue;
}
for cref in index.callers_of(&w.symbol) {
let caller = SymbolRef::new(cref.from_file.clone(), cref.from_symbol.clone());
if writes.insert(caller.clone()) {
next.push(caller);
}
}
}
if next.is_empty() {
break;
}
frontier = next;
}
ExpandedFootprint {
expanded: SymbolFootprint {
writes,
reads: declared.reads.clone(),
uncertain: scheduling != Scheduling::Precise,
},
declared_files,
scheduling,
}
}
pub fn analyze(subtasks: &[FootprintSubtask]) -> DecompositionPlan {
let n = subtasks.len();
let mut conflict_pairs: HashSet<(usize, usize)> = HashSet::new();
let mut deps: HashMap<usize, BTreeSet<usize>> = (0..n).map(|i| (i, BTreeSet::new())).collect();
let mut conflicts = Vec::new();
let mut edges = Vec::new();
for i in 0..n {
for j in (i + 1)..n {
let fa = &subtasks[i].footprint;
let fb = &subtasks[j].footprint;
let a = fa.inner();
let b = fb.inner();
let conflict = match (fa.scheduling(), fb.scheduling()) {
(Scheduling::Serialize, _) | (_, Scheduling::Serialize) => true,
(Scheduling::Precise, Scheduling::Precise) => !a.writes.is_disjoint(&b.writes),
(Scheduling::FileLevel, Scheduling::FileLevel)
| (Scheduling::FileLevel, Scheduling::Precise)
| (Scheduling::Precise, Scheduling::FileLevel) => {
!fa.declared_files().is_disjoint(fb.declared_files())
}
};
if conflict {
conflict_pairs.insert((i, j));
conflict_pairs.insert((j, i));
conflicts.push((subtasks[i].id.clone(), subtasks[j].id.clone()));
}
if !a.writes.is_disjoint(&b.reads) {
deps.get_mut(&j).unwrap().insert(i);
edges.push((subtasks[i].id.clone(), subtasks[j].id.clone()));
}
if !b.writes.is_disjoint(&a.reads) {
deps.get_mut(&i).unwrap().insert(j);
edges.push((subtasks[j].id.clone(), subtasks[i].id.clone()));
}
}
}
let mut placed = vec![false; n];
let mut levels: Vec<Vec<String>> = Vec::new();
while placed.iter().any(|p| !p) {
let ready: Vec<usize> = (0..n)
.filter(|&i| !placed[i] && deps[&i].iter().all(|d| placed[*d]))
.collect();
let mut chosen: Vec<usize> = if ready.is_empty() {
vec![(0..n).find(|&i| !placed[i]).unwrap()]
} else {
let mut level: Vec<usize> = Vec::new();
for &i in &ready {
if level.iter().all(|&k| !conflict_pairs.contains(&(i, k))) {
level.push(i);
}
}
level
};
chosen.sort();
for &i in &chosen {
placed[i] = true;
}
levels.push(chosen.into_iter().map(|i| subtasks[i].id.clone()).collect());
}
DecompositionPlan {
levels,
edges,
conflicts,
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sub(id: &str, fp: SymbolFootprint) -> FootprintSubtask {
FootprintSubtask {
id: id.to_string(),
footprint: ExpandedFootprint::assume_expanded(fp),
}
}
fn sub_sched(id: &str, fp: SymbolFootprint, scheduling: Scheduling) -> FootprintSubtask {
let declared_files = fp.writes.iter().map(|w| w.file.clone()).collect();
FootprintSubtask {
id: id.to_string(),
footprint: ExpandedFootprint {
expanded: fp,
declared_files,
scheduling,
},
}
}
fn w(file: &str, sym: &str) -> SymbolRef {
SymbolRef::new(file, sym)
}
#[test]
fn disjoint_writes_run_in_one_level() {
let plan = analyze(&[
sub("a", SymbolFootprint::writing([w("a.rs", "fa")])),
sub("b", SymbolFootprint::writing([w("b.rs", "fb")])),
]);
assert_eq!(plan.levels, vec![vec!["a".to_string(), "b".to_string()]]);
assert!(plan.conflicts.is_empty());
}
#[test]
fn overlapping_writes_serialize() {
let plan = analyze(&[
sub("a", SymbolFootprint::writing([w("lib.rs", "shared")])),
sub("b", SymbolFootprint::writing([w("lib.rs", "shared")])),
]);
assert_eq!(plan.levels.len(), 2, "{plan:?}");
assert_eq!(plan.conflicts.len(), 1);
}
#[test]
fn write_read_dependency_orders_levels() {
let a = SymbolFootprint::writing([w("lib.rs", "api")]);
let b = SymbolFootprint::default().with_reads([w("lib.rs", "api")]);
let plan = analyze(&[sub("a", a), sub("b", b)]);
assert_eq!(plan.levels, vec![vec!["a".to_string()], vec!["b".to_string()]]);
assert_eq!(plan.edges, vec![("a".to_string(), "b".to_string())]);
}
#[test]
fn uncertain_subtask_conflicts_with_everything() {
let mut uncertain = SymbolFootprint::writing([w("x.rs", "fx")]);
uncertain.uncertain = true;
let plan = analyze(&[
sub("u", uncertain),
sub("a", SymbolFootprint::writing([w("a.rs", "fa")])),
sub("b", SymbolFootprint::writing([w("b.rs", "fb")])),
]);
for level in &plan.levels {
if level.contains(&"u".to_string()) {
assert_eq!(level.len(), 1, "uncertain subtask is isolated: {plan:?}");
}
}
assert_eq!(plan.conflicts.len(), 2, "u conflicts with both");
}
#[test]
fn file_level_disjoint_files_run_in_parallel() {
let plan = analyze(&[
sub_sched("a", SymbolFootprint::writing([w("a.rs", "new_a")]), Scheduling::FileLevel),
sub_sched("b", SymbolFootprint::writing([w("b.rs", "new_b")]), Scheduling::FileLevel),
]);
assert_eq!(plan.levels, vec![vec!["a".to_string(), "b".to_string()]], "{plan:?}");
assert!(plan.conflicts.is_empty(), "disjoint files do not conflict: {plan:?}");
}
#[test]
fn file_level_same_file_serializes() {
let plan = analyze(&[
sub_sched("a", SymbolFootprint::writing([w("lib.rs", "new_a")]), Scheduling::FileLevel),
sub_sched("b", SymbolFootprint::writing([w("lib.rs", "new_b")]), Scheduling::FileLevel),
]);
assert_eq!(plan.levels.len(), 2, "same file → separate levels: {plan:?}");
assert_eq!(plan.conflicts.len(), 1);
}
#[test]
fn mixed_precise_and_file_level_compares_at_file_level() {
let parallel = analyze(&[
sub_sched("a", SymbolFootprint::writing([w("a.rs", "fa")]), Scheduling::Precise),
sub_sched("b", SymbolFootprint::writing([w("b.rs", "new_b")]), Scheduling::FileLevel),
]);
assert_eq!(parallel.levels, vec![vec!["a".to_string(), "b".to_string()]], "{parallel:?}");
let serial = analyze(&[
sub_sched("a", SymbolFootprint::writing([w("lib.rs", "fa")]), Scheduling::Precise),
sub_sched("b", SymbolFootprint::writing([w("lib.rs", "new_b")]), Scheduling::FileLevel),
]);
assert_eq!(serial.levels.len(), 2, "{serial:?}");
}
#[test]
fn serialize_conflicts_with_everything_even_disjoint_files() {
let plan = analyze(&[
sub_sched("s", SymbolFootprint::writing([w("s.rs", "fs")]), Scheduling::Serialize),
sub_sched("a", SymbolFootprint::writing([w("a.rs", "fa")]), Scheduling::FileLevel),
sub_sched("b", SymbolFootprint::writing([w("b.rs", "fb")]), Scheduling::Precise),
]);
for level in &plan.levels {
if level.contains(&"s".to_string()) {
assert_eq!(level.len(), 1, "Serialize subtask is isolated: {plan:?}");
}
}
assert_eq!(plan.conflicts.len(), 2, "s conflicts with both a and b");
}
#[test]
fn file_level_still_honors_read_write_edges() {
let writer = SymbolFootprint::writing([w("a.rs", "api")]);
let reader = SymbolFootprint::default().with_reads([w("a.rs", "api")]);
let plan = analyze(&[
sub_sched("writer", writer, Scheduling::FileLevel),
sub_sched("reader", reader, Scheduling::FileLevel),
]);
assert_eq!(
plan.levels,
vec![vec!["writer".to_string()], vec!["reader".to_string()]],
"reader runs after writer: {plan:?}"
);
}
#[test]
fn cyclic_dependency_is_broken_by_serializing() {
let a = SymbolFootprint {
writes: [w("lib.rs", "X")].into_iter().collect(),
reads: [w("lib.rs", "Y")].into_iter().collect(),
uncertain: false,
};
let b = SymbolFootprint {
writes: [w("lib.rs", "Y")].into_iter().collect(),
reads: [w("lib.rs", "X")].into_iter().collect(),
uncertain: false,
};
let plan = analyze(&[sub("a", a), sub("b", b)]);
assert_eq!(plan.levels.len(), 2, "{plan:?}");
}
#[test]
fn expand_marks_unknown_symbol_file_level() {
let dir = tempfile::tempdir().unwrap();
std::fs::write(dir.path().join("lib.rs"), "pub fn known() {}\n").unwrap();
let index = ProjectIndex::build(dir.path());
let declared = SymbolFootprint::writing([w("lib.rs", "does_not_exist")]);
let expanded = expand_footprint(&index, &declared, 3);
assert_eq!(expanded.scheduling(), Scheduling::FileLevel);
assert!(expanded.inner().uncertain, "unknown symbol is still 'uncertain' for inspection");
assert!(expanded.declared_files().contains("lib.rs"));
}
#[test]
fn expand_known_symbol_is_precise() {
let dir = tempfile::tempdir().unwrap();
std::fs::write(dir.path().join("lib.rs"), "pub fn known() {}\n").unwrap();
let index = ProjectIndex::build(dir.path());
let declared = SymbolFootprint::writing([w("lib.rs", "known")]);
let expanded = expand_footprint(&index, &declared, 3);
assert_eq!(expanded.scheduling(), Scheduling::Precise);
assert!(!expanded.inner().uncertain);
}
#[test]
fn truncated_index_forces_serialize_even_for_known_symbol() {
let dir = tempfile::tempdir().unwrap();
std::fs::write(dir.path().join("lib.rs"), "pub fn known() {}\n").unwrap();
let mut index = ProjectIndex::build(dir.path());
assert!(!index.truncated, "small build is not truncated");
index.truncated = true;
let declared = SymbolFootprint::writing([w("lib.rs", "known")]);
let expanded = expand_footprint(&index, &declared, 3);
assert_eq!(
expanded.scheduling(),
Scheduling::Serialize,
"truncated index must stay fail-closed, not relax to file-level"
);
}
#[test]
fn greenfield_disjoint_files_parallelize_end_to_end() {
let dir = tempfile::tempdir().unwrap();
std::fs::write(dir.path().join("a.rs"), "// implement alpha\n").unwrap();
std::fs::write(dir.path().join("b.rs"), "// implement beta\n").unwrap();
let index = ProjectIndex::build(dir.path());
let fa = expand_footprint(&index, &SymbolFootprint::writing([w("a.rs", "alpha")]), 3);
let fb = expand_footprint(&index, &SymbolFootprint::writing([w("b.rs", "beta")]), 3);
assert_eq!(fa.scheduling(), Scheduling::FileLevel);
assert_eq!(fb.scheduling(), Scheduling::FileLevel);
let plan = analyze(&[
FootprintSubtask { id: "a".into(), footprint: fa },
FootprintSubtask { id: "b".into(), footprint: fb },
]);
assert_eq!(
plan.levels,
vec![vec!["a".to_string(), "b".to_string()]],
"greenfield disjoint-file subtasks now run in parallel: {plan:?}"
);
}
}