Skip to main content

git_branch_status/
repository_ext.rs

1// Copyright 2021 Akiomi Kamakura
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fs;
16
17use git2::{
18    DescribeFormatOptions, DescribeOptions, Error, ErrorCode, Oid, Repository, RepositoryState,
19    StatusOptions,
20};
21
22use crate::branch_status::BranchStatus;
23use crate::status_entry_ext::StatusEntryExt;
24
25pub trait RepositoryExt {
26    fn action(&self, state: RepositoryState) -> Option<&'static str>;
27    fn branch_name(&self) -> Result<String, Error>;
28    fn branch_status(&self) -> Result<BranchStatus, Error>;
29    fn rebase_head_name(&self, state: RepositoryState) -> Result<Option<String>, Error>;
30    fn unborn_branch_name(&self) -> Result<Option<String>, Error>;
31    fn tag_name(&self) -> Result<Option<String>, Error>;
32    fn to_short_oid(&self, oid: Oid) -> Result<Option<String>, Error>;
33}
34
35impl RepositoryExt for Repository {
36    fn action(&self, state: RepositoryState) -> Option<&'static str> {
37        match state {
38            RepositoryState::ApplyMailbox => Some("am"),
39            RepositoryState::ApplyMailboxOrRebase => Some("am/rebase"),
40            RepositoryState::Bisect => Some("bisect"),
41            RepositoryState::CherryPick => Some("cherry"),
42            RepositoryState::CherryPickSequence => Some("cherry-seq"),
43            RepositoryState::Merge => Some("merge"),
44            RepositoryState::Rebase => Some("rebase"),
45            RepositoryState::RebaseInteractive => Some("rebase-i"),
46            RepositoryState::RebaseMerge => Some("rebase-m"),
47            RepositoryState::Revert => Some("revert"),
48            RepositoryState::RevertSequence => Some("revert-seq"),
49            _ => None,
50        }
51    }
52
53    fn branch_name(&self) -> Result<String, Error> {
54        let head = match self.head() {
55            Ok(head) => Some(head),
56            Err(ref e)
57                if e.code() == ErrorCode::UnbornBranch || e.code() == ErrorCode::NotFound =>
58            {
59                None
60            }
61            Err(e) => return Err(e),
62        };
63
64        let detached = self.head_detached()?;
65        let state = self.state();
66
67        let branch = if let Some(name) = self.rebase_head_name(state)? {
68            name
69        } else if detached {
70            // Prefer a tag pointing at HEAD, falling back to the short hash.
71            if let Some(tag) = self.tag_name()? {
72                tag
73            } else {
74                let oid = head.as_ref().and_then(|h| h.target());
75                let short = match oid.map(|oid| self.to_short_oid(oid)) {
76                    Some(Ok(id)) => id,
77                    Some(Err(e)) => return Err(e),
78                    None => None,
79                };
80
81                short.unwrap_or_else(|| "HEAD (detached)".to_string())
82            }
83        } else if let Some(name) = head.as_ref().and_then(|h| h.shorthand().ok()) {
84            name.to_string()
85        } else {
86            // An unborn branch (e.g. a freshly initialized repository) has no
87            // HEAD commit, so resolve the branch name from the symbolic HEAD.
88            self.unborn_branch_name()?
89                .unwrap_or_else(|| "HEAD (no branch)".to_string())
90        };
91
92        match self.action(state) {
93            Some(action) => Ok(branch + ":" + action),
94            None => Ok(branch),
95        }
96    }
97
98    fn branch_status(&self) -> Result<BranchStatus, Error> {
99        let mut opts = StatusOptions::new();
100        opts.include_untracked(false)
101            .include_ignored(false)
102            .include_unmodified(false)
103            .exclude_submodules(true);
104
105        let stats = self.statuses(Some(&mut opts))?;
106
107        let status = stats.iter().fold(BranchStatus::NotChanged, |acc, s| {
108            if acc < BranchStatus::Conflicted && s.is_conflicted() {
109                BranchStatus::Conflicted
110            } else if acc < BranchStatus::Unstaged && s.is_unstaged() {
111                BranchStatus::Unstaged
112            } else if acc < BranchStatus::Staged && s.is_staged() {
113                BranchStatus::Staged
114            } else {
115                acc
116            }
117        });
118
119        Ok(status)
120    }
121
122    fn rebase_head_name(&self, state: RepositoryState) -> Result<Option<String>, Error> {
123        // The original branch name is recorded in a `head-name` file while a
124        // rebase is in progress. The merge backend (and interactive rebases)
125        // use `rebase-merge/`, while the apply backend uses `rebase-apply/`.
126        let dir = match state {
127            RepositoryState::RebaseInteractive | RepositoryState::RebaseMerge => "rebase-merge",
128            RepositoryState::Rebase => "rebase-apply",
129            _ => return Ok(None),
130        };
131
132        let path = self.path().join(dir).join("head-name");
133        let refname = match fs::read_to_string(&path) {
134            Ok(content) => content.trim().to_string(),
135            Err(_) => return Ok(None),
136        };
137
138        let name = match self.find_reference(&refname) {
139            Ok(reference) => reference.shorthand().unwrap_or(&refname).to_string(),
140            Err(_) => refname
141                .strip_prefix("refs/heads/")
142                .unwrap_or(&refname)
143                .to_string(),
144        };
145
146        Ok(Some(name))
147    }
148
149    fn unborn_branch_name(&self) -> Result<Option<String>, Error> {
150        let reference = self.find_reference("HEAD")?;
151        Ok(reference.symbolic_target()?.map(|target| {
152            target
153                .strip_prefix("refs/heads/")
154                .unwrap_or(target)
155                .to_string()
156        }))
157    }
158
159    fn tag_name(&self) -> Result<Option<String>, Error> {
160        // `max_candidates_tags(0)` makes this behave like `git describe
161        // --exact-match`: it resolves only a tag that points at HEAD and
162        // errors otherwise, which we treat as "no tag".
163        let mut opts = DescribeOptions::new();
164        opts.describe_tags().max_candidates_tags(0);
165
166        let describe = match self.describe(&opts) {
167            Ok(describe) => describe,
168            Err(_) => return Ok(None),
169        };
170
171        let mut format = DescribeFormatOptions::new();
172        format.abbreviated_size(0);
173
174        Ok(describe.format(Some(&format)).ok())
175    }
176
177    fn to_short_oid(&self, oid: Oid) -> Result<Option<String>, Error> {
178        let object = self.find_object(oid, None)?;
179        match object.short_id() {
180            Ok(id) => Ok(id.as_str().map(|i| i.to_string()).ok()),
181            Err(e) => Err(e),
182        }
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::RepositoryExt;
189    use git2::{Oid, Repository, RepositoryInitOptions, RepositoryState, Signature};
190    use std::fs;
191    use tempfile::TempDir;
192
193    fn init_repo() -> (TempDir, Repository) {
194        let dir = TempDir::new().unwrap();
195        let mut opts = RepositoryInitOptions::new();
196        opts.initial_head("main");
197        let repo = Repository::init_opts(dir.path(), &opts).unwrap();
198        (dir, repo)
199    }
200
201    fn commit(repo: &Repository, message: &str) -> Oid {
202        let sig = Signature::now("tester", "tester@example.com").unwrap();
203        let tree_id = repo.index().unwrap().write_tree().unwrap();
204        let tree = repo.find_tree(tree_id).unwrap();
205        let parents = match repo.head() {
206            Ok(head) => vec![head.peel_to_commit().unwrap()],
207            Err(_) => vec![],
208        };
209        let parent_refs: Vec<&_> = parents.iter().collect();
210        repo.commit(Some("HEAD"), &sig, &sig, message, &tree, &parent_refs)
211            .unwrap()
212    }
213
214    #[test]
215    fn branch_name_returns_branch_on_unborn_branch() {
216        let (_dir, repo) = init_repo();
217        assert_eq!(repo.branch_name().unwrap(), "main");
218    }
219
220    #[test]
221    fn branch_name_returns_branch_on_normal_branch() {
222        let (_dir, repo) = init_repo();
223        commit(&repo, "initial");
224        assert_eq!(repo.branch_name().unwrap(), "main");
225    }
226
227    #[test]
228    fn branch_name_returns_short_hash_on_detached_head() {
229        let (_dir, repo) = init_repo();
230        let oid = commit(&repo, "initial");
231        repo.set_head_detached(oid).unwrap();
232        let short = repo.to_short_oid(oid).unwrap().unwrap();
233        assert_eq!(repo.branch_name().unwrap(), short);
234    }
235
236    #[test]
237    fn branch_name_returns_tag_on_detached_head_at_tag() {
238        let (_dir, repo) = init_repo();
239        let oid = commit(&repo, "initial");
240        let object = repo.find_object(oid, None).unwrap();
241        repo.tag_lightweight("v1.0.0", &object, false).unwrap();
242        repo.set_head_detached(oid).unwrap();
243        assert_eq!(repo.branch_name().unwrap(), "v1.0.0");
244    }
245
246    #[test]
247    fn tag_name_returns_none_without_tag() {
248        let (_dir, repo) = init_repo();
249        let oid = commit(&repo, "initial");
250        repo.set_head_detached(oid).unwrap();
251        assert_eq!(repo.tag_name().unwrap(), None);
252    }
253
254    #[test]
255    fn tag_name_returns_tag_at_head() {
256        let (_dir, repo) = init_repo();
257        let oid = commit(&repo, "initial");
258        let object = repo.find_object(oid, None).unwrap();
259        repo.tag_lightweight("v1.0.0", &object, false).unwrap();
260        repo.set_head_detached(oid).unwrap();
261        assert_eq!(repo.tag_name().unwrap(), Some("v1.0.0".to_string()));
262    }
263
264    #[test]
265    fn unborn_branch_name_returns_symbolic_target() {
266        let (_dir, repo) = init_repo();
267        assert_eq!(repo.unborn_branch_name().unwrap(), Some("main".to_string()));
268    }
269
270    #[test]
271    fn rebase_head_name_returns_none_when_not_rebasing() {
272        let (_dir, repo) = init_repo();
273        assert_eq!(repo.rebase_head_name(RepositoryState::Clean).unwrap(), None);
274    }
275
276    #[test]
277    fn rebase_head_name_reads_apply_backend_head_name() {
278        let (_dir, repo) = init_repo();
279        let rebase_dir = repo.path().join("rebase-apply");
280        fs::create_dir_all(&rebase_dir).unwrap();
281        fs::write(rebase_dir.join("head-name"), "refs/heads/feature\n").unwrap();
282
283        assert_eq!(
284            repo.rebase_head_name(RepositoryState::Rebase).unwrap(),
285            Some("feature".to_string())
286        );
287    }
288
289    #[test]
290    fn rebase_head_name_reads_merge_backend_head_name() {
291        let (_dir, repo) = init_repo();
292        let rebase_dir = repo.path().join("rebase-merge");
293        fs::create_dir_all(&rebase_dir).unwrap();
294        fs::write(rebase_dir.join("head-name"), "refs/heads/feature\n").unwrap();
295
296        assert_eq!(
297            repo.rebase_head_name(RepositoryState::RebaseInteractive)
298                .unwrap(),
299            Some("feature".to_string())
300        );
301        assert_eq!(
302            repo.rebase_head_name(RepositoryState::RebaseMerge).unwrap(),
303            Some("feature".to_string())
304        );
305    }
306}