1use std::collections::{HashMap, HashSet};
4
5use std::fmt::Write;
6use std::fs::{self, File};
7use std::io::{self, stdin, BufRead, BufReader, Read, Write as WriteIo};
8use std::path::{Path, PathBuf};
9use std::str::FromStr;
10use std::time::SystemTime;
11
12use console::style;
13use eyre::Context;
14use itertools::Itertools;
15use tempfile::NamedTempFile;
16use tracing::instrument;
17
18use crate::core::check_out::CheckOutCommitOptions;
19use crate::core::config::{get_hint_enabled, print_hint_suppression_notice, Hint};
20use crate::core::dag::Dag;
21use crate::core::effects::Effects;
22use crate::core::eventlog::{Event, EventLogDb, EventReplayer};
23use crate::core::formatting::Pluralize;
24use crate::core::repo_ext::RepoExt;
25use crate::git::{
26 CategorizedReferenceName, GitRunInfo, MaybeZeroOid, NonZeroOid, ReferenceName, Repo,
27 ResolvedReferenceInfo,
28};
29
30use super::execute::check_out_updated_head;
31use super::{find_abandoned_children, move_branches};
32
33pub fn get_deferred_commits_path(repo: &Repo) -> PathBuf {
44 repo.get_rebase_state_dir_path().join("deferred-commits")
45}
46
47fn read_deferred_commits(repo: &Repo) -> eyre::Result<Vec<NonZeroOid>> {
48 let deferred_commits_path = get_deferred_commits_path(repo);
49 let contents = match fs::read_to_string(&deferred_commits_path) {
50 Ok(contents) => contents,
51 Err(err) if err.kind() == io::ErrorKind::NotFound => Default::default(),
52 Err(err) => {
53 return Err(err)
54 .with_context(|| format!("Reading deferred commits at {deferred_commits_path:?}"))
55 }
56 };
57 let commit_oids = contents.lines().map(NonZeroOid::from_str).try_collect()?;
58 Ok(commit_oids)
59}
60
61#[instrument(skip(stream))]
62fn read_rewritten_list_entries(
63 stream: &mut impl Read,
64) -> eyre::Result<Vec<(NonZeroOid, MaybeZeroOid)>> {
65 let mut rewritten_oids = Vec::new();
66 let reader = BufReader::new(stream);
67 for line in reader.lines() {
68 let line = line?;
69 let line = line.trim();
70 match *line.split(' ').collect::<Vec<_>>().as_slice() {
71 [old_commit_oid, new_commit_oid, ..] => {
72 let old_commit_oid: NonZeroOid = old_commit_oid.parse()?;
73 let new_commit_oid: MaybeZeroOid = new_commit_oid.parse()?;
74 rewritten_oids.push((old_commit_oid, new_commit_oid));
75 }
76 _ => eyre::bail!("Invalid rewrite line: {:?}", &line),
77 }
78 }
79 Ok(rewritten_oids)
80}
81
82#[instrument]
83fn write_rewritten_list(
84 tempfile_dir: &Path,
85 rewritten_list_path: &Path,
86 rewritten_oids: &[(NonZeroOid, MaybeZeroOid)],
87) -> eyre::Result<()> {
88 std::fs::create_dir_all(tempfile_dir).wrap_err("Creating tempfile dir")?;
89 let mut tempfile =
90 NamedTempFile::new_in(tempfile_dir).wrap_err("Creating temporary `rewritten-list` file")?;
91
92 let file = tempfile.as_file_mut();
93 for (old_commit_oid, new_commit_oid) in rewritten_oids {
94 writeln!(file, "{old_commit_oid} {new_commit_oid}")?;
95 }
96 tempfile
97 .persist(rewritten_list_path)
98 .wrap_err("Moving new rewritten-list into place")?;
99 Ok(())
100}
101
102#[instrument]
103fn add_rewritten_list_entries(
104 tempfile_dir: &Path,
105 rewritten_list_path: &Path,
106 entries: &[(NonZeroOid, MaybeZeroOid)],
107) -> eyre::Result<()> {
108 let current_entries = match File::open(rewritten_list_path) {
109 Ok(mut rewritten_list_file) => read_rewritten_list_entries(&mut rewritten_list_file)?,
110 Err(err) if err.kind() == std::io::ErrorKind::NotFound => Default::default(),
111 Err(err) => return Err(err.into()),
112 };
113
114 let mut entries_to_add: HashMap<NonZeroOid, MaybeZeroOid> = entries.iter().copied().collect();
115 let mut new_entries = Vec::new();
116 for (old_commit_oid, new_commit_oid) in current_entries {
117 let new_entry = match entries_to_add.remove(&old_commit_oid) {
118 Some(new_commit_oid) => (old_commit_oid, new_commit_oid),
119 None => (old_commit_oid, new_commit_oid),
120 };
121 new_entries.push(new_entry);
122 }
123 new_entries.extend(entries_to_add.into_iter());
124
125 write_rewritten_list(tempfile_dir, rewritten_list_path, new_entries.as_slice())?;
126 Ok(())
127}
128
129#[instrument]
133pub fn hook_post_rewrite(
134 effects: &Effects,
135 git_run_info: &GitRunInfo,
136 rewrite_type: &str,
137) -> eyre::Result<()> {
138 let now = SystemTime::now();
139 let timestamp = now.duration_since(SystemTime::UNIX_EPOCH)?.as_secs_f64();
140
141 let repo = Repo::from_current_dir()?;
142 let is_spurious_event = rewrite_type == "amend" && repo.is_rebase_underway()?;
143 if is_spurious_event {
144 return Ok(());
145 }
146
147 let conn = repo.get_db_conn()?;
148 let event_log_db = EventLogDb::new(&conn)?;
149 let event_tx_id = event_log_db.make_transaction_id(now, "hook-post-rewrite")?;
150
151 {
152 let deferred_commit_oids = read_deferred_commits(&repo)?;
153 let commit_events = deferred_commit_oids
154 .into_iter()
155 .map(|commit_oid| Event::CommitEvent {
156 timestamp,
157 event_tx_id,
158 commit_oid,
159 })
160 .collect_vec();
161 event_log_db.add_events(commit_events)?;
162 }
163
164 let (rewritten_oids, rewrite_events) = {
165 let rewritten_oids = read_rewritten_list_entries(&mut stdin().lock())?;
166 let events = rewritten_oids
167 .iter()
168 .copied()
169 .map(|(old_commit_oid, new_commit_oid)| Event::RewriteEvent {
170 timestamp,
171 event_tx_id,
172 old_commit_oid: old_commit_oid.into(),
173 new_commit_oid,
174 })
175 .collect_vec();
176 let rewritten_oids_map: HashMap<NonZeroOid, MaybeZeroOid> =
177 rewritten_oids.into_iter().collect();
178 (rewritten_oids_map, events)
179 };
180
181 let message_rewritten_commits = Pluralize {
182 determiner: None,
183 amount: rewritten_oids.len(),
184 unit: ("rewritten commit", "rewritten commits"),
185 }
186 .to_string();
187 writeln!(
188 effects.get_output_stream(),
189 "branchless: processing {message_rewritten_commits}"
190 )?;
191 event_log_db.add_events(rewrite_events)?;
192
193 if repo
194 .get_rebase_state_dir_path()
195 .join(EXTRA_POST_REWRITE_FILE_NAME)
196 .exists()
197 {
198 let previous_head_info = load_original_head_info(&repo)?;
201 move_branches(effects, git_run_info, &repo, event_tx_id, &rewritten_oids)?;
202
203 let skipped_head_updated_oid = load_updated_head_oid(&repo)?;
204 match check_out_updated_head(
205 effects,
206 git_run_info,
207 &repo,
208 &event_log_db,
209 event_tx_id,
210 &rewritten_oids,
211 &previous_head_info,
212 skipped_head_updated_oid,
213 &CheckOutCommitOptions::default(),
214 )? {
215 Ok(()) => {}
216 Err(_exit_code) => {
217 eyre::bail!("Could not check out your updated `HEAD` commit.");
218 }
219 }
220 }
221
222 let should_check_abandoned_commits = get_hint_enabled(&repo, Hint::RestackWarnAbandoned)?;
223 if should_check_abandoned_commits && !is_spurious_event {
224 let printed_hint = warn_abandoned(
225 effects,
226 &repo,
227 &conn,
228 &event_log_db,
229 rewritten_oids.keys().copied(),
230 )?;
231 if printed_hint {
232 print_hint_suppression_notice(effects, Hint::RestackWarnAbandoned)?;
233 }
234 }
235
236 Ok(())
237}
238
239#[instrument(skip(old_commit_oids))]
240fn warn_abandoned(
241 effects: &Effects,
242 repo: &Repo,
243 conn: &rusqlite::Connection,
244 event_log_db: &EventLogDb,
245 old_commit_oids: impl IntoIterator<Item = NonZeroOid>,
246) -> eyre::Result<bool> {
247 let references_snapshot = repo.get_references_snapshot()?;
250 let event_replayer = EventReplayer::from_event_log_db(effects, repo, event_log_db)?;
251 let event_cursor = event_replayer.make_default_cursor();
252 let dag = Dag::open_and_sync(
253 effects,
254 repo,
255 &event_replayer,
256 event_cursor,
257 &references_snapshot,
258 )?;
259
260 let (all_abandoned_children, all_abandoned_branches) = {
261 let mut all_abandoned_children: HashSet<NonZeroOid> = HashSet::new();
262 let mut all_abandoned_branches: HashSet<&str> = HashSet::new();
263 for old_commit_oid in old_commit_oids {
264 let abandoned_result =
265 find_abandoned_children(&dag, &event_replayer, event_cursor, old_commit_oid)?;
266 let (_rewritten_oid, abandoned_children) = match abandoned_result {
267 Some(abandoned_result) => abandoned_result,
268 None => continue,
269 };
270 all_abandoned_children.extend(abandoned_children.iter());
271 if let Some(branch_names) = references_snapshot.branch_oid_to_names.get(&old_commit_oid)
272 {
273 all_abandoned_branches
274 .extend(branch_names.iter().map(|branch_name| branch_name.as_str()));
275 }
276 }
277 (all_abandoned_children, all_abandoned_branches)
278 };
279 let num_abandoned_children = all_abandoned_children.len();
280 let num_abandoned_branches = all_abandoned_branches.len();
281
282 if num_abandoned_children > 0 || num_abandoned_branches > 0 {
283 let warning_items = {
284 let mut warning_items = Vec::new();
285 if num_abandoned_children > 0 {
286 warning_items.push(
287 Pluralize {
288 determiner: None,
289 amount: num_abandoned_children,
290 unit: ("commit", "commits"),
291 }
292 .to_string(),
293 );
294 }
295 if num_abandoned_branches > 0 {
296 let abandoned_branch_count = Pluralize {
297 determiner: None,
298 amount: num_abandoned_branches,
299 unit: ("branch", "branches"),
300 }
301 .to_string();
302
303 let mut all_abandoned_branches: Vec<String> = all_abandoned_branches
304 .into_iter()
305 .map(|branch_name| {
306 CategorizedReferenceName::new(&branch_name.into()).render_suffix()
307 })
308 .collect();
309 all_abandoned_branches.sort_unstable();
310 let abandoned_branches_list = all_abandoned_branches.join(", ");
311 warning_items.push(format!(
312 "{abandoned_branch_count} ({abandoned_branches_list})"
313 ));
314 }
315
316 warning_items
317 };
318
319 let warning_message = warning_items.join(" and ");
320 let warning_message = style(format!("This operation abandoned {warning_message}!"))
321 .bold()
322 .yellow();
323
324 print!(
325 "\
326branchless: {warning_message}
327branchless: Consider running one of the following:
328branchless: - {git_restack}: re-apply the abandoned commits/branches
329branchless: (this is most likely what you want to do)
330branchless: - {git_smartlog}: assess the situation
331branchless: - {git_hide} [<commit>...]: hide the commits from the smartlog
332branchless: - {git_undo}: undo the operation
333",
334 warning_message = warning_message,
335 git_smartlog = style("git smartlog").bold(),
336 git_restack = style("git restack").bold(),
337 git_hide = style("git hide").bold(),
338 git_undo = style("git undo").bold(),
339 );
340 Ok(true)
341 } else {
342 Ok(false)
343 }
344}
345
346const ORIGINAL_HEAD_OID_FILE_NAME: &str = "branchless_original_head_oid";
347const ORIGINAL_HEAD_FILE_NAME: &str = "branchless_original_head";
348
349#[instrument]
352pub fn save_original_head_info(repo: &Repo, head_info: &ResolvedReferenceInfo) -> eyre::Result<()> {
353 let ResolvedReferenceInfo {
354 oid,
355 reference_name,
356 } = head_info;
357
358 if let Some(oid) = oid {
359 let dest_file_name = repo
360 .get_rebase_state_dir_path()
361 .join(ORIGINAL_HEAD_OID_FILE_NAME);
362 std::fs::write(dest_file_name, oid.to_string()).wrap_err("Writing head OID")?;
363 }
364
365 if let Some(head_name) = reference_name {
366 let dest_file_name = repo
367 .get_rebase_state_dir_path()
368 .join(ORIGINAL_HEAD_FILE_NAME);
369 std::fs::write(dest_file_name, head_name.as_str()).wrap_err("Writing head name")?;
370 }
371
372 Ok(())
373}
374
375#[instrument]
376fn load_original_head_info(repo: &Repo) -> eyre::Result<ResolvedReferenceInfo> {
377 let head_oid = {
378 let source_file_name = repo
379 .get_rebase_state_dir_path()
380 .join(ORIGINAL_HEAD_OID_FILE_NAME);
381 match std::fs::read_to_string(source_file_name) {
382 Ok(oid) => Some(oid.parse().wrap_err("Parsing original head OID")?),
383 Err(err) if err.kind() == std::io::ErrorKind::NotFound => None,
384 Err(err) => return Err(err.into()),
385 }
386 };
387
388 let head_name = {
389 let source_file_name = repo
390 .get_rebase_state_dir_path()
391 .join(ORIGINAL_HEAD_FILE_NAME);
392 match std::fs::read(source_file_name) {
393 Ok(reference_name) => Some(ReferenceName::from_bytes(reference_name)?),
394 Err(err) if err.kind() == std::io::ErrorKind::NotFound => None,
395 Err(err) => return Err(err.into()),
396 }
397 };
398
399 Ok(ResolvedReferenceInfo {
400 oid: head_oid,
401 reference_name: head_name,
402 })
403}
404
405const EXTRA_POST_REWRITE_FILE_NAME: &str = "branchless_do_extra_post_rewrite";
406
407const UPDATED_HEAD_FILE_NAME: &str = "branchless_updated_head";
413
414#[instrument]
415fn save_updated_head_oid(repo: &Repo, updated_head_oid: NonZeroOid) -> eyre::Result<()> {
416 let dest_file_name = repo
417 .get_rebase_state_dir_path()
418 .join(UPDATED_HEAD_FILE_NAME);
419 std::fs::write(dest_file_name, updated_head_oid.to_string())?;
420 Ok(())
421}
422
423#[instrument]
424fn load_updated_head_oid(repo: &Repo) -> eyre::Result<Option<NonZeroOid>> {
425 let source_file_name = repo
426 .get_rebase_state_dir_path()
427 .join(UPDATED_HEAD_FILE_NAME);
428 match std::fs::read_to_string(source_file_name) {
429 Ok(result) => Ok(Some(result.parse()?)),
430 Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(None),
431 Err(err) => Err(err.into()),
432 }
433}
434
435pub fn hook_register_extra_post_rewrite_hook() -> eyre::Result<()> {
442 let repo = Repo::from_current_dir()?;
443 let file_name = repo
444 .get_rebase_state_dir_path()
445 .join(EXTRA_POST_REWRITE_FILE_NAME);
446 File::create(file_name).wrap_err("Registering extra post-rewrite hook")?;
447
448 std::fs::write(
457 repo.get_rebase_state_dir_path().join("head-name"),
458 "detached HEAD",
459 )
460 .wrap_err("Setting `head-name` to detached HEAD")?;
461
462 Ok(())
463}
464
465pub fn hook_drop_commit_if_empty(
469 effects: &Effects,
470 old_commit_oid: NonZeroOid,
471) -> eyre::Result<()> {
472 let repo = Repo::from_current_dir()?;
473 let head_info = repo.get_head_info()?;
474 let head_oid = match head_info.oid {
475 Some(head_oid) => head_oid,
476 None => return Ok(()),
477 };
478 let head_commit = match repo.find_commit(head_oid)? {
479 Some(head_commit) => head_commit,
480 None => return Ok(()),
481 };
482
483 if !head_commit.is_empty() {
484 return Ok(());
485 }
486
487 let only_parent_oid = match head_commit.get_only_parent_oid() {
488 Some(only_parent_oid) => only_parent_oid,
489 None => return Ok(()),
490 };
491 writeln!(
492 effects.get_output_stream(),
493 "Skipped now-empty commit: {}",
494 effects
495 .get_glyphs()
496 .render(head_commit.friendly_describe(effects.get_glyphs())?)?
497 )?;
498 repo.set_head(only_parent_oid)?;
499
500 let orig_head_oid = match repo.find_reference(&"ORIG_HEAD".into())? {
501 Some(orig_head_reference) => orig_head_reference
502 .peel_to_commit()?
503 .map(|orig_head_commit| orig_head_commit.get_oid()),
504 None => None,
505 };
506 if Some(old_commit_oid) == orig_head_oid {
507 save_updated_head_oid(&repo, only_parent_oid)?;
508 }
509 add_rewritten_list_entries(
510 &repo.get_tempfile_dir()?,
511 &repo.get_rebase_state_dir_path().join("rewritten-list"),
512 &[
513 (old_commit_oid, MaybeZeroOid::Zero),
514 (head_commit.get_oid(), MaybeZeroOid::Zero),
515 ],
516 )?;
517
518 Ok(())
519}
520
521pub fn hook_skip_upstream_applied_commit(
524 effects: &Effects,
525 commit_oid: NonZeroOid,
526) -> eyre::Result<()> {
527 let repo = Repo::from_current_dir()?;
528 let commit = repo.find_commit_or_fail(commit_oid)?;
529 writeln!(
530 effects.get_output_stream(),
531 "Skipping commit (was already applied upstream): {}",
532 effects
533 .get_glyphs()
534 .render(commit.friendly_describe(effects.get_glyphs())?)?
535 )?;
536
537 if let Some(orig_head_reference) = repo.find_reference(&"ORIG_HEAD".into())? {
538 let resolved_orig_head = repo.resolve_reference(&orig_head_reference)?;
539 if let Some(original_head_oid) = resolved_orig_head.oid {
540 if original_head_oid == commit_oid {
541 let current_head_oid = repo.get_head_info()?.oid;
542 if let Some(current_head_oid) = current_head_oid {
543 save_updated_head_oid(&repo, current_head_oid)?;
544 }
545 }
546 }
547 }
548 add_rewritten_list_entries(
549 &repo.get_tempfile_dir()?,
550 &repo.get_rebase_state_dir_path().join("rewritten-list"),
551 &[(commit_oid, MaybeZeroOid::Zero)],
552 )?;
553
554 Ok(())
555}