1use git2::{BranchType, Repository};
2use scopetime::scope_time;
3
4use super::{CommitId, RepoPath};
5use crate::{
6 error::{Error, Result},
7 sync::repository::repo,
8};
9
10pub fn rebase_branch(
12 repo_path: &RepoPath,
13 branch: &str,
14 branch_type: BranchType,
15) -> Result<RebaseState> {
16 scope_time!("rebase_branch");
17
18 let repo = repo(repo_path)?;
19
20 rebase_branch_repo(&repo, branch, branch_type)
21}
22
23fn rebase_branch_repo(
24 repo: &Repository,
25 branch_name: &str,
26 branch_type: BranchType,
27) -> Result<RebaseState> {
28 let branch = repo.find_branch(branch_name, branch_type)?;
29
30 let annotated = repo.reference_to_annotated_commit(&branch.into_reference())?;
31
32 rebase(repo, &annotated)
33}
34
35pub fn conflict_free_rebase(
38 repo: &git2::Repository,
39 commit: &git2::AnnotatedCommit,
40) -> Result<CommitId> {
41 let mut rebase = repo.rebase(None, Some(commit), None, None)?;
42 let signature = crate::sync::commit::signature_allow_undefined_name(repo)?;
43 let mut last_commit = None;
44 while let Some(op) = rebase.next() {
45 let _op = op?;
46
47 if repo.index()?.has_conflicts() {
48 rebase.abort()?;
49 return Err(Error::RebaseConflict);
50 }
51
52 let c = rebase.commit(None, &signature, None)?;
53
54 last_commit = Some(CommitId::from(c));
55 }
56
57 if repo.index()?.has_conflicts() {
58 rebase.abort()?;
59 return Err(Error::RebaseConflict);
60 }
61
62 rebase.finish(Some(&signature))?;
63
64 last_commit.ok_or_else(|| Error::Generic(String::from("no commit rebased")))
65}
66
67#[derive(PartialEq, Eq, Debug)]
69pub enum RebaseState {
70 Finished,
72 Conflicted,
74}
75
76pub fn rebase(repo: &git2::Repository, commit: &git2::AnnotatedCommit) -> Result<RebaseState> {
78 let mut rebase = repo.rebase(None, Some(commit), None, None)?;
79 let signature = crate::sync::commit::signature_allow_undefined_name(repo)?;
80
81 while let Some(op) = rebase.next() {
82 let _op = op?;
83 if repo.index()?.has_conflicts() {
86 return Ok(RebaseState::Conflicted);
87 }
88
89 rebase.commit(None, &signature, None)?;
90 }
91
92 if repo.index()?.has_conflicts() {
93 return Ok(RebaseState::Conflicted);
94 }
95
96 rebase.finish(Some(&signature))?;
97
98 Ok(RebaseState::Finished)
99}
100
101pub fn continue_rebase(repo: &git2::Repository) -> Result<RebaseState> {
103 let mut rebase = repo.open_rebase(None)?;
104 let signature = crate::sync::commit::signature_allow_undefined_name(repo)?;
105
106 if repo.index()?.has_conflicts() {
107 return Ok(RebaseState::Conflicted);
108 }
109
110 if !repo.index()?.is_empty() {
112 rebase.commit(None, &signature, None)?;
113 }
114
115 while let Some(op) = rebase.next() {
116 let _op = op?;
117 if repo.index()?.has_conflicts() {
120 return Ok(RebaseState::Conflicted);
121 }
122
123 rebase.commit(None, &signature, None)?;
124 }
125
126 if repo.index()?.has_conflicts() {
127 return Ok(RebaseState::Conflicted);
128 }
129
130 rebase.finish(Some(&signature))?;
131
132 Ok(RebaseState::Finished)
133}
134
135#[derive(PartialEq, Eq, Debug)]
137pub struct RebaseProgress {
138 pub steps: usize,
140 pub current: usize,
142 pub current_commit: Option<CommitId>,
144}
145
146pub fn get_rebase_progress(repo: &git2::Repository) -> Result<RebaseProgress> {
148 let mut rebase = repo.open_rebase(None)?;
149
150 let current_commit: Option<CommitId> = rebase
151 .operation_current()
152 .and_then(|idx| rebase.nth(idx))
153 .map(|op| op.id().into());
154
155 let progress = RebaseProgress {
156 steps: rebase.len(),
157 current: rebase.operation_current().unwrap_or_default(),
158 current_commit,
159 };
160
161 Ok(progress)
162}
163
164pub fn abort_rebase(repo: &git2::Repository) -> Result<()> {
166 let mut rebase = repo.open_rebase(None)?;
167
168 rebase.abort()?;
169
170 Ok(())
171}
172
173#[cfg(test)]
174mod test_conflict_free_rebase {
175 use git2::{BranchType, Repository};
176
177 use super::conflict_free_rebase;
178 use crate::sync::{
179 checkout_branch, create_branch,
180 rebase::{rebase_branch, RebaseState},
181 repo_state,
182 repository::repo,
183 tests::{repo_init, write_commit_file},
184 CommitId, RepoPath, RepoState,
185 };
186
187 fn parent_ids(repo: &Repository, c: CommitId) -> Vec<CommitId> {
188 let foo = repo
189 .find_commit(c.into())
190 .unwrap()
191 .parent_ids()
192 .map(CommitId::from)
193 .collect();
194
195 foo
196 }
197
198 fn test_rebase_branch_repo(repo_path: &RepoPath, branch_name: &str) -> CommitId {
200 let repo = repo(repo_path).unwrap();
201
202 let branch = repo.find_branch(branch_name, BranchType::Local).unwrap();
203
204 let annotated = repo
205 .reference_to_annotated_commit(&branch.into_reference())
206 .unwrap();
207
208 conflict_free_rebase(&repo, &annotated).unwrap()
209 }
210
211 #[test]
212 fn test_smoke() {
213 let (_td, repo) = repo_init().unwrap();
214 let root = repo.path().parent().unwrap();
215 let repo_path: &RepoPath = &root.as_os_str().to_str().unwrap().into();
216
217 let c1 = write_commit_file(&repo, "test1.txt", "test", "commit1");
218
219 create_branch(repo_path, "foo").unwrap();
220
221 let c2 = write_commit_file(&repo, "test2.txt", "test", "commit2");
222
223 assert_eq!(parent_ids(&repo, c2), vec![c1]);
224
225 checkout_branch(repo_path, "master").unwrap();
226
227 let c3 = write_commit_file(&repo, "test3.txt", "test", "commit3");
228
229 checkout_branch(repo_path, "foo").unwrap();
230
231 let r = test_rebase_branch_repo(repo_path, "master");
232
233 assert_eq!(parent_ids(&repo, r), vec![c3]);
234 }
235
236 #[test]
237 fn test_conflict() {
238 let (_td, repo) = repo_init().unwrap();
239 let root = repo.path().parent().unwrap();
240 let repo_path: &RepoPath = &root.as_os_str().to_str().unwrap().into();
241
242 write_commit_file(&repo, "test.txt", "test1", "commit1");
243
244 create_branch(repo_path, "foo").unwrap();
245
246 write_commit_file(&repo, "test.txt", "test2", "commit2");
247
248 checkout_branch(repo_path, "master").unwrap();
249
250 write_commit_file(&repo, "test.txt", "test3", "commit3");
251
252 checkout_branch(repo_path, "foo").unwrap();
253
254 let res = rebase_branch(repo_path, "master", BranchType::Local);
255
256 assert!(matches!(res.unwrap(), RebaseState::Conflicted));
257
258 assert_eq!(repo_state(repo_path).unwrap(), RepoState::Rebase);
259 }
260}
261
262#[cfg(test)]
263mod test_rebase {
264 use git2::BranchType;
265
266 use crate::sync::{
267 checkout_branch, create_branch,
268 rebase::{abort_rebase, get_rebase_progress, RebaseProgress, RebaseState},
269 rebase_branch, repo_state,
270 tests::{repo_init, write_commit_file},
271 RepoPath, RepoState,
272 };
273
274 #[test]
275 fn test_conflicted_abort() {
276 let (_td, repo) = repo_init().unwrap();
277 let root = repo.path().parent().unwrap();
278 let repo_path: &RepoPath = &root.as_os_str().to_str().unwrap().into();
279
280 write_commit_file(&repo, "test.txt", "test1", "commit1");
281
282 create_branch(repo_path, "foo").unwrap();
283
284 let c = write_commit_file(&repo, "test.txt", "test2", "commit2");
285
286 checkout_branch(repo_path, "master").unwrap();
287
288 write_commit_file(&repo, "test.txt", "test3", "commit3");
289
290 checkout_branch(repo_path, "foo").unwrap();
291
292 assert!(get_rebase_progress(&repo).is_err());
293
294 let r = rebase_branch(repo_path, "master", BranchType::Local).unwrap();
297
298 assert_eq!(r, RebaseState::Conflicted);
299 assert_eq!(repo_state(repo_path).unwrap(), RepoState::Rebase);
300 assert_eq!(
301 get_rebase_progress(&repo).unwrap(),
302 RebaseProgress {
303 current: 0,
304 steps: 1,
305 current_commit: Some(c)
306 }
307 );
308
309 abort_rebase(&repo).unwrap();
312
313 assert_eq!(repo_state(repo_path).unwrap(), RepoState::Clean);
314 }
315}