Skip to main content

branchless/core/rewrite/
rewrite_hooks.rs

1//! Hooks used to have Git call back into `git-branchless` for various functionality.
2
3use std::collections::{HashMap, HashSet};
4
5use std::fmt::Write;
6use std::fs::File;
7use std::io::{BufRead, BufReader, Read, Write as WriteIo, stdin};
8use std::path::{Path, PathBuf};
9use std::time::SystemTime;
10
11use console::style;
12use eyre::Context;
13use itertools::Itertools;
14use tempfile::NamedTempFile;
15use tracing::instrument;
16
17use crate::core::check_out::CheckOutCommitOptions;
18use crate::core::config::{Hint, get_hint_enabled, print_hint_suppression_notice};
19use crate::core::dag::Dag;
20use crate::core::effects::Effects;
21use crate::core::eventlog::{Event, EventLogDb, EventReplayer};
22use crate::core::formatting::Pluralize;
23use crate::core::repo_ext::RepoExt;
24use crate::git::{
25    CategorizedReferenceName, GitRunInfo, MaybeZeroOid, NonZeroOid, ReferenceName, Repo,
26    ResolvedReferenceInfo,
27};
28
29use super::execute::check_out_updated_head;
30use super::{find_abandoned_children, move_branches};
31
32/// Get the path to the file which stores the list of "deferred commits".
33///
34/// During a rebase, we make new commits, but if we abort the rebase, we don't
35/// want those new commits to persist in the smartlog, etc. To address this, we
36/// instead queue up the list of created commits and only confirm them once the
37/// rebase has completed.
38///
39/// Note that this has the effect that if the user manually creates a commit
40/// during a rebase, and then aborts the rebase, the commit will not be
41/// available in the event log anywhere. This is probably acceptable.
42pub fn get_deferred_commits_path(repo: &Repo) -> PathBuf {
43    repo.get_rebase_state_dir_path().join("deferred-commits")
44}
45
46#[instrument(skip(stream))]
47fn read_rewritten_list_entries(
48    stream: &mut impl Read,
49) -> eyre::Result<Vec<(NonZeroOid, MaybeZeroOid)>> {
50    let mut rewritten_oids = Vec::new();
51    let reader = BufReader::new(stream);
52    for line in reader.lines() {
53        let line = line?;
54        let line = line.trim();
55        match *line.split(' ').collect::<Vec<_>>().as_slice() {
56            [old_commit_oid, new_commit_oid, ..] => {
57                let old_commit_oid: NonZeroOid = old_commit_oid.parse()?;
58                let new_commit_oid: MaybeZeroOid = new_commit_oid.parse()?;
59                rewritten_oids.push((old_commit_oid, new_commit_oid));
60            }
61            _ => eyre::bail!("Invalid rewrite line: {:?}", &line),
62        }
63    }
64    Ok(rewritten_oids)
65}
66
67#[instrument]
68fn write_rewritten_list(
69    tempfile_dir: &Path,
70    rewritten_list_path: &Path,
71    rewritten_oids: &[(NonZeroOid, MaybeZeroOid)],
72) -> eyre::Result<()> {
73    std::fs::create_dir_all(tempfile_dir).wrap_err("Creating tempfile dir")?;
74    let mut tempfile =
75        NamedTempFile::new_in(tempfile_dir).wrap_err("Creating temporary `rewritten-list` file")?;
76
77    let file = tempfile.as_file_mut();
78    for (old_commit_oid, new_commit_oid) in rewritten_oids {
79        writeln!(file, "{old_commit_oid} {new_commit_oid}")?;
80    }
81    tempfile
82        .persist(rewritten_list_path)
83        .wrap_err("Moving new rewritten-list into place")?;
84    Ok(())
85}
86
87#[instrument]
88fn add_rewritten_list_entries(
89    tempfile_dir: &Path,
90    rewritten_list_path: &Path,
91    entries: &[(NonZeroOid, MaybeZeroOid)],
92) -> eyre::Result<()> {
93    let current_entries = match File::open(rewritten_list_path) {
94        Ok(mut rewritten_list_file) => read_rewritten_list_entries(&mut rewritten_list_file)?,
95        Err(err) if err.kind() == std::io::ErrorKind::NotFound => Default::default(),
96        Err(err) => return Err(err.into()),
97    };
98
99    let mut entries_to_add: HashMap<NonZeroOid, MaybeZeroOid> = entries.iter().copied().collect();
100    let mut new_entries = Vec::new();
101    for (old_commit_oid, new_commit_oid) in current_entries {
102        let new_entry = match entries_to_add.remove(&old_commit_oid) {
103            Some(new_commit_oid) => (old_commit_oid, new_commit_oid),
104            None => (old_commit_oid, new_commit_oid),
105        };
106        new_entries.push(new_entry);
107    }
108    new_entries.extend(entries_to_add.into_iter());
109
110    write_rewritten_list(tempfile_dir, rewritten_list_path, new_entries.as_slice())?;
111    Ok(())
112}
113
114/// Handle Git's `post-rewrite` hook.
115///
116/// See the man-page for `githooks(5)`.
117#[instrument]
118pub fn hook_post_rewrite(
119    effects: &Effects,
120    git_run_info: &GitRunInfo,
121    rewrite_type: &str,
122) -> eyre::Result<()> {
123    let now = SystemTime::now();
124    let timestamp = now.duration_since(SystemTime::UNIX_EPOCH)?.as_secs_f64();
125
126    let repo = Repo::from_current_dir()?;
127    let is_spurious_event = rewrite_type == "amend" && repo.is_rebase_underway()?;
128    if is_spurious_event {
129        return Ok(());
130    }
131
132    let conn = repo.get_db_conn()?;
133    let event_log_db = EventLogDb::new(&conn)?;
134    let event_tx_id = event_log_db.make_transaction_id(now, "hook-post-rewrite")?;
135
136    let (rewritten_oids, rewrite_events) = {
137        let rewritten_oids = read_rewritten_list_entries(&mut stdin().lock())?;
138        let events = rewritten_oids
139            .iter()
140            .copied()
141            .map(|(old_commit_oid, new_commit_oid)| Event::RewriteEvent {
142                timestamp,
143                event_tx_id,
144                old_commit_oid: old_commit_oid.into(),
145                new_commit_oid,
146            })
147            .collect_vec();
148        let rewritten_oids_map: HashMap<NonZeroOid, MaybeZeroOid> =
149            rewritten_oids.into_iter().collect();
150        (rewritten_oids_map, events)
151    };
152
153    let message_rewritten_commits = Pluralize {
154        determiner: None,
155        amount: rewritten_oids.len(),
156        unit: ("rewritten commit", "rewritten commits"),
157    }
158    .to_string();
159    writeln!(
160        effects.get_output_stream(),
161        "branchless: processing {message_rewritten_commits}"
162    )?;
163    event_log_db.add_events(rewrite_events)?;
164
165    if repo
166        .get_rebase_state_dir_path()
167        .join(EXTRA_POST_REWRITE_FILE_NAME)
168        .exists()
169    {
170        // Make sure to resolve `ORIG_HEAD` before we potentially delete the
171        // branch it points to, so that we can get the original OID of `HEAD`.
172        let previous_head_info = load_original_head_info(&repo)?;
173        move_branches(effects, git_run_info, &repo, event_tx_id, &rewritten_oids)?;
174
175        let skipped_head_updated_oid = load_updated_head_oid(&repo)?;
176        match check_out_updated_head(
177            effects,
178            git_run_info,
179            &repo,
180            &event_log_db,
181            event_tx_id,
182            &rewritten_oids,
183            &previous_head_info,
184            skipped_head_updated_oid,
185            &CheckOutCommitOptions::default(),
186        )? {
187            Ok(()) => {}
188            Err(_exit_code) => {
189                eyre::bail!("Could not check out your updated `HEAD` commit.");
190            }
191        }
192    }
193
194    let should_check_abandoned_commits = get_hint_enabled(&repo, Hint::RestackWarnAbandoned)?;
195    if should_check_abandoned_commits && !is_spurious_event {
196        let printed_hint = warn_abandoned(
197            effects,
198            &repo,
199            &conn,
200            &event_log_db,
201            rewritten_oids.keys().copied(),
202        )?;
203        if printed_hint {
204            print_hint_suppression_notice(effects, Hint::RestackWarnAbandoned)?;
205        }
206    }
207
208    Ok(())
209}
210
211#[instrument(skip(old_commit_oids))]
212fn warn_abandoned(
213    effects: &Effects,
214    repo: &Repo,
215    conn: &rusqlite::Connection,
216    event_log_db: &EventLogDb,
217    old_commit_oids: impl IntoIterator<Item = NonZeroOid>,
218) -> eyre::Result<bool> {
219    // The caller will have added events to the event log database, so make sure
220    // to construct a fresh `EventReplayer` here.
221    let references_snapshot = repo.get_references_snapshot()?;
222    let event_replayer = EventReplayer::from_event_log_db(effects, repo, event_log_db)?;
223    let event_cursor = event_replayer.make_default_cursor();
224    let dag = Dag::open_and_sync(
225        effects,
226        repo,
227        &event_replayer,
228        event_cursor,
229        &references_snapshot,
230    )?;
231
232    let (all_abandoned_children, all_abandoned_branches) = {
233        let mut all_abandoned_children: HashSet<NonZeroOid> = HashSet::new();
234        let mut all_abandoned_branches: HashSet<&str> = HashSet::new();
235        for old_commit_oid in old_commit_oids {
236            let abandoned_result =
237                find_abandoned_children(&dag, &event_replayer, event_cursor, old_commit_oid)?;
238            let (_rewritten_oid, abandoned_children) = match abandoned_result {
239                Some(abandoned_result) => abandoned_result,
240                None => continue,
241            };
242            all_abandoned_children.extend(abandoned_children.iter());
243            if let Some(branch_names) = references_snapshot.branch_oid_to_names.get(&old_commit_oid)
244            {
245                all_abandoned_branches
246                    .extend(branch_names.iter().map(|branch_name| branch_name.as_str()));
247            }
248        }
249        (all_abandoned_children, all_abandoned_branches)
250    };
251    let num_abandoned_children = all_abandoned_children.len();
252    let num_abandoned_branches = all_abandoned_branches.len();
253
254    if num_abandoned_children > 0 || num_abandoned_branches > 0 {
255        let warning_items = {
256            let mut warning_items = Vec::new();
257            if num_abandoned_children > 0 {
258                warning_items.push(
259                    Pluralize {
260                        determiner: None,
261                        amount: num_abandoned_children,
262                        unit: ("commit", "commits"),
263                    }
264                    .to_string(),
265                );
266            }
267            if num_abandoned_branches > 0 {
268                let abandoned_branch_count = Pluralize {
269                    determiner: None,
270                    amount: num_abandoned_branches,
271                    unit: ("branch", "branches"),
272                }
273                .to_string();
274
275                let mut all_abandoned_branches: Vec<String> = all_abandoned_branches
276                    .into_iter()
277                    .map(|branch_name| {
278                        CategorizedReferenceName::new(&branch_name.into()).render_suffix()
279                    })
280                    .collect();
281                all_abandoned_branches.sort_unstable();
282                let abandoned_branches_list = all_abandoned_branches.join(", ");
283                warning_items.push(format!(
284                    "{abandoned_branch_count} ({abandoned_branches_list})"
285                ));
286            }
287
288            warning_items
289        };
290
291        let warning_message = warning_items.join(" and ");
292        let warning_message = style(format!("This operation abandoned {warning_message}!"))
293            .bold()
294            .yellow();
295
296        print!(
297            "\
298branchless: {warning_message}
299branchless: Consider running one of the following:
300branchless:   - {git_restack}: re-apply the abandoned commits/branches
301branchless:     (this is most likely what you want to do)
302branchless:   - {git_smartlog}: assess the situation
303branchless:   - {git_hide} [<commit>...]: hide the commits from the smartlog
304branchless:   - {git_undo}: undo the operation
305",
306            warning_message = warning_message,
307            git_smartlog = style("git smartlog").bold(),
308            git_restack = style("git restack").bold(),
309            git_hide = style("git hide").bold(),
310            git_undo = style("git undo").bold(),
311        );
312        Ok(true)
313    } else {
314        Ok(false)
315    }
316}
317
318const ORIGINAL_HEAD_OID_FILE_NAME: &str = "branchless_original_head_oid";
319const ORIGINAL_HEAD_FILE_NAME: &str = "branchless_original_head";
320
321/// Save the name of the currently checked-out branch. This should be called as
322/// part of initializing the rebase.
323#[instrument]
324pub fn save_original_head_info(repo: &Repo, head_info: &ResolvedReferenceInfo) -> eyre::Result<()> {
325    let ResolvedReferenceInfo {
326        oid,
327        reference_name,
328    } = head_info;
329
330    if let Some(oid) = oid {
331        let dest_file_name = repo
332            .get_rebase_state_dir_path()
333            .join(ORIGINAL_HEAD_OID_FILE_NAME);
334        std::fs::write(dest_file_name, oid.to_string()).wrap_err("Writing head OID")?;
335    }
336
337    if let Some(head_name) = reference_name {
338        let dest_file_name = repo
339            .get_rebase_state_dir_path()
340            .join(ORIGINAL_HEAD_FILE_NAME);
341        std::fs::write(dest_file_name, head_name.as_str()).wrap_err("Writing head name")?;
342    }
343
344    Ok(())
345}
346
347#[instrument]
348fn load_original_head_info(repo: &Repo) -> eyre::Result<ResolvedReferenceInfo> {
349    let head_oid = {
350        let source_file_name = repo
351            .get_rebase_state_dir_path()
352            .join(ORIGINAL_HEAD_OID_FILE_NAME);
353        match std::fs::read_to_string(source_file_name) {
354            Ok(oid) => Some(oid.parse().wrap_err("Parsing original head OID")?),
355            Err(err) if err.kind() == std::io::ErrorKind::NotFound => None,
356            Err(err) => return Err(err.into()),
357        }
358    };
359
360    let head_name = {
361        let source_file_name = repo
362            .get_rebase_state_dir_path()
363            .join(ORIGINAL_HEAD_FILE_NAME);
364        match std::fs::read(source_file_name) {
365            Ok(reference_name) => Some(ReferenceName::from_bytes(reference_name)?),
366            Err(err) if err.kind() == std::io::ErrorKind::NotFound => None,
367            Err(err) => return Err(err.into()),
368        }
369    };
370
371    Ok(ResolvedReferenceInfo {
372        oid: head_oid,
373        reference_name: head_name,
374    })
375}
376
377const EXTRA_POST_REWRITE_FILE_NAME: &str = "branchless_do_extra_post_rewrite";
378
379/// In order to handle the case of a commit being skipped and its corresponding
380/// branch being deleted, we need to store our own copy of the original `HEAD`
381/// OID, and then replace it once the rebase is about to conclude. We can't do
382/// it earlier, because if the user aborts the rebase after the commit has been
383/// skipped, then they would be returned to the wrong commit.
384const UPDATED_HEAD_FILE_NAME: &str = "branchless_updated_head";
385
386#[instrument]
387fn save_updated_head_oid(repo: &Repo, updated_head_oid: NonZeroOid) -> eyre::Result<()> {
388    let dest_file_name = repo
389        .get_rebase_state_dir_path()
390        .join(UPDATED_HEAD_FILE_NAME);
391    std::fs::write(dest_file_name, updated_head_oid.to_string())?;
392    Ok(())
393}
394
395#[instrument]
396fn load_updated_head_oid(repo: &Repo) -> eyre::Result<Option<NonZeroOid>> {
397    let source_file_name = repo
398        .get_rebase_state_dir_path()
399        .join(UPDATED_HEAD_FILE_NAME);
400    match std::fs::read_to_string(source_file_name) {
401        Ok(result) => Ok(Some(result.parse()?)),
402        Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(None),
403        Err(err) => Err(err.into()),
404    }
405}
406
407/// Register extra cleanup actions for rebase.
408///
409/// For rebases, register that extra cleanup actions should be taken when the
410/// rebase finishes and calls the post-rewrite hook. We don't want to change the
411/// behavior of `git rebase` itself, except when called via `git-branchless`, so
412/// that the user's expectations aren't unexpectedly subverted.
413pub fn hook_register_extra_post_rewrite_hook() -> eyre::Result<()> {
414    let repo = Repo::from_current_dir()?;
415    let file_name = repo
416        .get_rebase_state_dir_path()
417        .join(EXTRA_POST_REWRITE_FILE_NAME);
418    File::create(file_name).wrap_err("Registering extra post-rewrite hook")?;
419
420    // This is the last step before the rebase concludes. Ordinarily, Git will
421    // use `head-name` as the name of the previously checked-out branch, and
422    // move that branch to point to the current commit (and check it out again).
423    // We want to suppress this behavior because we don't want the branch to
424    // move (or, if we do want it to move, we will handle that ourselves as part
425    // of the post-rewrite hook). So we update `head-name` to contain "detached
426    // HEAD" to indicate to Git that no branch was checked out prior to the
427    // rebase, so that it doesn't try to adjust any branches.
428    std::fs::write(
429        repo.get_rebase_state_dir_path().join("head-name"),
430        "detached HEAD",
431    )
432    .wrap_err("Setting `head-name` to detached HEAD")?;
433
434    Ok(())
435}
436
437/// For rebases, detect empty commits (which have probably been applied
438/// upstream) and write them to the `rewritten-list` file, so that they're later
439/// passed to the `post-rewrite` hook.
440pub fn hook_drop_commit_if_empty(
441    effects: &Effects,
442    old_commit_oid: NonZeroOid,
443) -> eyre::Result<()> {
444    let repo = Repo::from_current_dir()?;
445    let head_info = repo.get_head_info()?;
446    let head_oid = match head_info.oid {
447        Some(head_oid) => head_oid,
448        None => return Ok(()),
449    };
450    let head_commit = match repo.find_commit(head_oid)? {
451        Some(head_commit) => head_commit,
452        None => return Ok(()),
453    };
454
455    if !head_commit.is_empty() {
456        return Ok(());
457    }
458
459    let only_parent_oid = match head_commit.get_only_parent_oid() {
460        Some(only_parent_oid) => only_parent_oid,
461        None => return Ok(()),
462    };
463    writeln!(
464        effects.get_output_stream(),
465        "Skipped now-empty commit: {}",
466        effects
467            .get_glyphs()
468            .render(head_commit.friendly_describe(effects.get_glyphs())?)?
469    )?;
470    repo.set_head(only_parent_oid)?;
471
472    let orig_head_oid = match repo.find_reference(&"ORIG_HEAD".into())? {
473        Some(orig_head_reference) => orig_head_reference
474            .peel_to_commit()?
475            .map(|orig_head_commit| orig_head_commit.get_oid()),
476        None => None,
477    };
478    if Some(old_commit_oid) == orig_head_oid {
479        save_updated_head_oid(&repo, only_parent_oid)?;
480    }
481    add_rewritten_list_entries(
482        &repo.get_tempfile_dir()?,
483        &repo.get_rebase_state_dir_path().join("rewritten-list"),
484        &[
485            (old_commit_oid, MaybeZeroOid::Zero),
486            (head_commit.get_oid(), MaybeZeroOid::Zero),
487        ],
488    )?;
489
490    Ok(())
491}
492
493/// For rebases, if a commit is known to have been applied upstream, skip it
494/// without attempting to apply it.
495pub fn hook_skip_upstream_applied_commit(
496    effects: &Effects,
497    commit_oid: NonZeroOid,
498) -> eyre::Result<()> {
499    let repo = Repo::from_current_dir()?;
500    let commit = repo.find_commit_or_fail(commit_oid)?;
501    writeln!(
502        effects.get_output_stream(),
503        "Skipping commit (was already applied upstream): {}",
504        effects
505            .get_glyphs()
506            .render(commit.friendly_describe(effects.get_glyphs())?)?
507    )?;
508
509    if let Some(orig_head_reference) = repo.find_reference(&"ORIG_HEAD".into())? {
510        let resolved_orig_head = repo.resolve_reference(&orig_head_reference)?;
511        if let Some(original_head_oid) = resolved_orig_head.oid {
512            if original_head_oid == commit_oid {
513                let current_head_oid = repo.get_head_info()?.oid;
514                if let Some(current_head_oid) = current_head_oid {
515                    save_updated_head_oid(&repo, current_head_oid)?;
516                }
517            }
518        }
519    }
520    add_rewritten_list_entries(
521        &repo.get_tempfile_dir()?,
522        &repo.get_rebase_state_dir_path().join("rewritten-list"),
523        &[(commit_oid, MaybeZeroOid::Zero)],
524    )?;
525
526    Ok(())
527}