1#[derive(Clone, Default, Debug)]
2pub struct Script {
3 batches: Vec<Batch>,
4}
5
6impl Script {
7 pub fn new() -> Self {
8 Default::default()
9 }
10
11 pub fn is_branch_deleted(&self, name: &str) -> bool {
12 self.batches
13 .iter()
14 .flat_map(|b| b.commands.values())
15 .flatten()
16 .any(|c| {
17 if let Command::DeleteBranch(current) = c {
18 current == name
19 } else {
20 false
21 }
22 })
23 }
24
25 pub fn iter(&self) -> impl Iterator<Item = &'_ Batch> {
26 self.batches.iter()
27 }
28
29 pub fn display<'a>(&'a self, labels: &'a dyn Labels) -> impl std::fmt::Display + 'a {
30 ScriptDisplay {
31 script: self,
32 labels,
33 }
34 }
35
36 fn infer_marks(&mut self) {
37 let expected_marks = self
38 .batches
39 .iter()
40 .map(|b| b.onto_mark())
41 .collect::<Vec<_>>();
42 for expected_mark in expected_marks {
43 for batch in &mut self.batches {
44 batch.infer_mark(expected_mark);
45 }
46 }
47 }
48}
49
50impl From<Vec<Batch>> for Script {
51 fn from(batches: Vec<Batch>) -> Self {
52 let graph = gen_graph(&batches);
54 let batches = sort_batches(batches, &graph);
55 let mut script = Self { batches };
56 script.infer_marks();
57 script
58 }
59}
60
61impl<'s> IntoIterator for &'s Script {
62 type Item = &'s Batch;
63 type IntoIter = std::slice::Iter<'s, Batch>;
64
65 fn into_iter(self) -> Self::IntoIter {
66 self.batches.iter()
67 }
68}
69
70impl std::fmt::Display for Script {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 if !self.batches.is_empty() {
73 let onto_id = self.batches[0].onto_mark();
74 let labels = NamedLabels::new();
75 labels.register_onto(onto_id);
76 self.display(&labels).fmt(f)?;
77 }
78
79 Ok(())
80 }
81}
82
83impl PartialEq for Script {
84 fn eq(&self, other: &Self) -> bool {
85 self.batches == other.batches
86 }
87}
88
89impl Eq for Script {}
90
91struct ScriptDisplay<'a> {
92 script: &'a Script,
93 labels: &'a dyn Labels,
94}
95
96impl std::fmt::Display for ScriptDisplay<'_> {
97 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98 if !self.script.batches.is_empty() {
99 writeln!(f, "label onto")?;
100 for batch in &self.script.batches {
101 writeln!(f)?;
102 write!(f, "{}", batch.display(self.labels))?;
103 }
104 }
105
106 Ok(())
107 }
108}
109
110#[derive(Clone, Debug, PartialEq, Eq)]
111pub struct Batch {
112 onto_mark: git2::Oid,
113 commands: indexmap::IndexMap<git2::Oid, indexmap::IndexSet<Command>>,
114 marks: indexmap::IndexSet<git2::Oid>,
115}
116
117impl Batch {
118 pub fn new(onto_mark: git2::Oid) -> Self {
119 Self {
120 onto_mark,
121 commands: Default::default(),
122 marks: Default::default(),
123 }
124 }
125
126 pub fn is_empty(&self) -> bool {
127 self.commands.is_empty()
128 }
129
130 pub fn onto_mark(&self) -> git2::Oid {
131 self.onto_mark
132 }
133
134 pub fn branch(&self) -> Option<&str> {
135 for (_, commands) in self.commands.iter().rev() {
136 for command in commands.iter().rev() {
137 if let Command::CreateBranch(name) = command {
138 return Some(name);
139 }
140 }
141 }
142
143 None
144 }
145
146 pub fn push(&mut self, id: git2::Oid, command: Command) {
147 if let Command::RegisterMark(mark) = command {
148 self.marks.insert(mark);
149 }
150 self.commands.entry(id).or_default().insert(command);
151 if let Some((last_key, _)) = self.commands.last() {
152 assert_eq!(*last_key, id, "gaps aren't allowed between ids");
153 }
154 }
155
156 pub fn display<'a>(&'a self, labels: &'a dyn Labels) -> impl std::fmt::Display + 'a {
157 BatchDisplay {
158 batch: self,
159 labels,
160 }
161 }
162
163 fn id(&self) -> git2::Oid {
164 *self
165 .commands
166 .first()
167 .expect("called after filtering out empty")
168 .0
169 }
170
171 fn infer_mark(&mut self, mark: git2::Oid) {
172 if mark == self.onto_mark {
173 } else if let Some(commands) = self.commands.get_mut(&mark) {
174 self.marks.insert(mark);
175 commands.insert(Command::RegisterMark(mark));
176 }
177 }
178}
179
180fn gen_graph(batches: &[Batch]) -> petgraph::graphmap::DiGraphMap<(git2::Oid, bool), usize> {
181 let mut graph = petgraph::graphmap::DiGraphMap::new();
182 for batch in batches {
183 graph.add_edge((batch.onto_mark(), false), (batch.id(), true), 0);
184 for mark in &batch.marks {
185 graph.add_edge((batch.id(), true), (*mark, false), 0);
186 }
187 }
188 graph
189}
190
191fn sort_batches(
192 mut batches: Vec<Batch>,
193 graph: &petgraph::graphmap::DiGraphMap<(git2::Oid, bool), usize>,
194) -> Vec<Batch> {
195 let mut unsorted = batches
196 .drain(..)
197 .map(|b| (b.id(), b))
198 .collect::<std::collections::HashMap<_, _>>();
199 for id in petgraph::algo::toposort(&graph, None)
200 .unwrap()
201 .into_iter()
202 .filter_map(|(id, is_batch)| is_batch.then_some(id))
203 {
204 batches.push(unsorted.remove(&id).unwrap());
205 }
206 batches
207}
208
209struct BatchDisplay<'a> {
210 batch: &'a Batch,
211 labels: &'a dyn Labels,
212}
213
214impl std::fmt::Display for BatchDisplay<'_> {
215 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216 let label = self.labels.get(self.batch.onto_mark());
217 writeln!(f, "# Formerly {}", self.batch.onto_mark())?;
218 writeln!(f, "reset {label}")?;
219 for (_, commands) in &self.batch.commands {
220 for command in commands {
221 match command {
222 Command::RegisterMark(mark_oid) => {
223 let label = self.labels.get(*mark_oid);
224 writeln!(f, "label {label}")?;
225 }
226 Command::CherryPick(cherry_oid) => {
227 writeln!(f, "pick {cherry_oid}")?;
228 }
229 Command::Reword(_msg) => {
230 writeln!(f, "reword")?;
231 }
232 Command::Fixup(squash_oid) => {
233 writeln!(f, "fixup {squash_oid}")?;
234 }
235 Command::CreateBranch(name) => {
236 writeln!(f, "exec git switch --force-create {name}")?;
237 }
238 Command::DeleteBranch(name) => {
239 writeln!(f, "exec git branch -D {name}")?;
240 }
241 }
242 }
243 }
244 Ok(())
245 }
246}
247
248pub trait Labels {
249 fn get(&self, mark_id: git2::Oid) -> &str;
250}
251
252#[derive(Default)]
253pub struct NamedLabels {
254 generator: std::cell::RefCell<names::Generator<'static>>,
255 names: elsa::FrozenMap<git2::Oid, String>,
256}
257
258impl NamedLabels {
259 pub fn new() -> Self {
260 Default::default()
261 }
262
263 pub fn register_onto(&self, onto_id: git2::Oid) {
264 self.names.insert(onto_id, "onto".to_owned());
265 }
266
267 pub fn get(&self, mark_id: git2::Oid) -> &str {
268 if let Some(label) = self.names.get(&mark_id) {
269 return label;
270 }
271
272 let label = self.generator.borrow_mut().next().unwrap();
273 self.names.insert(mark_id, label)
274 }
275}
276
277impl Labels for NamedLabels {
278 fn get(&self, mark_id: git2::Oid) -> &str {
279 self.get(mark_id)
280 }
281}
282
283#[derive(Default)]
284#[non_exhaustive]
285pub struct OidLabels {
286 onto_id: std::cell::Cell<Option<git2::Oid>>,
287 names: elsa::FrozenMap<git2::Oid, String>,
288}
289
290impl OidLabels {
291 pub fn new() -> Self {
292 Default::default()
293 }
294
295 pub fn register_onto(&self, onto_id: git2::Oid) {
296 self.onto_id.set(Some(onto_id));
297 }
298
299 pub fn get(&self, mark_id: git2::Oid) -> &str {
300 if let Some(label) = self.names.get(&mark_id) {
301 return label;
302 }
303
304 let label = match self.onto_id.get() {
305 Some(onto_id) if onto_id == mark_id => "onto".to_owned(),
306 _ => mark_id.to_string(),
307 };
308
309 self.names.insert(mark_id, label)
310 }
311}
312
313impl Labels for OidLabels {
314 fn get(&self, mark_id: git2::Oid) -> &str {
315 self.get(mark_id)
316 }
317}
318
319#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
320pub enum Command {
321 RegisterMark(git2::Oid),
323 CherryPick(git2::Oid),
325 Reword(String),
327 Fixup(git2::Oid),
329 CreateBranch(String),
331 DeleteBranch(String),
333}
334
335pub struct Executor {
336 marks: std::collections::HashMap<git2::Oid, git2::Oid>,
337 branches: Vec<(git2::Oid, String)>,
338 delete_branches: Vec<String>,
339 post_rewrite: Vec<(git2::Oid, git2::Oid)>,
340 head_id: git2::Oid,
341 dry_run: bool,
342 detached: bool,
343}
344
345impl Executor {
346 pub fn new(dry_run: bool) -> Executor {
347 Self {
348 marks: Default::default(),
349 branches: Default::default(),
350 delete_branches: Default::default(),
351 post_rewrite: Default::default(),
352 head_id: git2::Oid::zero(),
353 dry_run,
354 detached: false,
355 }
356 }
357
358 pub fn run<'s>(
359 &mut self,
360 repo: &mut dyn crate::git::Repo,
361 script: &'s Script,
362 ) -> Vec<(git2::Error, &'s str, Vec<&'s str>)> {
363 let mut failures = Vec::new();
364
365 self.head_id = repo.head_commit().id;
366
367 let onto_id = script.batches[0].onto_mark();
368 let labels = NamedLabels::new();
369 labels.register_onto(onto_id);
370 for (i, batch) in script.batches.iter().enumerate() {
371 let branch_name = batch.branch().unwrap_or("detached");
372 if !failures.is_empty() {
373 log::trace!("Ignoring `{}`", branch_name);
374 log::trace!("Script:\n{}", batch.display(&labels));
375 continue;
376 }
377
378 log::trace!("Applying `{}`", branch_name);
379 log::trace!("Script:\n{}", batch.display(&labels));
380 let res = self.stage_batch(repo, batch);
381 match res.and_then(|_| self.commit(repo)) {
382 Ok(()) => {
383 log::trace!(" `{}` succeeded", branch_name);
384 }
385 Err(err) => {
386 log::trace!(" `{}` failed: {}", branch_name, err);
387 self.abandon();
388 let dependent_branches = script.batches[(i + 1)..]
389 .iter()
390 .filter_map(|b| b.branch())
391 .collect::<Vec<_>>();
392 failures.push((err, branch_name, dependent_branches));
393 }
394 }
395 }
396
397 failures
398 }
399
400 fn stage_batch(
401 &mut self,
402 repo: &mut dyn crate::git::Repo,
403 batch: &Batch,
404 ) -> Result<(), git2::Error> {
405 let onto_mark = batch.onto_mark();
406 let onto_id = self.marks.get(&onto_mark).copied().unwrap_or(onto_mark);
407 let commit = repo.find_commit(onto_id).ok_or_else(|| {
408 git2::Error::new(
409 git2::ErrorCode::NotFound,
410 git2::ErrorClass::Reference,
411 format!("could not find commit {onto_id:?}"),
412 )
413 })?;
414 log::trace!("git checkout {} # {}", onto_id, commit.summary);
415 let mut head_oid = onto_id;
416 for (_, commands) in &batch.commands {
417 for command in commands {
418 match command {
419 Command::RegisterMark(mark_oid) => {
420 let target_oid = head_oid;
421 self.marks.insert(*mark_oid, target_oid);
422 }
423 Command::CherryPick(cherry_oid) => {
424 let cherry_commit = repo.find_commit(*cherry_oid).ok_or_else(|| {
425 git2::Error::new(
426 git2::ErrorCode::NotFound,
427 git2::ErrorClass::Reference,
428 format!("could not find commit {cherry_oid:?}"),
429 )
430 })?;
431 log::trace!(
432 "git cherry-pick {} # {}",
433 cherry_oid,
434 cherry_commit.summary
435 );
436 let updated_oid = if self.dry_run {
437 *cherry_oid
438 } else {
439 repo.cherry_pick(head_oid, *cherry_oid)?
440 };
441 self.update_head(*cherry_oid, updated_oid);
442 self.post_rewrite.push((*cherry_oid, updated_oid));
443 head_oid = updated_oid;
444 }
445 Command::Reword(msg) => {
446 log::trace!("git commit --amend");
447 let updated_oid = if self.dry_run {
448 head_oid
449 } else {
450 repo.reword(head_oid, msg)?
451 };
452 self.update_head(head_oid, updated_oid);
453 for (_old_oid, new_oid) in &mut self.post_rewrite {
454 if *new_oid == head_oid {
455 *new_oid = updated_oid;
456 }
457 }
458 head_oid = updated_oid;
459 }
460 Command::Fixup(squash_oid) => {
461 let cherry_commit = repo.find_commit(*squash_oid).ok_or_else(|| {
462 git2::Error::new(
463 git2::ErrorCode::NotFound,
464 git2::ErrorClass::Reference,
465 format!("could not find commit {squash_oid:?}"),
466 )
467 })?;
468 log::trace!(
469 "git merge --squash {} # {}",
470 squash_oid,
471 cherry_commit.summary
472 );
473 let updated_oid = if self.dry_run {
474 *squash_oid
475 } else {
476 repo.squash(*squash_oid, head_oid)?
477 };
478 self.update_head(head_oid, updated_oid);
479 self.update_head(*squash_oid, updated_oid);
480 for (_old_oid, new_oid) in &mut self.post_rewrite {
481 if *new_oid == head_oid {
482 *new_oid = updated_oid;
483 }
484 }
485 self.post_rewrite.push((*squash_oid, updated_oid));
486 head_oid = updated_oid;
487 }
488 Command::CreateBranch(name) => {
489 let branch_oid = head_oid;
490 self.branches.push((branch_oid, name.to_owned()));
491 }
492 Command::DeleteBranch(name) => {
493 self.delete_branches.push(name.to_owned());
494 }
495 }
496 }
497 }
498
499 Ok(())
500 }
501
502 pub fn update_head(&mut self, old_id: git2::Oid, new_id: git2::Oid) {
503 if self.head_id == old_id && old_id != new_id {
504 log::trace!("head changed from {} to {}", old_id, new_id);
505 self.head_id = new_id;
506 }
507 }
508
509 pub fn commit(&mut self, repo: &mut dyn crate::git::Repo) -> Result<(), git2::Error> {
510 let hook_repo = repo.path().map(git2::Repository::open).transpose()?;
511 let hooks = if self.dry_run {
512 None
513 } else {
514 hook_repo
515 .as_ref()
516 .map(git2_ext::hooks::Hooks::with_repo)
517 .transpose()?
518 };
519
520 log::trace!("Running reference-transaction hook");
521 let reference_transaction = self.branches.clone();
522 let reference_transaction: Vec<(git2::Oid, git2::Oid, &str)> = reference_transaction
523 .iter()
524 .map(|(new_oid, name)| {
525 let old_oid = git2::Oid::zero();
528 (old_oid, *new_oid, name.as_str())
529 })
530 .collect();
531 let reference_transaction =
532 if let (Some(hook_repo), Some(hooks)) = (hook_repo.as_ref(), hooks.as_ref()) {
533 Some(
534 hooks
535 .run_reference_transaction(hook_repo, &reference_transaction)
536 .map_err(|err| {
537 git2::Error::new(
538 git2::ErrorCode::GenericError,
539 git2::ErrorClass::Os,
540 err.to_string(),
541 )
542 })?,
543 )
544 } else {
545 None
546 };
547
548 if !self.branches.is_empty() || !self.delete_branches.is_empty() {
549 if !self.dry_run {
551 repo.detach()?;
552 self.detached = true;
553 }
554
555 for (oid, name) in self.branches.iter() {
556 let commit = repo.find_commit(*oid).unwrap();
557 log::trace!("git checkout {} # {}", oid, commit.summary);
558 log::trace!("git switch --force-create {}", name);
559 if !self.dry_run {
560 repo.branch(name, *oid)?;
561 }
562 }
563 }
564 self.branches.clear();
565
566 for name in self.delete_branches.iter() {
567 log::trace!("git branch -D {}", name);
568 if !self.dry_run {
569 repo.delete_branch(name)?;
570 }
571 }
572 self.delete_branches.clear();
573
574 if let Some(tx) = reference_transaction {
575 tx.committed();
576 }
577 self.post_rewrite.retain(|(old, new)| old != new);
578 if !self.post_rewrite.is_empty() {
579 log::trace!("Running post-rewrite hook");
580 if let (Some(hook_repo), Some(hooks)) = (hook_repo.as_ref(), hooks.as_ref()) {
581 hooks.run_post_rewrite_rebase(hook_repo, &self.post_rewrite);
582 }
583 self.post_rewrite.clear();
584 }
585
586 Ok(())
587 }
588
589 pub fn abandon(&mut self) {
590 self.branches.clear();
591 self.delete_branches.clear();
592 self.post_rewrite.clear();
593 }
594
595 pub fn close(
596 &mut self,
597 repo: &mut dyn crate::git::Repo,
598 restore_branch: Option<&str>,
599 ) -> Result<(), git2::Error> {
600 assert_eq!(&self.branches, &[]);
601 assert_eq!(self.delete_branches, Vec::<String>::new());
602 if let Some(restore_branch) = restore_branch {
603 log::trace!("git switch {}", restore_branch);
604 if !self.dry_run && self.detached {
605 repo.switch_branch(restore_branch)?;
606 }
607 } else if self.head_id != git2::Oid::zero() {
608 log::trace!("git switch {}", self.head_id);
609 if !self.dry_run && self.detached {
610 repo.switch_commit(self.head_id)?;
611 }
612 }
613
614 Ok(())
615 }
616}