1use std::{fmt, path};
2
3use anyhow::*;
4use git2::build::CheckoutBuilder;
5use std::result::Result::Ok;
6
7#[derive(Debug, Clone, PartialEq)]
8pub enum MergeResult {
9 UpToDate,
10 FastForward,
11 Merged,
12 Rebased,
13 Conflicts,
14}
15
16#[derive(Debug, Clone, PartialEq)]
17pub enum RemoteComparison {
18 UpToDate,
19 Ahead(usize),
20 Behind(usize),
21 Diverged(usize, usize),
22 NoRemote,
23}
24
25pub struct Repo {
26 pub git_repo: git2::Repository,
27 pub work_dir: path::PathBuf,
28 pub head: String,
29 pub subrepos: Vec<Repo>,
30}
31
32impl Repo {
33 pub fn new(work_dir: &path::Path, head_name: Option<&str>) -> Result<Self> {
34 let git_repo = git2::Repository::open(work_dir)
35 .with_context(|| format!("Cannot open repo at `{}`", work_dir.display()))?;
36
37 let head = match head_name {
38 Some(name) => String::from(name),
39 None => {
40 if git_repo.head_detached().with_context(|| {
41 format!(
42 "Cannot determine head state for repo at `{}`",
43 work_dir.display()
44 )
45 })? {
46 bail!(
47 "Cannot operate on a detached head for repo at `{}`",
48 work_dir.display()
49 )
50 }
51
52 String::from(git_repo.head().with_context(|| {
53 format!(
54 "Cannot find the head branch for repo at `{}`. Is it detached?",
55 work_dir.display()
56 )
57 })?.shorthand().with_context(|| {
58 format!(
59 "Cannot find a human readable representation of the head ref for repo at `{}`",
60 work_dir.display(),
61 )
62 })?)
63 },
64 };
65
66 let subrepos = git_repo
67 .submodules()
68 .with_context(|| {
69 format!(
70 "Cannot load submodules for repo at `{}`",
71 work_dir.display()
72 )
73 })?
74 .iter()
75 .map(|submodule| Repo::new(&work_dir.join(submodule.path()), Some(&head)))
76 .collect::<Result<Vec<Repo>>>()?;
77
78 Ok(Repo {
79 git_repo,
80 work_dir: path::PathBuf::from(work_dir),
81 head,
82 subrepos,
83 })
84 }
85
86 pub fn get_subrepo_by_path(&self, subrepo_path: &path::PathBuf) -> Option<&Repo> {
87 self.subrepos
88 .iter()
89 .find(|subrepo| subrepo.work_dir == self.work_dir.join(subrepo_path))
90 }
91
92 pub fn sync(&self) -> Result<()> {
93 self.switch(&self.head)?;
94 Ok(())
95 }
96
97 pub fn switch(&self, head: &str) -> Result<()> {
98 self.git_repo.set_head(&self.resolve_reference(head)?)?;
99 self.git_repo.checkout_head(None)?;
100 Ok(())
101 }
102
103 pub fn fetch(&self) -> Result<()> {
104 let head_ref = self.git_repo.head()?;
106 let branch_name = head_ref.shorthand().with_context(|| {
107 format!(
108 "Cannot get branch name for repo at `{}`",
109 self.work_dir.display()
110 )
111 })?;
112
113 let tracking = match self.tracking_branch(branch_name)? {
114 Some(tracking) => tracking,
115 None => {
116 return Ok(());
118 },
119 };
120
121 match self.git_repo.find_remote(&tracking.remote) {
123 Ok(mut remote) => {
124 let mut fetch_options = git2::FetchOptions::new();
125 fetch_options.remote_callbacks(self.remote_callbacks()?);
126
127 remote
128 .fetch::<&str>(&[], Some(&mut fetch_options), None)
129 .with_context(|| {
130 format!(
131 "Failed to fetch from remote '{}' for repo at `{}`",
132 tracking.remote,
133 self.work_dir.display()
134 )
135 })?;
136 },
137 Err(_) => {
138 return Ok(());
140 },
141 }
142
143 Ok(())
144 }
145
146 fn rebase(
147 &self,
148 _branch_name: &str,
149 remote_commit: &git2::Commit,
150 ) -> Result<MergeResult> {
151 let _local_commit = self.git_repo.head()?.peel_to_commit()?;
152 let remote_oid = remote_commit.id();
153
154 let remote_annotated = self.git_repo.find_annotated_commit(remote_oid)?;
156
157 let signature = self.git_repo.signature()?;
159 let mut rebase = self.git_repo.rebase(
160 None, Some(&remote_annotated), None, None, )?;
165
166 let mut has_conflicts = false;
168 while let Some(op) = rebase.next() {
169 match op {
170 Ok(_rebase_op) => {
171 let index = self.git_repo.index()?;
173 if index.has_conflicts() {
174 has_conflicts = true;
175 break;
176 }
177
178 if rebase.commit(None, &signature, None).is_err() {
180 has_conflicts = true;
181 break;
182 }
183 },
184 Err(_) => {
185 has_conflicts = true;
186 break;
187 },
188 }
189 }
190
191 if has_conflicts {
192 return Ok(MergeResult::Conflicts);
194 }
195
196 rebase.finish(Some(&signature))?;
198
199 Ok(MergeResult::Rebased)
200 }
201
202 pub fn merge(&self, branch_name: &str) -> Result<MergeResult> {
203 self.fetch()?;
205
206 let tracking = match self.tracking_branch(branch_name)? {
208 Some(tracking) => tracking,
209 None => {
210 return Ok(MergeResult::UpToDate);
212 },
213 };
214
215 let remote_branch_oid = match self.git_repo.refname_to_id(&tracking.remote_ref)
217 {
218 Ok(oid) => oid,
219 Err(_) => {
220 return Ok(MergeResult::UpToDate);
222 },
223 };
224
225 let remote_commit = self.git_repo.find_commit(remote_branch_oid)?;
226 let local_commit = self.git_repo.head()?.peel_to_commit()?;
227
228 if local_commit.id() == remote_commit.id() {
230 return Ok(MergeResult::UpToDate);
231 }
232
233 if self
235 .git_repo
236 .graph_descendant_of(remote_commit.id(), local_commit.id())?
237 {
238 self.git_repo.reference(
240 &format!("refs/heads/{}", branch_name),
241 remote_commit.id(),
242 true,
243 &format!("Fast-forward '{}' to {}", branch_name, tracking.remote_ref),
244 )?;
245 self.git_repo
246 .set_head(&format!("refs/heads/{}", branch_name))?;
247 let mut checkout = CheckoutBuilder::new();
248 checkout.force();
249 self.git_repo.checkout_head(Some(&mut checkout))?;
250 return Ok(MergeResult::FastForward);
251 }
252
253 let pull_strategy = self.get_pull_strategy(branch_name)?;
255
256 match pull_strategy {
257 PullStrategy::Rebase => {
258 self.rebase(branch_name, &remote_commit)
260 },
261 PullStrategy::Merge => {
262 self.do_merge(branch_name, &local_commit, &remote_commit, &tracking)
264 },
265 }
266 }
267
268 fn do_merge(
269 &self,
270 branch_name: &str,
271 local_commit: &git2::Commit,
272 remote_commit: &git2::Commit,
273 tracking: &TrackingBranch,
274 ) -> Result<MergeResult> {
275 let mut merge_opts = git2::MergeOptions::new();
277 merge_opts.fail_on_conflict(false); let _merge_result = self.git_repo.merge_commits(
280 local_commit,
281 remote_commit,
282 Some(&merge_opts),
283 )?;
284
285 let mut index = self.git_repo.index()?;
287 let has_conflicts = index.has_conflicts();
288
289 if !has_conflicts {
290 let signature = self.git_repo.signature()?;
292 let tree_id = index.write_tree()?;
293 let tree = self.git_repo.find_tree(tree_id)?;
294
295 self.git_repo.commit(
296 Some(&format!("refs/heads/{}", branch_name)),
297 &signature,
298 &signature,
299 &format!("Merge remote-tracking branch '{}'", tracking.remote_ref),
300 &tree,
301 &[local_commit, remote_commit],
302 )?;
303
304 self.git_repo.cleanup_state()?;
305
306 Ok(MergeResult::Merged)
307 } else {
308 Ok(MergeResult::Conflicts)
310 }
311 }
312
313 pub fn get_remote_name_for_branch(&self, branch_name: &str) -> Result<String> {
314 if let Some(tracking) = self.tracking_branch(branch_name)? {
315 Ok(tracking.remote)
316 } else {
317 Ok("origin".to_string())
319 }
320 }
321
322 pub fn get_remote_comparison(
324 &self,
325 branch_name: &str,
326 ) -> Result<Option<RemoteComparison>> {
327 let tracking = match self.tracking_branch(branch_name)? {
329 Some(tracking) => tracking,
330 None => return Ok(None), };
332
333 let remote_oid = match self.git_repo.refname_to_id(&tracking.remote_ref) {
335 Ok(oid) => oid,
336 Err(_) => {
337 return Ok(Some(RemoteComparison::NoRemote));
339 },
340 };
341
342 let local_oid = self.git_repo.head()?.peel_to_commit()?.id();
344
345 if local_oid == remote_oid {
347 return Ok(Some(RemoteComparison::UpToDate));
348 }
349
350 let (ahead, behind) =
352 self.git_repo.graph_ahead_behind(local_oid, remote_oid)?;
353
354 if ahead > 0 && behind > 0 {
355 Ok(Some(RemoteComparison::Diverged(ahead, behind)))
356 } else if ahead > 0 {
357 Ok(Some(RemoteComparison::Ahead(ahead)))
358 } else if behind > 0 {
359 Ok(Some(RemoteComparison::Behind(behind)))
360 } else {
361 Ok(Some(RemoteComparison::UpToDate))
362 }
363 }
364
365 pub fn remote_callbacks(&self) -> Result<git2::RemoteCallbacks<'static>> {
366 let config = self.git_repo.config()?;
367
368 let mut callbacks = git2::RemoteCallbacks::new();
369 callbacks.credentials(move |url, username_from_url, allowed| {
370 if allowed.contains(git2::CredentialType::SSH_KEY)
371 && let Some(username) = username_from_url
372 && let Ok(cred) = git2::Cred::ssh_key_from_agent(username)
373 {
374 return Ok(cred);
375 }
376
377 if (allowed.contains(git2::CredentialType::USER_PASS_PLAINTEXT)
378 || allowed.contains(git2::CredentialType::SSH_KEY)
379 || allowed.contains(git2::CredentialType::DEFAULT))
380 && let Ok(cred) =
381 git2::Cred::credential_helper(&config, url, username_from_url)
382 {
383 return Ok(cred);
384 }
385
386 if allowed.contains(git2::CredentialType::USERNAME) {
387 if let Some(username) = username_from_url {
388 return git2::Cred::username(username);
389 } else {
390 return git2::Cred::username("git");
391 }
392 }
393
394 git2::Cred::default()
395 });
396
397 Ok(callbacks)
398 }
399
400 fn resolve_reference(&self, short_name: &str) -> Result<String> {
401 Ok(self
402 .git_repo
403 .resolve_reference_from_short_name(short_name)?
404 .name()
405 .with_context(|| {
406 format!(
407 "Cannot resolve head reference for repo at `{}`",
408 self.work_dir.display()
409 )
410 })?
411 .to_owned())
412 }
413
414 pub fn tracking_branch(&self, branch_name: &str) -> Result<Option<TrackingBranch>> {
415 let config = self.git_repo.config()?;
416
417 let remote_key = format!("branch.{}.remote", branch_name);
418 let merge_key = format!("branch.{}.merge", branch_name);
419
420 let remote = match config.get_string(&remote_key) {
421 Ok(name) => name,
422 Err(err) if err.code() == git2::ErrorCode::NotFound => return Ok(None),
423 Err(err) => return Err(err.into()),
424 };
425
426 let merge_ref = match config.get_string(&merge_key) {
427 Ok(name) => name,
428 Err(err) if err.code() == git2::ErrorCode::NotFound => return Ok(None),
429 Err(err) => return Err(err.into()),
430 };
431
432 let branch_short = merge_ref
433 .strip_prefix("refs/heads/")
434 .unwrap_or(&merge_ref)
435 .to_owned();
436
437 let remote_ref = format!("refs/remotes/{}/{}", remote, branch_short);
438
439 Ok(Some(TrackingBranch { remote, remote_ref }))
440 }
441
442 fn get_pull_strategy(&self, branch_name: &str) -> Result<PullStrategy> {
443 let config = self.git_repo.config()?;
444
445 let branch_rebase_key = format!("branch.{}.rebase", branch_name);
447 if let Ok(value) = config.get_string(&branch_rebase_key) {
448 return Ok(parse_rebase_config(&value));
449 }
450
451 if let Ok(value) = config.get_string("pull.rebase") {
453 return Ok(parse_rebase_config(&value));
454 }
455
456 if let Ok(value) = config.get_bool("pull.rebase") {
458 return Ok(if value {
459 PullStrategy::Rebase
460 } else {
461 PullStrategy::Merge
462 });
463 }
464
465 Ok(PullStrategy::Merge)
467 }
468}
469
470impl fmt::Debug for Repo {
471 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
472 f.debug_struct("Repo")
473 .field("work_dir", &self.work_dir)
474 .field("head", &self.head)
475 .field("subrepos", &self.subrepos)
476 .finish()
477 }
478}
479
480pub struct TrackingBranch {
481 pub remote: String,
482 pub remote_ref: String,
483}
484
485#[derive(Debug, Clone, PartialEq)]
486enum PullStrategy {
487 Merge,
488 Rebase,
489}
490
491fn parse_rebase_config(value: &str) -> PullStrategy {
492 match value.to_lowercase().as_str() {
493 "true" | "interactive" | "i" | "merges" | "m" => PullStrategy::Rebase,
494 "false" => PullStrategy::Merge,
495 _ => PullStrategy::Merge, }
497}