1use std::collections::HashSet;
23use std::fs;
24use std::io;
25use std::path::{Path, PathBuf};
26
27use crate::hash::{self, HEX_LEN, Hash};
28use crate::object::Object;
29use crate::store::ObjectStore;
30
31pub const REBASE_DIR: &str = "rebase-apply";
33
34const HEAD_NAME_FILE: &str = "head-name";
35const ORIG_HEAD_FILE: &str = "orig-head";
36const ONTO_FILE: &str = "onto";
37const TODO_FILE: &str = "todo";
38const DONE_FILE: &str = "done";
39const ACTIONS_FILE: &str = "actions";
40
41const MAX_HEAD_NAME_BYTES: u64 = 4096;
42const MAX_HASH_FILE_BYTES: u64 = 128;
43const MAX_HASH_LIST_BYTES: u64 = 1024 * 1024;
44
45const MAX_REPLAY_DEPTH: usize = 10_000;
47
48const MAX_ANCESTORS: usize = 10_000;
50
51#[derive(Debug, thiserror::Error)]
53pub enum RebaseError {
54 #[error("no rebase in progress")]
56 NoRebaseInProgress,
57 #[error("rebase state on disk is malformed")]
59 InvalidRebaseState,
60 #[error(transparent)]
62 Io(#[from] io::Error),
63 #[error(transparent)]
65 Object(#[from] crate::object::MkitError),
66 #[error(transparent)]
68 Store(#[from] crate::store::StoreError),
69}
70
71pub type RebaseResult<T> = Result<T, RebaseError>;
73
74#[derive(Debug, Clone, Copy, PartialEq, Eq)]
78pub enum RebaseAction {
79 Pick,
81 Reword,
83 Squash,
86 Fixup,
89}
90
91impl RebaseAction {
92 #[must_use]
94 pub fn keyword(self) -> &'static str {
95 match self {
96 RebaseAction::Pick => "pick",
97 RebaseAction::Reword => "reword",
98 RebaseAction::Squash => "squash",
99 RebaseAction::Fixup => "fixup",
100 }
101 }
102
103 #[must_use]
106 pub fn folds_into_previous(self) -> bool {
107 matches!(self, RebaseAction::Squash | RebaseAction::Fixup)
108 }
109
110 #[must_use]
112 pub fn from_keyword(s: &str) -> Option<Self> {
113 match s {
114 "pick" => Some(RebaseAction::Pick),
115 "reword" => Some(RebaseAction::Reword),
116 "squash" => Some(RebaseAction::Squash),
117 "fixup" => Some(RebaseAction::Fixup),
118 _ => None,
119 }
120 }
121}
122
123#[derive(Debug, Clone, PartialEq, Eq)]
125pub struct RebaseState {
126 pub head_name: String,
128 pub orig_head: Hash,
130 pub onto: Hash,
132 pub todo: Vec<Hash>,
134 pub actions: Vec<RebaseAction>,
137 pub done: Vec<Hash>,
139}
140
141impl RebaseState {
142 #[must_use]
145 pub fn front_action(&self) -> RebaseAction {
146 self.actions.first().copied().unwrap_or(RebaseAction::Pick)
147 }
148
149 pub fn consume_front(&mut self) {
152 if !self.todo.is_empty() {
153 self.todo.remove(0);
154 }
155 if !self.actions.is_empty() {
156 self.actions.remove(0);
157 }
158 }
159}
160
161#[must_use]
163pub fn is_rebase_in_progress(mkit_dir: &Path) -> bool {
164 mkit_dir.join(REBASE_DIR).is_dir()
165}
166
167pub fn read_state(mkit_dir: &Path) -> RebaseResult<RebaseState> {
173 let dir = mkit_dir.join(REBASE_DIR);
174 if !dir.is_dir() {
175 return Err(RebaseError::NoRebaseInProgress);
176 }
177
178 let head_name = read_text_capped(&dir.join(HEAD_NAME_FILE), MAX_HEAD_NAME_BYTES)
179 .map_err(|_| RebaseError::InvalidRebaseState)?;
180 let head_name = trim_trailing(&head_name).to_string();
181
182 let orig_head = read_hex_hash(&dir.join(ORIG_HEAD_FILE))?;
183 let onto = read_hex_hash(&dir.join(ONTO_FILE))?;
184
185 let todo = read_hash_list(&dir.join(TODO_FILE))?;
186 let done = read_hash_list(&dir.join(DONE_FILE))?;
187 let actions = read_actions(&dir.join(ACTIONS_FILE), todo.len())?;
191
192 Ok(RebaseState {
193 head_name,
194 orig_head,
195 onto,
196 todo,
197 actions,
198 done,
199 })
200}
201
202pub fn write_state(mkit_dir: &Path, state: &RebaseState) -> RebaseResult<()> {
207 let dir = mkit_dir.join(REBASE_DIR);
208 fs::create_dir_all(&dir)?;
209
210 write_with_newline(&dir.join(HEAD_NAME_FILE), state.head_name.as_bytes())?;
211 write_with_newline(
212 &dir.join(ORIG_HEAD_FILE),
213 hash::to_hex(&state.orig_head).as_bytes(),
214 )?;
215 write_with_newline(&dir.join(ONTO_FILE), hash::to_hex(&state.onto).as_bytes())?;
216 write_hash_list(&dir.join(TODO_FILE), &state.todo)?;
217 write_actions(&dir.join(ACTIONS_FILE), &state.actions)?;
218 write_hash_list(&dir.join(DONE_FILE), &state.done)?;
219 Ok(())
220}
221
222pub fn cleanup_rebase(mkit_dir: &Path) -> RebaseResult<()> {
227 let dir = mkit_dir.join(REBASE_DIR);
228 match fs::remove_dir_all(&dir) {
229 Ok(()) => Ok(()),
230 Err(e) if e.kind() == io::ErrorKind::NotFound => Ok(()),
231 Err(e) => Err(RebaseError::Io(e)),
232 }
233}
234
235pub fn collect_commits_to_replay(
242 store: &ObjectStore,
243 head_hash: Hash,
244 onto_hash: Hash,
245) -> RebaseResult<Vec<Hash>> {
246 if head_hash == onto_hash {
247 return Ok(Vec::new());
248 }
249
250 let mut onto_ancestors: HashSet<Hash> = HashSet::new();
251 collect_ancestor_set(store, onto_hash, &mut onto_ancestors);
252
253 let mut commits: Vec<Hash> = Vec::new();
254 let mut current = head_hash;
255 let mut depth = 0usize;
256
257 while depth < MAX_REPLAY_DEPTH {
258 if current == onto_hash || onto_ancestors.contains(¤t) {
259 break;
260 }
261 commits.push(current);
262
263 let Ok(obj) = store.read_object(¤t) else {
264 break;
265 };
266 let Object::Commit(commit) = obj else { break };
267 if commit.parents.is_empty() {
268 break;
269 }
270 current = commit.parents[0];
271 depth += 1;
272 }
273
274 commits.reverse();
275 Ok(commits)
276}
277
278fn collect_ancestor_set(store: &ObjectStore, start: Hash, set: &mut HashSet<Hash>) {
283 let mut stack: Vec<Hash> = vec![start];
284 let mut count = 0usize;
285 while let Some(current) = stack.pop() {
286 if count >= MAX_ANCESTORS {
287 break;
288 }
289 if !set.insert(current) {
290 continue;
291 }
292 count += 1;
293 let Ok(obj) = store.read_object(¤t) else {
294 continue;
295 };
296 if let Object::Commit(commit) = obj {
297 for p in commit.parents {
298 stack.push(p);
299 }
300 }
301 }
302}
303
304fn read_hex_hash(path: &Path) -> RebaseResult<Hash> {
305 let raw =
306 read_text_capped(path, MAX_HASH_FILE_BYTES).map_err(|_| RebaseError::InvalidRebaseState)?;
307 let trimmed = trim_trailing(&raw);
308 hash::from_hex(trimmed).map_err(|_| RebaseError::InvalidRebaseState)
309}
310
311fn read_hash_list(path: &Path) -> RebaseResult<Vec<Hash>> {
312 let raw = match read_text_capped(path, MAX_HASH_LIST_BYTES) {
313 Ok(s) => s,
314 Err(e) if e.kind() == io::ErrorKind::NotFound => return Ok(Vec::new()),
315 Err(e) => return Err(RebaseError::Io(e)),
316 };
317 let trimmed = trim_trailing(&raw);
318 if trimmed.is_empty() {
319 return Ok(Vec::new());
320 }
321 let mut out = Vec::new();
322 for line in trimmed.split('\n') {
323 let line = line.trim_end_matches(['\r', ' ']);
324 if line.is_empty() {
325 continue;
326 }
327 if line.len() != HEX_LEN {
328 return Err(RebaseError::InvalidRebaseState);
329 }
330 let h = hash::from_hex(line).map_err(|_| RebaseError::InvalidRebaseState)?;
331 out.push(h);
332 }
333 Ok(out)
334}
335
336fn read_actions(path: &Path, todo_len: usize) -> RebaseResult<Vec<RebaseAction>> {
342 let raw = match read_text_capped(path, MAX_HASH_LIST_BYTES) {
343 Ok(s) => s,
344 Err(e) if e.kind() == io::ErrorKind::NotFound => {
345 return Ok(vec![RebaseAction::Pick; todo_len]);
346 }
347 Err(e) => return Err(RebaseError::Io(e)),
348 };
349 let mut out = Vec::with_capacity(todo_len);
350 for line in trim_trailing(&raw).split('\n') {
351 let line = line.trim();
352 if line.is_empty() {
353 continue;
354 }
355 let action = RebaseAction::from_keyword(line).ok_or(RebaseError::InvalidRebaseState)?;
356 out.push(action);
357 }
358 out.resize(todo_len, RebaseAction::Pick);
359 Ok(out)
360}
361
362fn write_actions(path: &Path, actions: &[RebaseAction]) -> RebaseResult<()> {
363 if actions.is_empty() {
364 write_with_newline(path, b"")?;
365 return Ok(());
366 }
367 let mut buf = String::with_capacity(actions.len() * 8);
368 for a in actions {
369 buf.push_str(a.keyword());
370 buf.push('\n');
371 }
372 fs::write(path, buf.as_bytes())?;
373 Ok(())
374}
375
376fn write_hash_list(path: &Path, hashes: &[Hash]) -> RebaseResult<()> {
377 if hashes.is_empty() {
378 write_with_newline(path, b"")?;
379 return Ok(());
380 }
381 let mut buf = String::with_capacity(hashes.len() * (HEX_LEN + 1));
382 for h in hashes {
383 buf.push_str(&hash::to_hex(h));
384 buf.push('\n');
385 }
386 fs::write(path, buf.as_bytes())?;
387 Ok(())
388}
389
390fn write_with_newline(path: &Path, content: &[u8]) -> io::Result<()> {
391 let mut buf: Vec<u8> = Vec::with_capacity(content.len() + 1);
392 buf.extend_from_slice(content);
393 if buf.last().copied() != Some(b'\n') {
394 buf.push(b'\n');
395 }
396 fs::write(path, buf)
397}
398
399fn read_text_capped(path: &Path, cap: u64) -> io::Result<String> {
400 let meta = fs::metadata(path)?;
401 if meta.len() > cap {
402 return Err(io::Error::new(io::ErrorKind::InvalidData, "file too large"));
403 }
404 let raw = fs::read(path)?;
405 String::from_utf8(raw).map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "non-utf8"))
406}
407
408fn trim_trailing(s: &str) -> &str {
409 s.trim_end_matches(['\n', '\r', ' '])
410}
411
412#[must_use]
414pub fn rebase_dir_path(mkit_dir: &Path) -> PathBuf {
415 mkit_dir.join(REBASE_DIR)
416}
417
418#[cfg(test)]
419mod tests {
420 use super::*;
421 use crate::object::{Commit, Identity, Tree, TreeEntry};
422 use crate::serialize;
423 use tempfile::TempDir;
424
425 fn fresh_store() -> (TempDir, ObjectStore) {
426 let dir = TempDir::new().unwrap();
427 let store = ObjectStore::init(dir.path()).unwrap();
428 (dir, store)
429 }
430
431 fn put_blob(store: &ObjectStore, data: &[u8]) -> Hash {
432 let bytes = serialize::serialize(&Object::Blob(crate::object::Blob {
433 data: data.to_vec(),
434 }))
435 .unwrap();
436 store.write(&bytes).unwrap()
437 }
438
439 fn put_tree(store: &ObjectStore, name: &str, blob_h: Hash) -> Hash {
440 let tree = Object::Tree(Tree {
441 entries: vec![TreeEntry {
442 name: name.as_bytes().to_vec(),
443 mode: crate::object::EntryMode::Blob,
444 object_hash: blob_h,
445 }],
446 });
447 let bytes = serialize::serialize(&tree).unwrap();
448 store.write(&bytes).unwrap()
449 }
450
451 fn put_commit(store: &ObjectStore, tree_h: Hash, parents: Vec<Hash>, ts: u64) -> Hash {
452 let commit = Object::Commit(Commit::new_unannotated(
453 tree_h,
454 parents,
455 Identity::ed25519([0u8; 32]),
456 [0u8; 32],
457 b"msg".to_vec(),
458 ts,
459 [0u8; 64],
460 ));
461 let bytes = serialize::serialize(&commit).unwrap();
462 store.write(&bytes).unwrap()
463 }
464
465 #[test]
466 fn state_roundtrip_writes_recoverable_files() {
467 let tmp = TempDir::new().unwrap();
468 let mkit = tmp.path().join(".mkit");
469 fs::create_dir_all(&mkit).unwrap();
470
471 let state = RebaseState {
472 head_name: "feature-branch".to_string(),
473 orig_head: hash::hash(b"orig-head"),
474 onto: hash::hash(b"onto"),
475 todo: vec![hash::hash(b"t1"), hash::hash(b"t2")],
476 actions: vec![RebaseAction::Pick, RebaseAction::Reword],
477 done: vec![hash::hash(b"d1")],
478 };
479 write_state(&mkit, &state).unwrap();
480
481 let dir = mkit.join(REBASE_DIR);
484 assert!(dir.join("head-name").is_file());
485 assert!(dir.join("orig-head").is_file());
486 assert!(dir.join("onto").is_file());
487 assert!(dir.join("todo").is_file());
488 assert!(dir.join("actions").is_file());
489 assert!(dir.join("done").is_file());
490
491 let read = read_state(&mkit).unwrap();
492 assert_eq!(read, state);
493 }
494
495 #[test]
496 fn missing_actions_file_defaults_to_all_pick() {
497 let tmp = TempDir::new().unwrap();
498 let mkit = tmp.path().join(".mkit");
499 fs::create_dir_all(&mkit).unwrap();
500 let state = RebaseState {
501 head_name: "main".to_string(),
502 orig_head: hash::hash(b"head"),
503 onto: hash::hash(b"onto"),
504 todo: vec![hash::hash(b"t1"), hash::hash(b"t2")],
505 actions: vec![RebaseAction::Pick, RebaseAction::Pick],
506 done: Vec::new(),
507 };
508 write_state(&mkit, &state).unwrap();
509 fs::remove_file(mkit.join(REBASE_DIR).join("actions")).unwrap();
512 let read = read_state(&mkit).unwrap();
513 assert_eq!(read.actions, vec![RebaseAction::Pick, RebaseAction::Pick]);
514 }
515
516 #[test]
517 fn action_keyword_roundtrip_and_folds() {
518 for a in [
519 RebaseAction::Pick,
520 RebaseAction::Reword,
521 RebaseAction::Squash,
522 RebaseAction::Fixup,
523 ] {
524 assert_eq!(RebaseAction::from_keyword(a.keyword()), Some(a));
525 }
526 assert!(RebaseAction::Squash.folds_into_previous());
527 assert!(RebaseAction::Fixup.folds_into_previous());
528 assert!(!RebaseAction::Pick.folds_into_previous());
529 assert!(!RebaseAction::Reword.folds_into_previous());
530 assert_eq!(RebaseAction::from_keyword("nope"), None);
531 }
532
533 #[test]
534 fn consume_front_keeps_todo_and_actions_aligned() {
535 let mut state = RebaseState {
536 head_name: "main".to_string(),
537 orig_head: hash::hash(b"head"),
538 onto: hash::hash(b"onto"),
539 todo: vec![hash::hash(b"t1"), hash::hash(b"t2")],
540 actions: vec![RebaseAction::Reword, RebaseAction::Pick],
541 done: Vec::new(),
542 };
543 assert_eq!(state.front_action(), RebaseAction::Reword);
544 state.consume_front();
545 assert_eq!(state.todo, vec![hash::hash(b"t2")]);
546 assert_eq!(state.actions, vec![RebaseAction::Pick]);
547 assert_eq!(state.front_action(), RebaseAction::Pick);
548 }
549
550 #[test]
551 fn state_roundtrip_with_empty_todo_and_done() {
552 let tmp = TempDir::new().unwrap();
553 let mkit = tmp.path().join(".mkit");
554 fs::create_dir_all(&mkit).unwrap();
555
556 let state = RebaseState {
557 head_name: "main".to_string(),
558 orig_head: hash::hash(b"head"),
559 onto: hash::hash(b"onto"),
560 todo: Vec::new(),
561 actions: Vec::new(),
562 done: Vec::new(),
563 };
564 write_state(&mkit, &state).unwrap();
565 let read = read_state(&mkit).unwrap();
566 assert_eq!(read, state);
567 }
568
569 #[test]
570 fn is_rebase_in_progress_detection() {
571 let tmp = TempDir::new().unwrap();
572 let mkit = tmp.path().join(".mkit");
573 fs::create_dir_all(&mkit).unwrap();
574 assert!(!is_rebase_in_progress(&mkit));
575 fs::create_dir_all(mkit.join(REBASE_DIR)).unwrap();
576 assert!(is_rebase_in_progress(&mkit));
577 }
578
579 #[test]
580 fn cleanup_removes_state_dir() {
581 let tmp = TempDir::new().unwrap();
582 let mkit = tmp.path().join(".mkit");
583 fs::create_dir_all(&mkit).unwrap();
584 fs::create_dir_all(mkit.join(REBASE_DIR)).unwrap();
585 fs::write(mkit.join(REBASE_DIR).join("head-name"), b"main\n").unwrap();
586 assert!(is_rebase_in_progress(&mkit));
587 cleanup_rebase(&mkit).unwrap();
588 assert!(!is_rebase_in_progress(&mkit));
589 }
590
591 #[test]
592 fn cleanup_on_missing_dir_is_noop() {
593 let tmp = TempDir::new().unwrap();
594 let mkit = tmp.path().join(".mkit");
595 fs::create_dir_all(&mkit).unwrap();
596 cleanup_rebase(&mkit).unwrap();
597 }
598
599 #[test]
600 fn read_state_when_no_rebase_returns_error() {
601 let tmp = TempDir::new().unwrap();
602 let mkit = tmp.path().join(".mkit");
603 fs::create_dir_all(&mkit).unwrap();
604 let err = read_state(&mkit).unwrap_err();
605 assert!(matches!(err, RebaseError::NoRebaseInProgress));
606 }
607
608 #[test]
609 fn collect_commits_on_linear_chain() {
610 let (_d, store) = fresh_store();
611 let blob = put_blob(&store, b"data");
612 let tree = put_tree(&store, "f.txt", blob);
613 let c1 = put_commit(&store, tree, vec![], 1);
614 let c2 = put_commit(&store, tree, vec![c1], 2);
615 let c3 = put_commit(&store, tree, vec![c2], 3);
616 let c4 = put_commit(&store, tree, vec![c3], 4);
617
618 let res = collect_commits_to_replay(&store, c4, c2).unwrap();
619 assert_eq!(res, vec![c3, c4]);
620 }
621
622 #[test]
623 fn collect_commits_same_commit_returns_empty() {
624 let (_d, store) = fresh_store();
625 let blob = put_blob(&store, b"data");
626 let tree = put_tree(&store, "f.txt", blob);
627 let c1 = put_commit(&store, tree, vec![], 1);
628 let res = collect_commits_to_replay(&store, c1, c1).unwrap();
629 assert!(res.is_empty());
630 }
631
632 #[test]
633 fn collect_commits_y_shape_stops_at_ancestor_of_onto() {
634 let (_d, store) = fresh_store();
635 let blob = put_blob(&store, b"data");
636 let tree = put_tree(&store, "f.txt", blob);
637 let c1 = put_commit(&store, tree, vec![], 1);
638 let c2 = put_commit(&store, tree, vec![c1], 2);
639 let c3 = put_commit(&store, tree, vec![c2], 3);
640 let c4 = put_commit(&store, tree, vec![c1], 4);
641 let c5 = put_commit(&store, tree, vec![c4], 5);
642
643 let res = collect_commits_to_replay(&store, c5, c3).unwrap();
645 assert_eq!(res, vec![c4, c5]);
646 }
647}