git_stack/rewrite/
mod.rs

1#[derive(Clone, Default, Debug)]
2pub struct Script {
3    batches: Vec<Batch>,
4}
5
6impl Script {
7    pub fn new() -> Self {
8        Default::default()
9    }
10
11    pub fn is_branch_deleted(&self, name: &str) -> bool {
12        self.batches
13            .iter()
14            .flat_map(|b| b.commands.values())
15            .flatten()
16            .any(|c| {
17                if let Command::DeleteBranch(current) = c {
18                    current == name
19                } else {
20                    false
21                }
22            })
23    }
24
25    pub fn iter(&self) -> impl Iterator<Item = &'_ Batch> {
26        self.batches.iter()
27    }
28
29    pub fn display<'a>(&'a self, labels: &'a dyn Labels) -> impl std::fmt::Display + 'a {
30        ScriptDisplay {
31            script: self,
32            labels,
33        }
34    }
35
36    fn infer_marks(&mut self) {
37        let expected_marks = self
38            .batches
39            .iter()
40            .map(|b| b.onto_mark())
41            .collect::<Vec<_>>();
42        for expected_mark in expected_marks {
43            for batch in &mut self.batches {
44                batch.infer_mark(expected_mark);
45            }
46        }
47    }
48}
49
50impl From<Vec<Batch>> for Script {
51    fn from(batches: Vec<Batch>) -> Self {
52        // TODO: we should partition so its not all-or-nothing
53        let graph = gen_graph(&batches);
54        let batches = sort_batches(batches, &graph);
55        let mut script = Self { batches };
56        script.infer_marks();
57        script
58    }
59}
60
61impl<'s> IntoIterator for &'s Script {
62    type Item = &'s Batch;
63    type IntoIter = std::slice::Iter<'s, Batch>;
64
65    fn into_iter(self) -> Self::IntoIter {
66        self.batches.iter()
67    }
68}
69
70impl std::fmt::Display for Script {
71    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72        if !self.batches.is_empty() {
73            let onto_id = self.batches[0].onto_mark();
74            let labels = NamedLabels::new();
75            labels.register_onto(onto_id);
76            self.display(&labels).fmt(f)?;
77        }
78
79        Ok(())
80    }
81}
82
83impl PartialEq for Script {
84    fn eq(&self, other: &Self) -> bool {
85        self.batches == other.batches
86    }
87}
88
89impl Eq for Script {}
90
91struct ScriptDisplay<'a> {
92    script: &'a Script,
93    labels: &'a dyn Labels,
94}
95
96impl std::fmt::Display for ScriptDisplay<'_> {
97    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98        if !self.script.batches.is_empty() {
99            writeln!(f, "label onto")?;
100            for batch in &self.script.batches {
101                writeln!(f)?;
102                write!(f, "{}", batch.display(self.labels))?;
103            }
104        }
105
106        Ok(())
107    }
108}
109
110#[derive(Clone, Debug, PartialEq, Eq)]
111pub struct Batch {
112    onto_mark: git2::Oid,
113    commands: indexmap::IndexMap<git2::Oid, indexmap::IndexSet<Command>>,
114    marks: indexmap::IndexSet<git2::Oid>,
115}
116
117impl Batch {
118    pub fn new(onto_mark: git2::Oid) -> Self {
119        Self {
120            onto_mark,
121            commands: Default::default(),
122            marks: Default::default(),
123        }
124    }
125
126    pub fn is_empty(&self) -> bool {
127        self.commands.is_empty()
128    }
129
130    pub fn onto_mark(&self) -> git2::Oid {
131        self.onto_mark
132    }
133
134    pub fn branch(&self) -> Option<&str> {
135        for (_, commands) in self.commands.iter().rev() {
136            for command in commands.iter().rev() {
137                if let Command::CreateBranch(name) = command {
138                    return Some(name);
139                }
140            }
141        }
142
143        None
144    }
145
146    pub fn push(&mut self, id: git2::Oid, command: Command) {
147        if let Command::RegisterMark(mark) = command {
148            self.marks.insert(mark);
149        }
150        self.commands.entry(id).or_default().insert(command);
151        if let Some((last_key, _)) = self.commands.last() {
152            assert_eq!(*last_key, id, "gaps aren't allowed between ids");
153        }
154    }
155
156    pub fn display<'a>(&'a self, labels: &'a dyn Labels) -> impl std::fmt::Display + 'a {
157        BatchDisplay {
158            batch: self,
159            labels,
160        }
161    }
162
163    fn id(&self) -> git2::Oid {
164        *self
165            .commands
166            .first()
167            .expect("called after filtering out empty")
168            .0
169    }
170
171    fn infer_mark(&mut self, mark: git2::Oid) {
172        if mark == self.onto_mark {
173        } else if let Some(commands) = self.commands.get_mut(&mark) {
174            self.marks.insert(mark);
175            commands.insert(Command::RegisterMark(mark));
176        }
177    }
178}
179
180fn gen_graph(batches: &[Batch]) -> petgraph::graphmap::DiGraphMap<(git2::Oid, bool), usize> {
181    let mut graph = petgraph::graphmap::DiGraphMap::new();
182    for batch in batches {
183        graph.add_edge((batch.onto_mark(), false), (batch.id(), true), 0);
184        for mark in &batch.marks {
185            graph.add_edge((batch.id(), true), (*mark, false), 0);
186        }
187    }
188    graph
189}
190
191fn sort_batches(
192    mut batches: Vec<Batch>,
193    graph: &petgraph::graphmap::DiGraphMap<(git2::Oid, bool), usize>,
194) -> Vec<Batch> {
195    let mut unsorted = batches
196        .drain(..)
197        .map(|b| (b.id(), b))
198        .collect::<std::collections::HashMap<_, _>>();
199    for id in petgraph::algo::toposort(&graph, None)
200        .unwrap()
201        .into_iter()
202        .filter_map(|(id, is_batch)| is_batch.then_some(id))
203    {
204        batches.push(unsorted.remove(&id).unwrap());
205    }
206    batches
207}
208
209struct BatchDisplay<'a> {
210    batch: &'a Batch,
211    labels: &'a dyn Labels,
212}
213
214impl std::fmt::Display for BatchDisplay<'_> {
215    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216        let label = self.labels.get(self.batch.onto_mark());
217        writeln!(f, "# Formerly {}", self.batch.onto_mark())?;
218        writeln!(f, "reset {label}")?;
219        for (_, commands) in &self.batch.commands {
220            for command in commands {
221                match command {
222                    Command::RegisterMark(mark_oid) => {
223                        let label = self.labels.get(*mark_oid);
224                        writeln!(f, "label {label}")?;
225                    }
226                    Command::CherryPick(cherry_oid) => {
227                        writeln!(f, "pick {cherry_oid}")?;
228                    }
229                    Command::Reword(_msg) => {
230                        writeln!(f, "reword")?;
231                    }
232                    Command::Fixup(squash_oid) => {
233                        writeln!(f, "fixup {squash_oid}")?;
234                    }
235                    Command::CreateBranch(name) => {
236                        writeln!(f, "exec git switch --force-create {name}")?;
237                    }
238                    Command::DeleteBranch(name) => {
239                        writeln!(f, "exec git branch -D {name}")?;
240                    }
241                }
242            }
243        }
244        Ok(())
245    }
246}
247
248pub trait Labels {
249    fn get(&self, mark_id: git2::Oid) -> &str;
250}
251
252#[derive(Default)]
253pub struct NamedLabels {
254    generator: std::cell::RefCell<names::Generator<'static>>,
255    names: elsa::FrozenMap<git2::Oid, String>,
256}
257
258impl NamedLabels {
259    pub fn new() -> Self {
260        Default::default()
261    }
262
263    pub fn register_onto(&self, onto_id: git2::Oid) {
264        self.names.insert(onto_id, "onto".to_owned());
265    }
266
267    pub fn get(&self, mark_id: git2::Oid) -> &str {
268        if let Some(label) = self.names.get(&mark_id) {
269            return label;
270        }
271
272        let label = self.generator.borrow_mut().next().unwrap();
273        self.names.insert(mark_id, label)
274    }
275}
276
277impl Labels for NamedLabels {
278    fn get(&self, mark_id: git2::Oid) -> &str {
279        self.get(mark_id)
280    }
281}
282
283#[derive(Default)]
284#[non_exhaustive]
285pub struct OidLabels {
286    onto_id: std::cell::Cell<Option<git2::Oid>>,
287    names: elsa::FrozenMap<git2::Oid, String>,
288}
289
290impl OidLabels {
291    pub fn new() -> Self {
292        Default::default()
293    }
294
295    pub fn register_onto(&self, onto_id: git2::Oid) {
296        self.onto_id.set(Some(onto_id));
297    }
298
299    pub fn get(&self, mark_id: git2::Oid) -> &str {
300        if let Some(label) = self.names.get(&mark_id) {
301            return label;
302        }
303
304        let label = match self.onto_id.get() {
305            Some(onto_id) if onto_id == mark_id => "onto".to_owned(),
306            _ => mark_id.to_string(),
307        };
308
309        self.names.insert(mark_id, label)
310    }
311}
312
313impl Labels for OidLabels {
314    fn get(&self, mark_id: git2::Oid) -> &str {
315        self.get(mark_id)
316    }
317}
318
319#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
320pub enum Command {
321    /// Mark the current commit with an `Oid` for future reference
322    RegisterMark(git2::Oid),
323    /// Cherry-pick an existing commit
324    CherryPick(git2::Oid),
325    /// Change the wording of a commit message
326    Reword(String),
327    /// Squash a commit into prior commit, keeping the parent commits identity
328    Fixup(git2::Oid),
329    /// Mark a branch for creation at the current commit
330    CreateBranch(String),
331    /// Mark a branch for deletion
332    DeleteBranch(String),
333}
334
335pub struct Executor {
336    marks: std::collections::HashMap<git2::Oid, git2::Oid>,
337    branches: Vec<(git2::Oid, String)>,
338    delete_branches: Vec<String>,
339    post_rewrite: Vec<(git2::Oid, git2::Oid)>,
340    head_id: git2::Oid,
341    dry_run: bool,
342    detached: bool,
343}
344
345impl Executor {
346    pub fn new(dry_run: bool) -> Executor {
347        Self {
348            marks: Default::default(),
349            branches: Default::default(),
350            delete_branches: Default::default(),
351            post_rewrite: Default::default(),
352            head_id: git2::Oid::zero(),
353            dry_run,
354            detached: false,
355        }
356    }
357
358    pub fn run<'s>(
359        &mut self,
360        repo: &mut dyn crate::git::Repo,
361        script: &'s Script,
362    ) -> Vec<(git2::Error, &'s str, Vec<&'s str>)> {
363        let mut failures = Vec::new();
364
365        self.head_id = repo.head_commit().id;
366
367        let onto_id = script.batches[0].onto_mark();
368        let labels = NamedLabels::new();
369        labels.register_onto(onto_id);
370        for (i, batch) in script.batches.iter().enumerate() {
371            let branch_name = batch.branch().unwrap_or("detached");
372            if !failures.is_empty() {
373                log::trace!("Ignoring `{}`", branch_name);
374                log::trace!("Script:\n{}", batch.display(&labels));
375                continue;
376            }
377
378            log::trace!("Applying `{}`", branch_name);
379            log::trace!("Script:\n{}", batch.display(&labels));
380            let res = self.stage_batch(repo, batch);
381            match res.and_then(|_| self.commit(repo)) {
382                Ok(()) => {
383                    log::trace!("         `{}` succeeded", branch_name);
384                }
385                Err(err) => {
386                    log::trace!("         `{}` failed: {}", branch_name, err);
387                    self.abandon();
388                    let dependent_branches = script.batches[(i + 1)..]
389                        .iter()
390                        .filter_map(|b| b.branch())
391                        .collect::<Vec<_>>();
392                    failures.push((err, branch_name, dependent_branches));
393                }
394            }
395        }
396
397        failures
398    }
399
400    fn stage_batch(
401        &mut self,
402        repo: &mut dyn crate::git::Repo,
403        batch: &Batch,
404    ) -> Result<(), git2::Error> {
405        let onto_mark = batch.onto_mark();
406        let onto_id = self.marks.get(&onto_mark).copied().unwrap_or(onto_mark);
407        let commit = repo.find_commit(onto_id).ok_or_else(|| {
408            git2::Error::new(
409                git2::ErrorCode::NotFound,
410                git2::ErrorClass::Reference,
411                format!("could not find commit {onto_id:?}"),
412            )
413        })?;
414        log::trace!("git checkout {}  # {}", onto_id, commit.summary);
415        let mut head_oid = onto_id;
416        for (_, commands) in &batch.commands {
417            for command in commands {
418                match command {
419                    Command::RegisterMark(mark_oid) => {
420                        let target_oid = head_oid;
421                        self.marks.insert(*mark_oid, target_oid);
422                    }
423                    Command::CherryPick(cherry_oid) => {
424                        let cherry_commit = repo.find_commit(*cherry_oid).ok_or_else(|| {
425                            git2::Error::new(
426                                git2::ErrorCode::NotFound,
427                                git2::ErrorClass::Reference,
428                                format!("could not find commit {cherry_oid:?}"),
429                            )
430                        })?;
431                        log::trace!(
432                            "git cherry-pick {}  # {}",
433                            cherry_oid,
434                            cherry_commit.summary
435                        );
436                        let updated_oid = if self.dry_run {
437                            *cherry_oid
438                        } else {
439                            repo.cherry_pick(head_oid, *cherry_oid)?
440                        };
441                        self.update_head(*cherry_oid, updated_oid);
442                        self.post_rewrite.push((*cherry_oid, updated_oid));
443                        head_oid = updated_oid;
444                    }
445                    Command::Reword(msg) => {
446                        log::trace!("git commit --amend");
447                        let updated_oid = if self.dry_run {
448                            head_oid
449                        } else {
450                            repo.reword(head_oid, msg)?
451                        };
452                        self.update_head(head_oid, updated_oid);
453                        for (_old_oid, new_oid) in &mut self.post_rewrite {
454                            if *new_oid == head_oid {
455                                *new_oid = updated_oid;
456                            }
457                        }
458                        head_oid = updated_oid;
459                    }
460                    Command::Fixup(squash_oid) => {
461                        let cherry_commit = repo.find_commit(*squash_oid).ok_or_else(|| {
462                            git2::Error::new(
463                                git2::ErrorCode::NotFound,
464                                git2::ErrorClass::Reference,
465                                format!("could not find commit {squash_oid:?}"),
466                            )
467                        })?;
468                        log::trace!(
469                            "git merge --squash {}  # {}",
470                            squash_oid,
471                            cherry_commit.summary
472                        );
473                        let updated_oid = if self.dry_run {
474                            *squash_oid
475                        } else {
476                            repo.squash(*squash_oid, head_oid)?
477                        };
478                        self.update_head(head_oid, updated_oid);
479                        self.update_head(*squash_oid, updated_oid);
480                        for (_old_oid, new_oid) in &mut self.post_rewrite {
481                            if *new_oid == head_oid {
482                                *new_oid = updated_oid;
483                            }
484                        }
485                        self.post_rewrite.push((*squash_oid, updated_oid));
486                        head_oid = updated_oid;
487                    }
488                    Command::CreateBranch(name) => {
489                        let branch_oid = head_oid;
490                        self.branches.push((branch_oid, name.to_owned()));
491                    }
492                    Command::DeleteBranch(name) => {
493                        self.delete_branches.push(name.to_owned());
494                    }
495                }
496            }
497        }
498
499        Ok(())
500    }
501
502    pub fn update_head(&mut self, old_id: git2::Oid, new_id: git2::Oid) {
503        if self.head_id == old_id && old_id != new_id {
504            log::trace!("head changed from {} to {}", old_id, new_id);
505            self.head_id = new_id;
506        }
507    }
508
509    pub fn commit(&mut self, repo: &mut dyn crate::git::Repo) -> Result<(), git2::Error> {
510        let hook_repo = repo.path().map(git2::Repository::open).transpose()?;
511        let hooks = if self.dry_run {
512            None
513        } else {
514            hook_repo
515                .as_ref()
516                .map(git2_ext::hooks::Hooks::with_repo)
517                .transpose()?
518        };
519
520        log::trace!("Running reference-transaction hook");
521        let reference_transaction = self.branches.clone();
522        let reference_transaction: Vec<(git2::Oid, git2::Oid, &str)> = reference_transaction
523            .iter()
524            .map(|(new_oid, name)| {
525                // HACK: relying on "force updating the reference regardless of its current value" part
526                // of rules rather than tracking the old value
527                let old_oid = git2::Oid::zero();
528                (old_oid, *new_oid, name.as_str())
529            })
530            .collect();
531        let reference_transaction =
532            if let (Some(hook_repo), Some(hooks)) = (hook_repo.as_ref(), hooks.as_ref()) {
533                Some(
534                    hooks
535                        .run_reference_transaction(hook_repo, &reference_transaction)
536                        .map_err(|err| {
537                            git2::Error::new(
538                                git2::ErrorCode::GenericError,
539                                git2::ErrorClass::Os,
540                                err.to_string(),
541                            )
542                        })?,
543                )
544            } else {
545                None
546            };
547
548        if !self.branches.is_empty() || !self.delete_branches.is_empty() {
549            // In case we are changing the branch HEAD is attached to
550            if !self.dry_run {
551                repo.detach()?;
552                self.detached = true;
553            }
554
555            for (oid, name) in self.branches.iter() {
556                let commit = repo.find_commit(*oid).unwrap();
557                log::trace!("git checkout {}  # {}", oid, commit.summary);
558                log::trace!("git switch --force-create {}", name);
559                if !self.dry_run {
560                    repo.branch(name, *oid)?;
561                }
562            }
563        }
564        self.branches.clear();
565
566        for name in self.delete_branches.iter() {
567            log::trace!("git branch -D {}", name);
568            if !self.dry_run {
569                repo.delete_branch(name)?;
570            }
571        }
572        self.delete_branches.clear();
573
574        if let Some(tx) = reference_transaction {
575            tx.committed();
576        }
577        self.post_rewrite.retain(|(old, new)| old != new);
578        if !self.post_rewrite.is_empty() {
579            log::trace!("Running post-rewrite hook");
580            if let (Some(hook_repo), Some(hooks)) = (hook_repo.as_ref(), hooks.as_ref()) {
581                hooks.run_post_rewrite_rebase(hook_repo, &self.post_rewrite);
582            }
583            self.post_rewrite.clear();
584        }
585
586        Ok(())
587    }
588
589    pub fn abandon(&mut self) {
590        self.branches.clear();
591        self.delete_branches.clear();
592        self.post_rewrite.clear();
593    }
594
595    pub fn close(
596        &mut self,
597        repo: &mut dyn crate::git::Repo,
598        restore_branch: Option<&str>,
599    ) -> Result<(), git2::Error> {
600        assert_eq!(&self.branches, &[]);
601        assert_eq!(self.delete_branches, Vec::<String>::new());
602        if let Some(restore_branch) = restore_branch {
603            log::trace!("git switch {}", restore_branch);
604            if !self.dry_run && self.detached {
605                repo.switch_branch(restore_branch)?;
606            }
607        } else if self.head_id != git2::Oid::zero() {
608            log::trace!("git switch {}", self.head_id);
609            if !self.dry_run && self.detached {
610                repo.switch_commit(self.head_id)?;
611            }
612        }
613
614        Ok(())
615    }
616}