1use crate::cherry_pick::{CherryPickError, cherry_pick};
7use crate::commit::CommitsTable;
8use crate::history::log;
9use crate::object_store::GitObjectStore;
10
11#[derive(Debug, thiserror::Error)]
13pub enum RebaseError {
14 #[error("Cherry-pick failed: {0}")]
15 CherryPick(#[from] CherryPickError),
16
17 #[error("Commit not found: {0}")]
18 CommitNotFound(String),
19
20 #[error("Nothing to rebase (start equals onto)")]
21 NothingToRebase,
22}
23
24pub struct RebaseResult {
26 pub new_head: String,
28 pub replayed: usize,
30}
31
32pub fn rebase(
40 obj_store: &mut GitObjectStore,
41 commits_table: &mut CommitsTable,
42 start_commit_id: &str, end_commit_id: &str, onto_commit_id: &str, author: &str,
46) -> Result<RebaseResult, RebaseError> {
47 if start_commit_id == onto_commit_id {
48 return Err(RebaseError::NothingToRebase);
49 }
50
51 let all = log(commits_table, end_commit_id, 0);
52 let mut to_replay: Vec<String> = Vec::new();
53 for commit in &all {
54 if commit.commit_id == start_commit_id {
55 break;
56 }
57 to_replay.push(commit.commit_id.clone());
58 }
59
60 to_replay.reverse();
62
63 if to_replay.is_empty() {
64 return Err(RebaseError::NothingToRebase);
65 }
66
67 let mut current_head = onto_commit_id.to_string();
69 let mut replayed = 0;
70
71 for commit_id in &to_replay {
72 let new_id = cherry_pick(obj_store, commits_table, commit_id, ¤t_head, author)?;
73 current_head = new_id;
74 replayed += 1;
75 }
76
77 Ok(RebaseResult {
78 new_head: current_head,
79 replayed,
80 })
81}
82
83#[cfg(test)]
84mod tests {
85 use super::*;
86 use crate::{CommitsTable, GitObjectStore, checkout, create_commit};
87 use arrow_graph_core::Triple;
88
89 fn make_triple(s: &str, p: &str, o: &str) -> Triple {
90 Triple {
91 subject: s.to_string(),
92 predicate: p.to_string(),
93 object: o.to_string(),
94 graph: None,
95 confidence: Some(1.0),
96 source_document: None,
97 source_chunk_id: None,
98 extracted_by: None,
99 caused_by: None,
100 derived_from: None,
101 consolidated_at: None,
102 }
103 }
104
105 #[test]
106 fn test_rebase_linear_chain() {
107 let tmp = tempfile::tempdir().unwrap();
108 let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
109 let mut commits = CommitsTable::new();
110
111 obj.store
112 .add_triple(&make_triple("a", "r", "1"), "world", Some(1u8))
113 .unwrap();
114 let base = create_commit(&obj, &mut commits, vec![], "base", "test").unwrap();
115
116 obj.store
117 .add_triple(&make_triple("b", "r", "2"), "world", Some(1u8))
118 .unwrap();
119 let c1 = create_commit(
120 &obj,
121 &mut commits,
122 vec![base.commit_id.clone()],
123 "c1",
124 "test",
125 )
126 .unwrap();
127
128 obj.store
129 .add_triple(&make_triple("c", "r", "3"), "world", Some(1u8))
130 .unwrap();
131 let c2 =
132 create_commit(&obj, &mut commits, vec![c1.commit_id.clone()], "c2", "test").unwrap();
133
134 checkout(&mut obj, &commits, &base.commit_id).unwrap();
135 obj.store
136 .add_triple(&make_triple("d", "r", "4"), "world", Some(1u8))
137 .unwrap();
138 let new_base = create_commit(
139 &obj,
140 &mut commits,
141 vec![base.commit_id.clone()],
142 "new_base",
143 "test",
144 )
145 .unwrap();
146
147 let result = rebase(
148 &mut obj,
149 &mut commits,
150 &base.commit_id,
151 &c2.commit_id,
152 &new_base.commit_id,
153 "test",
154 )
155 .unwrap();
156
157 assert_eq!(result.replayed, 2);
158 assert_ne!(result.new_head, c2.commit_id);
159 }
160
161 #[test]
162 fn test_rebase_nothing_to_rebase() {
163 let tmp = tempfile::tempdir().unwrap();
164 let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
165 let mut commits = CommitsTable::new();
166
167 obj.store
168 .add_triple(&make_triple("a", "r", "1"), "world", Some(1u8))
169 .unwrap();
170 let base = create_commit(&obj, &mut commits, vec![], "base", "test").unwrap();
171
172 let result = rebase(
173 &mut obj,
174 &mut commits,
175 &base.commit_id,
176 &base.commit_id,
177 &base.commit_id,
178 "test",
179 );
180 assert!(result.is_err());
181 }
182
183 #[test]
184 fn test_rebase_single_commit() {
185 let tmp = tempfile::tempdir().unwrap();
186 let mut obj = GitObjectStore::with_snapshot_dir(tmp.path());
187 let mut commits = CommitsTable::new();
188
189 obj.store
190 .add_triple(&make_triple("a", "r", "1"), "world", Some(1u8))
191 .unwrap();
192 let base = create_commit(&obj, &mut commits, vec![], "base", "test").unwrap();
193
194 obj.store
195 .add_triple(&make_triple("b", "r", "2"), "world", Some(1u8))
196 .unwrap();
197 let c1 = create_commit(
198 &obj,
199 &mut commits,
200 vec![base.commit_id.clone()],
201 "c1",
202 "test",
203 )
204 .unwrap();
205
206 checkout(&mut obj, &commits, &base.commit_id).unwrap();
207 obj.store
208 .add_triple(&make_triple("x", "r", "9"), "world", Some(1u8))
209 .unwrap();
210 let new_base = create_commit(
211 &obj,
212 &mut commits,
213 vec![base.commit_id.clone()],
214 "new_base",
215 "test",
216 )
217 .unwrap();
218
219 let result = rebase(
220 &mut obj,
221 &mut commits,
222 &base.commit_id,
223 &c1.commit_id,
224 &new_base.commit_id,
225 "test",
226 )
227 .unwrap();
228
229 assert_eq!(result.replayed, 1);
230 }
231}