use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::mpsc::channel;
use futures::StreamExt as _;
use futures::TryStreamExt as _;
use futures::future::try_join_all;
use indexmap::IndexSet;
use jj_lib::backend::BackendError;
use jj_lib::backend::CommitId;
use jj_lib::backend::FileId;
use jj_lib::backend::TreeValue;
use jj_lib::commit::Commit;
use jj_lib::diff::ContentDiff;
use jj_lib::diff::DiffHunkKind;
use jj_lib::matchers::Matcher;
use jj_lib::merged_tree::TreeDiffEntry;
use jj_lib::merged_tree_builder::MergedTreeBuilder;
use jj_lib::repo::MutableRepo;
use jj_lib::repo::Repo as _;
use jj_lib::repo_path::RepoPathBuf;
use jj_lib::revset::RevsetExpression;
use jj_lib::rewrite::merge_commit_trees;
use jj_lib::store::Store;
use rayon::iter::IntoParallelIterator as _;
use rayon::prelude::ParallelIterator as _;
use crate::revset::RevsetEvaluationError;
use crate::revset::RevsetStreamExt as _;
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
pub struct FileToFix {
pub file_id: FileId,
pub base_file_id: Option<FileId>,
pub repo_path: RepoPathBuf,
}
#[derive(Debug, thiserror::Error)]
pub enum FixError {
#[error(transparent)]
Backend(#[from] BackendError),
#[error(transparent)]
RevsetEvaluation(#[from] RevsetEvaluationError),
#[error(transparent)]
Io(#[from] std::io::Error),
#[error(transparent)]
FixContent(Box<dyn std::error::Error + Send + Sync>),
}
pub trait FileFixer {
fn fix_files<'a>(
&mut self,
store: &Store,
files_to_fix: &'a HashSet<FileToFix>,
) -> Result<HashMap<&'a FileToFix, FileId>, FixError>;
}
#[derive(Debug, Default)]
pub struct FixSummary {
pub rewrites: HashMap<CommitId, CommitId>,
pub num_checked_commits: i32,
pub num_fixed_commits: i32,
}
pub struct ParallelFileFixer<T> {
fix_fn: T,
}
impl<T> ParallelFileFixer<T>
where
T: Fn(&Store, &FileToFix) -> Result<Option<FileId>, FixError> + Sync + Send,
{
pub fn new(fix_fn: T) -> Self {
Self { fix_fn }
}
}
impl<T> FileFixer for ParallelFileFixer<T>
where
T: Fn(&Store, &FileToFix) -> Result<Option<FileId>, FixError> + Sync + Send,
{
fn fix_files<'a>(
&mut self,
store: &Store,
files_to_fix: &'a HashSet<FileToFix>,
) -> Result<HashMap<&'a FileToFix, FileId>, FixError> {
let (updates_tx, updates_rx) = channel();
files_to_fix.into_par_iter().try_for_each_init(
|| updates_tx.clone(),
|updates_tx, file_to_fix| -> Result<(), FixError> {
let result = (self.fix_fn)(store, file_to_fix)?;
match result {
Some(new_file_id) => {
updates_tx.send((file_to_fix, new_file_id)).unwrap();
Ok(())
}
None => Ok(()),
}
},
)?;
drop(updates_tx);
let mut result = HashMap::new();
while let Ok((file_to_fix, new_file_id)) = updates_rx.recv() {
result.insert(file_to_fix, new_file_id);
}
Ok(result)
}
}
pub async fn fix_files(
root_commits: Vec<CommitId>,
matcher: &dyn Matcher,
include_unchanged_files: bool,
repo_mut: &mut MutableRepo,
file_fixer: &mut impl FileFixer,
) -> Result<FixSummary, FixError> {
let mut summary = FixSummary::default();
let commits: Vec<_> = RevsetExpression::commits(root_commits.clone())
.descendants()
.evaluate(repo_mut)?
.stream()
.commits(repo_mut.store())
.try_collect()
.await?;
tracing::debug!(
?root_commits,
?commits,
"looking for files to fix in commits:"
);
let base_commit_map = get_base_commit_map(&commits).await?;
let mut base_files: HashMap<(CommitId, RepoPathBuf), FileId> = HashMap::new();
let mut unique_files_to_fix: HashSet<FileToFix> = HashSet::new();
let mut commit_paths: HashMap<CommitId, HashSet<RepoPathBuf>> = HashMap::new();
for commit in commits.iter().rev() {
let mut paths: HashSet<RepoPathBuf> = HashSet::new();
let mut base_commits = Vec::new();
let base_commit_ids = base_commit_map.get(commit.id()).unwrap();
for base_commit_id in base_commit_ids {
if let Some(base_paths) = commit_paths.get(base_commit_id) {
paths.extend(base_paths.iter().cloned());
}
let base_commit = repo_mut.store().get_commit_async(base_commit_id).await?;
base_commits.push(base_commit);
}
let base_tree = merge_commit_trees(repo_mut, &base_commits).await?;
let diff_base_tree = if include_unchanged_files {
&repo_mut.store().empty_merged_tree()
} else {
&base_tree
};
let mut diff_stream = diff_base_tree.diff_stream(&commit.tree(), &matcher);
while let Some(TreeDiffEntry {
path: repo_path,
values,
}) = diff_stream.next().await
{
let values = values?;
if values.after.is_absent() {
continue;
}
let before = if include_unchanged_files {
base_tree.path_value(&repo_path).await?.into_iter().next()
} else {
values.before.into_iter().next()
};
let before_file_id = if let Some(Some(TreeValue::File {
id: before_id,
executable: _,
copy_id: _,
})) = before
{
base_files.insert((commit.id().clone(), repo_path.clone()), before_id.clone());
Some(before_id.clone())
} else {
None
};
for after_term in values.after {
if let Some(TreeValue::File {
id,
executable: _,
copy_id: _,
}) = after_term
{
let file_to_fix = FileToFix {
file_id: id.clone(),
base_file_id: before_file_id.clone(),
repo_path: repo_path.clone(),
};
unique_files_to_fix.insert(file_to_fix.clone());
paths.insert(repo_path.clone());
}
}
}
commit_paths.insert(commit.id().clone(), paths);
}
tracing::debug!(
?include_unchanged_files,
?unique_files_to_fix,
"invoking file fixer on these files:"
);
let fixed_file_ids = file_fixer.fix_files(repo_mut.store().as_ref(), &unique_files_to_fix)?;
tracing::debug!(?fixed_file_ids, "file fixer fixed these files:");
repo_mut
.transform_descendants(root_commits, async |rewriter| {
let old_commit_id = rewriter.old_commit().id().clone();
let repo_paths = commit_paths.get(&old_commit_id).unwrap();
let old_tree = rewriter.old_commit().tree();
let mut tree_builder = MergedTreeBuilder::new(old_tree.clone());
let mut has_changes = false;
for repo_path in repo_paths {
let old_value = old_tree.path_value(repo_path).await?;
let base_file_id = base_files.get(&(old_commit_id.clone(), repo_path.clone()));
let new_value = old_value.map(|old_term| {
if let Some(TreeValue::File {
id,
executable,
copy_id,
}) = old_term
{
let file_to_fix = FileToFix {
file_id: id.clone(),
base_file_id: base_file_id.cloned(),
repo_path: repo_path.clone(),
};
if let Some(new_id) = fixed_file_ids.get(&file_to_fix) {
return Some(TreeValue::File {
id: new_id.clone(),
executable: *executable,
copy_id: copy_id.clone(),
});
}
}
old_term.clone()
});
if new_value != old_value {
tree_builder.set_or_remove(repo_path.clone(), new_value);
has_changes = true;
}
}
summary.num_checked_commits += 1;
if has_changes {
summary.num_fixed_commits += 1;
let new_tree = tree_builder.write_tree().await?;
let builder = rewriter.reparent();
let new_commit = builder.set_tree(new_tree).write().await?;
summary
.rewrites
.insert(old_commit_id, new_commit.id().clone());
} else if rewriter.parents_changed() {
let new_commit = rewriter.reparent().write().await?;
summary
.rewrites
.insert(old_commit_id, new_commit.id().clone());
}
Ok(())
})
.await?;
tracing::debug!(?summary);
Ok(summary)
}
#[derive(Debug, PartialEq, Eq)]
pub enum RegionsToFormat {
LineRanges(Vec<LineRange>),
}
#[derive(Debug, PartialEq, Eq)]
pub struct FormatRange {
pub first: usize,
pub last: usize,
}
impl FormatRange {
pub fn new(first: usize, last: usize) -> Self {
Self { first, last }
}
}
pub type LineRange = FormatRange;
pub fn compute_changed_ranges(base: &[u8], current: &[u8]) -> RegionsToFormat {
let mut ranges: Vec<LineRange> = Vec::new();
if current.is_empty() {
return RegionsToFormat::LineRanges(ranges);
}
let diff = ContentDiff::by_line([base, current]);
let mut current_line = 1;
for hunk in diff.hunks() {
let line_count = compute_file_line_count(hunk.contents[1]);
match hunk.kind {
DiffHunkKind::Matching => {}
DiffHunkKind::Different => {
if line_count > 0 {
ranges.push(LineRange {
first: current_line,
last: current_line + line_count - 1,
});
}
}
}
current_line += line_count;
}
RegionsToFormat::LineRanges(ranges)
}
pub fn compute_file_line_count(text: &[u8]) -> usize {
let line_count = text.iter().filter(|&&b| b == b'\n').count();
let extra = if !text.is_empty() && !text.ends_with(b"\n") {
1
} else {
0
};
line_count + extra
}
pub async fn get_base_commit_map(
commits: &[Commit],
) -> Result<HashMap<CommitId, IndexSet<CommitId>>, FixError> {
let commit_ids: HashSet<&CommitId> = commits.iter().map(|c| c.id()).collect();
let parents_lists = try_join_all(commits.iter().map(|c| c.parents())).await?;
let base_commit_ids: HashSet<CommitId> = parents_lists
.into_iter()
.flatten()
.filter(|parent| !commit_ids.contains(parent.id()))
.map(|base_commit| base_commit.id().clone())
.collect();
let mut base_commit_map: HashMap<CommitId, IndexSet<CommitId>> = HashMap::new();
for commit in commits.iter().rev() {
let mut parent_commit_ids: IndexSet<CommitId> = IndexSet::new();
for parent_id in commit.parent_ids() {
if let Some(parent_bases) = base_commit_map.get(parent_id) {
parent_commit_ids.extend(parent_bases.iter().cloned());
}
if base_commit_ids.contains(parent_id) {
parent_commit_ids.insert(parent_id.clone());
}
}
base_commit_map.insert(commit.id().clone(), parent_commit_ids);
}
Ok(base_commit_map)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compute_file_line_count() {
assert_eq!(compute_file_line_count(b""), 0);
assert_eq!(compute_file_line_count(b"a"), 1);
assert_eq!(compute_file_line_count(b"a\n"), 1);
assert_eq!(compute_file_line_count(b"a\nb"), 2);
assert_eq!(compute_file_line_count(b"a\nb\n"), 2);
}
}