branchless/core/rewrite/
rewrite_hooks.rs

1//! Hooks used to have Git call back into `git-branchless` for various functionality.
2
3use std::collections::{HashMap, HashSet};
4
5use std::fmt::Write;
6use std::fs::{self, File};
7use std::io::{self, stdin, BufRead, BufReader, Read, Write as WriteIo};
8use std::path::{Path, PathBuf};
9use std::str::FromStr;
10use std::time::SystemTime;
11
12use console::style;
13use eyre::Context;
14use itertools::Itertools;
15use tempfile::NamedTempFile;
16use tracing::instrument;
17
18use crate::core::check_out::CheckOutCommitOptions;
19use crate::core::config::{get_hint_enabled, print_hint_suppression_notice, Hint};
20use crate::core::dag::Dag;
21use crate::core::effects::Effects;
22use crate::core::eventlog::{Event, EventLogDb, EventReplayer};
23use crate::core::formatting::Pluralize;
24use crate::core::repo_ext::RepoExt;
25use crate::git::{
26    CategorizedReferenceName, GitRunInfo, MaybeZeroOid, NonZeroOid, ReferenceName, Repo,
27    ResolvedReferenceInfo,
28};
29
30use super::execute::check_out_updated_head;
31use super::{find_abandoned_children, move_branches};
32
33/// Get the path to the file which stores the list of "deferred commits".
34///
35/// During a rebase, we make new commits, but if we abort the rebase, we don't
36/// want those new commits to persist in the smartlog, etc. To address this, we
37/// instead queue up the list of created commits and only confirm them once the
38/// rebase has completed.
39///
40/// Note that this has the effect that if the user manually creates a commit
41/// during a rebase, and then aborts the rebase, the commit will not be
42/// available in the event log anywhere. This is probably acceptable.
43pub fn get_deferred_commits_path(repo: &Repo) -> PathBuf {
44    repo.get_rebase_state_dir_path().join("deferred-commits")
45}
46
47fn read_deferred_commits(repo: &Repo) -> eyre::Result<Vec<NonZeroOid>> {
48    let deferred_commits_path = get_deferred_commits_path(repo);
49    let contents = match fs::read_to_string(&deferred_commits_path) {
50        Ok(contents) => contents,
51        Err(err) if err.kind() == io::ErrorKind::NotFound => Default::default(),
52        Err(err) => {
53            return Err(err)
54                .with_context(|| format!("Reading deferred commits at {deferred_commits_path:?}"))
55        }
56    };
57    let commit_oids = contents.lines().map(NonZeroOid::from_str).try_collect()?;
58    Ok(commit_oids)
59}
60
61#[instrument(skip(stream))]
62fn read_rewritten_list_entries(
63    stream: &mut impl Read,
64) -> eyre::Result<Vec<(NonZeroOid, MaybeZeroOid)>> {
65    let mut rewritten_oids = Vec::new();
66    let reader = BufReader::new(stream);
67    for line in reader.lines() {
68        let line = line?;
69        let line = line.trim();
70        match *line.split(' ').collect::<Vec<_>>().as_slice() {
71            [old_commit_oid, new_commit_oid, ..] => {
72                let old_commit_oid: NonZeroOid = old_commit_oid.parse()?;
73                let new_commit_oid: MaybeZeroOid = new_commit_oid.parse()?;
74                rewritten_oids.push((old_commit_oid, new_commit_oid));
75            }
76            _ => eyre::bail!("Invalid rewrite line: {:?}", &line),
77        }
78    }
79    Ok(rewritten_oids)
80}
81
82#[instrument]
83fn write_rewritten_list(
84    tempfile_dir: &Path,
85    rewritten_list_path: &Path,
86    rewritten_oids: &[(NonZeroOid, MaybeZeroOid)],
87) -> eyre::Result<()> {
88    std::fs::create_dir_all(tempfile_dir).wrap_err("Creating tempfile dir")?;
89    let mut tempfile =
90        NamedTempFile::new_in(tempfile_dir).wrap_err("Creating temporary `rewritten-list` file")?;
91
92    let file = tempfile.as_file_mut();
93    for (old_commit_oid, new_commit_oid) in rewritten_oids {
94        writeln!(file, "{old_commit_oid} {new_commit_oid}")?;
95    }
96    tempfile
97        .persist(rewritten_list_path)
98        .wrap_err("Moving new rewritten-list into place")?;
99    Ok(())
100}
101
102#[instrument]
103fn add_rewritten_list_entries(
104    tempfile_dir: &Path,
105    rewritten_list_path: &Path,
106    entries: &[(NonZeroOid, MaybeZeroOid)],
107) -> eyre::Result<()> {
108    let current_entries = match File::open(rewritten_list_path) {
109        Ok(mut rewritten_list_file) => read_rewritten_list_entries(&mut rewritten_list_file)?,
110        Err(err) if err.kind() == std::io::ErrorKind::NotFound => Default::default(),
111        Err(err) => return Err(err.into()),
112    };
113
114    let mut entries_to_add: HashMap<NonZeroOid, MaybeZeroOid> = entries.iter().copied().collect();
115    let mut new_entries = Vec::new();
116    for (old_commit_oid, new_commit_oid) in current_entries {
117        let new_entry = match entries_to_add.remove(&old_commit_oid) {
118            Some(new_commit_oid) => (old_commit_oid, new_commit_oid),
119            None => (old_commit_oid, new_commit_oid),
120        };
121        new_entries.push(new_entry);
122    }
123    new_entries.extend(entries_to_add.into_iter());
124
125    write_rewritten_list(tempfile_dir, rewritten_list_path, new_entries.as_slice())?;
126    Ok(())
127}
128
129/// Handle Git's `post-rewrite` hook.
130///
131/// See the man-page for `githooks(5)`.
132#[instrument]
133pub fn hook_post_rewrite(
134    effects: &Effects,
135    git_run_info: &GitRunInfo,
136    rewrite_type: &str,
137) -> eyre::Result<()> {
138    let now = SystemTime::now();
139    let timestamp = now.duration_since(SystemTime::UNIX_EPOCH)?.as_secs_f64();
140
141    let repo = Repo::from_current_dir()?;
142    let is_spurious_event = rewrite_type == "amend" && repo.is_rebase_underway()?;
143    if is_spurious_event {
144        return Ok(());
145    }
146
147    let conn = repo.get_db_conn()?;
148    let event_log_db = EventLogDb::new(&conn)?;
149    let event_tx_id = event_log_db.make_transaction_id(now, "hook-post-rewrite")?;
150
151    {
152        let deferred_commit_oids = read_deferred_commits(&repo)?;
153        let commit_events = deferred_commit_oids
154            .into_iter()
155            .map(|commit_oid| Event::CommitEvent {
156                timestamp,
157                event_tx_id,
158                commit_oid,
159            })
160            .collect_vec();
161        event_log_db.add_events(commit_events)?;
162    }
163
164    let (rewritten_oids, rewrite_events) = {
165        let rewritten_oids = read_rewritten_list_entries(&mut stdin().lock())?;
166        let events = rewritten_oids
167            .iter()
168            .copied()
169            .map(|(old_commit_oid, new_commit_oid)| Event::RewriteEvent {
170                timestamp,
171                event_tx_id,
172                old_commit_oid: old_commit_oid.into(),
173                new_commit_oid,
174            })
175            .collect_vec();
176        let rewritten_oids_map: HashMap<NonZeroOid, MaybeZeroOid> =
177            rewritten_oids.into_iter().collect();
178        (rewritten_oids_map, events)
179    };
180
181    let message_rewritten_commits = Pluralize {
182        determiner: None,
183        amount: rewritten_oids.len(),
184        unit: ("rewritten commit", "rewritten commits"),
185    }
186    .to_string();
187    writeln!(
188        effects.get_output_stream(),
189        "branchless: processing {message_rewritten_commits}"
190    )?;
191    event_log_db.add_events(rewrite_events)?;
192
193    if repo
194        .get_rebase_state_dir_path()
195        .join(EXTRA_POST_REWRITE_FILE_NAME)
196        .exists()
197    {
198        // Make sure to resolve `ORIG_HEAD` before we potentially delete the
199        // branch it points to, so that we can get the original OID of `HEAD`.
200        let previous_head_info = load_original_head_info(&repo)?;
201        move_branches(effects, git_run_info, &repo, event_tx_id, &rewritten_oids)?;
202
203        let skipped_head_updated_oid = load_updated_head_oid(&repo)?;
204        match check_out_updated_head(
205            effects,
206            git_run_info,
207            &repo,
208            &event_log_db,
209            event_tx_id,
210            &rewritten_oids,
211            &previous_head_info,
212            skipped_head_updated_oid,
213            &CheckOutCommitOptions::default(),
214        )? {
215            Ok(()) => {}
216            Err(_exit_code) => {
217                eyre::bail!("Could not check out your updated `HEAD` commit.");
218            }
219        }
220    }
221
222    let should_check_abandoned_commits = get_hint_enabled(&repo, Hint::RestackWarnAbandoned)?;
223    if should_check_abandoned_commits && !is_spurious_event {
224        let printed_hint = warn_abandoned(
225            effects,
226            &repo,
227            &conn,
228            &event_log_db,
229            rewritten_oids.keys().copied(),
230        )?;
231        if printed_hint {
232            print_hint_suppression_notice(effects, Hint::RestackWarnAbandoned)?;
233        }
234    }
235
236    Ok(())
237}
238
239#[instrument(skip(old_commit_oids))]
240fn warn_abandoned(
241    effects: &Effects,
242    repo: &Repo,
243    conn: &rusqlite::Connection,
244    event_log_db: &EventLogDb,
245    old_commit_oids: impl IntoIterator<Item = NonZeroOid>,
246) -> eyre::Result<bool> {
247    // The caller will have added events to the event log database, so make sure
248    // to construct a fresh `EventReplayer` here.
249    let references_snapshot = repo.get_references_snapshot()?;
250    let event_replayer = EventReplayer::from_event_log_db(effects, repo, event_log_db)?;
251    let event_cursor = event_replayer.make_default_cursor();
252    let dag = Dag::open_and_sync(
253        effects,
254        repo,
255        &event_replayer,
256        event_cursor,
257        &references_snapshot,
258    )?;
259
260    let (all_abandoned_children, all_abandoned_branches) = {
261        let mut all_abandoned_children: HashSet<NonZeroOid> = HashSet::new();
262        let mut all_abandoned_branches: HashSet<&str> = HashSet::new();
263        for old_commit_oid in old_commit_oids {
264            let abandoned_result =
265                find_abandoned_children(&dag, &event_replayer, event_cursor, old_commit_oid)?;
266            let (_rewritten_oid, abandoned_children) = match abandoned_result {
267                Some(abandoned_result) => abandoned_result,
268                None => continue,
269            };
270            all_abandoned_children.extend(abandoned_children.iter());
271            if let Some(branch_names) = references_snapshot.branch_oid_to_names.get(&old_commit_oid)
272            {
273                all_abandoned_branches
274                    .extend(branch_names.iter().map(|branch_name| branch_name.as_str()));
275            }
276        }
277        (all_abandoned_children, all_abandoned_branches)
278    };
279    let num_abandoned_children = all_abandoned_children.len();
280    let num_abandoned_branches = all_abandoned_branches.len();
281
282    if num_abandoned_children > 0 || num_abandoned_branches > 0 {
283        let warning_items = {
284            let mut warning_items = Vec::new();
285            if num_abandoned_children > 0 {
286                warning_items.push(
287                    Pluralize {
288                        determiner: None,
289                        amount: num_abandoned_children,
290                        unit: ("commit", "commits"),
291                    }
292                    .to_string(),
293                );
294            }
295            if num_abandoned_branches > 0 {
296                let abandoned_branch_count = Pluralize {
297                    determiner: None,
298                    amount: num_abandoned_branches,
299                    unit: ("branch", "branches"),
300                }
301                .to_string();
302
303                let mut all_abandoned_branches: Vec<String> = all_abandoned_branches
304                    .into_iter()
305                    .map(|branch_name| {
306                        CategorizedReferenceName::new(&branch_name.into()).render_suffix()
307                    })
308                    .collect();
309                all_abandoned_branches.sort_unstable();
310                let abandoned_branches_list = all_abandoned_branches.join(", ");
311                warning_items.push(format!(
312                    "{abandoned_branch_count} ({abandoned_branches_list})"
313                ));
314            }
315
316            warning_items
317        };
318
319        let warning_message = warning_items.join(" and ");
320        let warning_message = style(format!("This operation abandoned {warning_message}!"))
321            .bold()
322            .yellow();
323
324        print!(
325            "\
326branchless: {warning_message}
327branchless: Consider running one of the following:
328branchless:   - {git_restack}: re-apply the abandoned commits/branches
329branchless:     (this is most likely what you want to do)
330branchless:   - {git_smartlog}: assess the situation
331branchless:   - {git_hide} [<commit>...]: hide the commits from the smartlog
332branchless:   - {git_undo}: undo the operation
333",
334            warning_message = warning_message,
335            git_smartlog = style("git smartlog").bold(),
336            git_restack = style("git restack").bold(),
337            git_hide = style("git hide").bold(),
338            git_undo = style("git undo").bold(),
339        );
340        Ok(true)
341    } else {
342        Ok(false)
343    }
344}
345
346const ORIGINAL_HEAD_OID_FILE_NAME: &str = "branchless_original_head_oid";
347const ORIGINAL_HEAD_FILE_NAME: &str = "branchless_original_head";
348
349/// Save the name of the currently checked-out branch. This should be called as
350/// part of initializing the rebase.
351#[instrument]
352pub fn save_original_head_info(repo: &Repo, head_info: &ResolvedReferenceInfo) -> eyre::Result<()> {
353    let ResolvedReferenceInfo {
354        oid,
355        reference_name,
356    } = head_info;
357
358    if let Some(oid) = oid {
359        let dest_file_name = repo
360            .get_rebase_state_dir_path()
361            .join(ORIGINAL_HEAD_OID_FILE_NAME);
362        std::fs::write(dest_file_name, oid.to_string()).wrap_err("Writing head OID")?;
363    }
364
365    if let Some(head_name) = reference_name {
366        let dest_file_name = repo
367            .get_rebase_state_dir_path()
368            .join(ORIGINAL_HEAD_FILE_NAME);
369        std::fs::write(dest_file_name, head_name.as_str()).wrap_err("Writing head name")?;
370    }
371
372    Ok(())
373}
374
375#[instrument]
376fn load_original_head_info(repo: &Repo) -> eyre::Result<ResolvedReferenceInfo> {
377    let head_oid = {
378        let source_file_name = repo
379            .get_rebase_state_dir_path()
380            .join(ORIGINAL_HEAD_OID_FILE_NAME);
381        match std::fs::read_to_string(source_file_name) {
382            Ok(oid) => Some(oid.parse().wrap_err("Parsing original head OID")?),
383            Err(err) if err.kind() == std::io::ErrorKind::NotFound => None,
384            Err(err) => return Err(err.into()),
385        }
386    };
387
388    let head_name = {
389        let source_file_name = repo
390            .get_rebase_state_dir_path()
391            .join(ORIGINAL_HEAD_FILE_NAME);
392        match std::fs::read(source_file_name) {
393            Ok(reference_name) => Some(ReferenceName::from_bytes(reference_name)?),
394            Err(err) if err.kind() == std::io::ErrorKind::NotFound => None,
395            Err(err) => return Err(err.into()),
396        }
397    };
398
399    Ok(ResolvedReferenceInfo {
400        oid: head_oid,
401        reference_name: head_name,
402    })
403}
404
405const EXTRA_POST_REWRITE_FILE_NAME: &str = "branchless_do_extra_post_rewrite";
406
407/// In order to handle the case of a commit being skipped and its corresponding
408/// branch being deleted, we need to store our own copy of the original `HEAD`
409/// OID, and then replace it once the rebase is about to conclude. We can't do
410/// it earlier, because if the user aborts the rebase after the commit has been
411/// skipped, then they would be returned to the wrong commit.
412const UPDATED_HEAD_FILE_NAME: &str = "branchless_updated_head";
413
414#[instrument]
415fn save_updated_head_oid(repo: &Repo, updated_head_oid: NonZeroOid) -> eyre::Result<()> {
416    let dest_file_name = repo
417        .get_rebase_state_dir_path()
418        .join(UPDATED_HEAD_FILE_NAME);
419    std::fs::write(dest_file_name, updated_head_oid.to_string())?;
420    Ok(())
421}
422
423#[instrument]
424fn load_updated_head_oid(repo: &Repo) -> eyre::Result<Option<NonZeroOid>> {
425    let source_file_name = repo
426        .get_rebase_state_dir_path()
427        .join(UPDATED_HEAD_FILE_NAME);
428    match std::fs::read_to_string(source_file_name) {
429        Ok(result) => Ok(Some(result.parse()?)),
430        Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(None),
431        Err(err) => Err(err.into()),
432    }
433}
434
435/// Register extra cleanup actions for rebase.
436///
437/// For rebases, register that extra cleanup actions should be taken when the
438/// rebase finishes and calls the post-rewrite hook. We don't want to change the
439/// behavior of `git rebase` itself, except when called via `git-branchless`, so
440/// that the user's expectations aren't unexpectedly subverted.
441pub fn hook_register_extra_post_rewrite_hook() -> eyre::Result<()> {
442    let repo = Repo::from_current_dir()?;
443    let file_name = repo
444        .get_rebase_state_dir_path()
445        .join(EXTRA_POST_REWRITE_FILE_NAME);
446    File::create(file_name).wrap_err("Registering extra post-rewrite hook")?;
447
448    // This is the last step before the rebase concludes. Ordinarily, Git will
449    // use `head-name` as the name of the previously checked-out branch, and
450    // move that branch to point to the current commit (and check it out again).
451    // We want to suppress this behavior because we don't want the branch to
452    // move (or, if we do want it to move, we will handle that ourselves as part
453    // of the post-rewrite hook). So we update `head-name` to contain "detached
454    // HEAD" to indicate to Git that no branch was checked out prior to the
455    // rebase, so that it doesn't try to adjust any branches.
456    std::fs::write(
457        repo.get_rebase_state_dir_path().join("head-name"),
458        "detached HEAD",
459    )
460    .wrap_err("Setting `head-name` to detached HEAD")?;
461
462    Ok(())
463}
464
465/// For rebases, detect empty commits (which have probably been applied
466/// upstream) and write them to the `rewritten-list` file, so that they're later
467/// passed to the `post-rewrite` hook.
468pub fn hook_drop_commit_if_empty(
469    effects: &Effects,
470    old_commit_oid: NonZeroOid,
471) -> eyre::Result<()> {
472    let repo = Repo::from_current_dir()?;
473    let head_info = repo.get_head_info()?;
474    let head_oid = match head_info.oid {
475        Some(head_oid) => head_oid,
476        None => return Ok(()),
477    };
478    let head_commit = match repo.find_commit(head_oid)? {
479        Some(head_commit) => head_commit,
480        None => return Ok(()),
481    };
482
483    if !head_commit.is_empty() {
484        return Ok(());
485    }
486
487    let only_parent_oid = match head_commit.get_only_parent_oid() {
488        Some(only_parent_oid) => only_parent_oid,
489        None => return Ok(()),
490    };
491    writeln!(
492        effects.get_output_stream(),
493        "Skipped now-empty commit: {}",
494        effects
495            .get_glyphs()
496            .render(head_commit.friendly_describe(effects.get_glyphs())?)?
497    )?;
498    repo.set_head(only_parent_oid)?;
499
500    let orig_head_oid = match repo.find_reference(&"ORIG_HEAD".into())? {
501        Some(orig_head_reference) => orig_head_reference
502            .peel_to_commit()?
503            .map(|orig_head_commit| orig_head_commit.get_oid()),
504        None => None,
505    };
506    if Some(old_commit_oid) == orig_head_oid {
507        save_updated_head_oid(&repo, only_parent_oid)?;
508    }
509    add_rewritten_list_entries(
510        &repo.get_tempfile_dir()?,
511        &repo.get_rebase_state_dir_path().join("rewritten-list"),
512        &[
513            (old_commit_oid, MaybeZeroOid::Zero),
514            (head_commit.get_oid(), MaybeZeroOid::Zero),
515        ],
516    )?;
517
518    Ok(())
519}
520
521/// For rebases, if a commit is known to have been applied upstream, skip it
522/// without attempting to apply it.
523pub fn hook_skip_upstream_applied_commit(
524    effects: &Effects,
525    commit_oid: NonZeroOid,
526) -> eyre::Result<()> {
527    let repo = Repo::from_current_dir()?;
528    let commit = repo.find_commit_or_fail(commit_oid)?;
529    writeln!(
530        effects.get_output_stream(),
531        "Skipping commit (was already applied upstream): {}",
532        effects
533            .get_glyphs()
534            .render(commit.friendly_describe(effects.get_glyphs())?)?
535    )?;
536
537    if let Some(orig_head_reference) = repo.find_reference(&"ORIG_HEAD".into())? {
538        let resolved_orig_head = repo.resolve_reference(&orig_head_reference)?;
539        if let Some(original_head_oid) = resolved_orig_head.oid {
540            if original_head_oid == commit_oid {
541                let current_head_oid = repo.get_head_info()?.oid;
542                if let Some(current_head_oid) = current_head_oid {
543                    save_updated_head_oid(&repo, current_head_oid)?;
544                }
545            }
546        }
547    }
548    add_rewritten_list_entries(
549        &repo.get_tempfile_dir()?,
550        &repo.get_rebase_state_dir_path().join("rewritten-list"),
551        &[(commit_oid, MaybeZeroOid::Zero)],
552    )?;
553
554    Ok(())
555}