1use std::fs;
16
17use git2::{
18 DescribeFormatOptions, DescribeOptions, Error, ErrorCode, Oid, Repository, RepositoryState,
19 StatusOptions,
20};
21
22use crate::branch_status::BranchStatus;
23use crate::status_entry_ext::StatusEntryExt;
24
25pub trait RepositoryExt {
26 fn action(&self, state: RepositoryState) -> Option<&'static str>;
27 fn branch_name(&self) -> Result<String, Error>;
28 fn branch_status(&self) -> Result<BranchStatus, Error>;
29 fn rebase_head_name(&self, state: RepositoryState) -> Result<Option<String>, Error>;
30 fn unborn_branch_name(&self) -> Result<Option<String>, Error>;
31 fn tag_name(&self) -> Result<Option<String>, Error>;
32 fn to_short_oid(&self, oid: Oid) -> Result<Option<String>, Error>;
33}
34
35impl RepositoryExt for Repository {
36 fn action(&self, state: RepositoryState) -> Option<&'static str> {
37 match state {
38 RepositoryState::ApplyMailbox => Some("am"),
39 RepositoryState::ApplyMailboxOrRebase => Some("am/rebase"),
40 RepositoryState::Bisect => Some("bisect"),
41 RepositoryState::CherryPick => Some("cherry"),
42 RepositoryState::CherryPickSequence => Some("cherry-seq"),
43 RepositoryState::Merge => Some("merge"),
44 RepositoryState::Rebase => Some("rebase"),
45 RepositoryState::RebaseInteractive => Some("rebase-i"),
46 RepositoryState::RebaseMerge => Some("rebase-m"),
47 RepositoryState::Revert => Some("revert"),
48 RepositoryState::RevertSequence => Some("revert-seq"),
49 _ => None,
50 }
51 }
52
53 fn branch_name(&self) -> Result<String, Error> {
54 let head = match self.head() {
55 Ok(head) => Some(head),
56 Err(ref e)
57 if e.code() == ErrorCode::UnbornBranch || e.code() == ErrorCode::NotFound =>
58 {
59 None
60 }
61 Err(e) => return Err(e),
62 };
63
64 let detached = self.head_detached()?;
65 let state = self.state();
66
67 let branch = if let Some(name) = self.rebase_head_name(state)? {
68 name
69 } else if detached {
70 if let Some(tag) = self.tag_name()? {
72 tag
73 } else {
74 let oid = head.as_ref().and_then(|h| h.target());
75 let short = match oid.map(|oid| self.to_short_oid(oid)) {
76 Some(Ok(id)) => id,
77 Some(Err(e)) => return Err(e),
78 None => None,
79 };
80
81 short.unwrap_or_else(|| "HEAD (detached)".to_string())
82 }
83 } else if let Some(name) = head.as_ref().and_then(|h| h.shorthand().ok()) {
84 name.to_string()
85 } else {
86 self.unborn_branch_name()?
89 .unwrap_or_else(|| "HEAD (no branch)".to_string())
90 };
91
92 match self.action(state) {
93 Some(action) => Ok(branch + ":" + action),
94 None => Ok(branch),
95 }
96 }
97
98 fn branch_status(&self) -> Result<BranchStatus, Error> {
99 let mut opts = StatusOptions::new();
100 opts.include_untracked(false)
101 .include_ignored(false)
102 .include_unmodified(false)
103 .exclude_submodules(true);
104
105 let stats = self.statuses(Some(&mut opts))?;
106
107 let status = stats.iter().fold(BranchStatus::NotChanged, |acc, s| {
108 if acc < BranchStatus::Conflicted && s.is_conflicted() {
109 BranchStatus::Conflicted
110 } else if acc < BranchStatus::Unstaged && s.is_unstaged() {
111 BranchStatus::Unstaged
112 } else if acc < BranchStatus::Staged && s.is_staged() {
113 BranchStatus::Staged
114 } else {
115 acc
116 }
117 });
118
119 Ok(status)
120 }
121
122 fn rebase_head_name(&self, state: RepositoryState) -> Result<Option<String>, Error> {
123 let dir = match state {
127 RepositoryState::RebaseInteractive | RepositoryState::RebaseMerge => "rebase-merge",
128 RepositoryState::Rebase => "rebase-apply",
129 _ => return Ok(None),
130 };
131
132 let path = self.path().join(dir).join("head-name");
133 let refname = match fs::read_to_string(&path) {
134 Ok(content) => content.trim().to_string(),
135 Err(_) => return Ok(None),
136 };
137
138 let name = match self.find_reference(&refname) {
139 Ok(reference) => reference.shorthand().unwrap_or(&refname).to_string(),
140 Err(_) => refname
141 .strip_prefix("refs/heads/")
142 .unwrap_or(&refname)
143 .to_string(),
144 };
145
146 Ok(Some(name))
147 }
148
149 fn unborn_branch_name(&self) -> Result<Option<String>, Error> {
150 let reference = self.find_reference("HEAD")?;
151 Ok(reference.symbolic_target()?.map(|target| {
152 target
153 .strip_prefix("refs/heads/")
154 .unwrap_or(target)
155 .to_string()
156 }))
157 }
158
159 fn tag_name(&self) -> Result<Option<String>, Error> {
160 let mut opts = DescribeOptions::new();
164 opts.describe_tags().max_candidates_tags(0);
165
166 let describe = match self.describe(&opts) {
167 Ok(describe) => describe,
168 Err(_) => return Ok(None),
169 };
170
171 let mut format = DescribeFormatOptions::new();
172 format.abbreviated_size(0);
173
174 Ok(describe.format(Some(&format)).ok())
175 }
176
177 fn to_short_oid(&self, oid: Oid) -> Result<Option<String>, Error> {
178 let object = self.find_object(oid, None)?;
179 match object.short_id() {
180 Ok(id) => Ok(id.as_str().map(|i| i.to_string()).ok()),
181 Err(e) => Err(e),
182 }
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::RepositoryExt;
189 use git2::{Oid, Repository, RepositoryInitOptions, RepositoryState, Signature};
190 use std::fs;
191 use tempfile::TempDir;
192
193 fn init_repo() -> (TempDir, Repository) {
194 let dir = TempDir::new().unwrap();
195 let mut opts = RepositoryInitOptions::new();
196 opts.initial_head("main");
197 let repo = Repository::init_opts(dir.path(), &opts).unwrap();
198 (dir, repo)
199 }
200
201 fn commit(repo: &Repository, message: &str) -> Oid {
202 let sig = Signature::now("tester", "tester@example.com").unwrap();
203 let tree_id = repo.index().unwrap().write_tree().unwrap();
204 let tree = repo.find_tree(tree_id).unwrap();
205 let parents = match repo.head() {
206 Ok(head) => vec![head.peel_to_commit().unwrap()],
207 Err(_) => vec![],
208 };
209 let parent_refs: Vec<&_> = parents.iter().collect();
210 repo.commit(Some("HEAD"), &sig, &sig, message, &tree, &parent_refs)
211 .unwrap()
212 }
213
214 #[test]
215 fn branch_name_returns_branch_on_unborn_branch() {
216 let (_dir, repo) = init_repo();
217 assert_eq!(repo.branch_name().unwrap(), "main");
218 }
219
220 #[test]
221 fn branch_name_returns_branch_on_normal_branch() {
222 let (_dir, repo) = init_repo();
223 commit(&repo, "initial");
224 assert_eq!(repo.branch_name().unwrap(), "main");
225 }
226
227 #[test]
228 fn branch_name_returns_short_hash_on_detached_head() {
229 let (_dir, repo) = init_repo();
230 let oid = commit(&repo, "initial");
231 repo.set_head_detached(oid).unwrap();
232 let short = repo.to_short_oid(oid).unwrap().unwrap();
233 assert_eq!(repo.branch_name().unwrap(), short);
234 }
235
236 #[test]
237 fn branch_name_returns_tag_on_detached_head_at_tag() {
238 let (_dir, repo) = init_repo();
239 let oid = commit(&repo, "initial");
240 let object = repo.find_object(oid, None).unwrap();
241 repo.tag_lightweight("v1.0.0", &object, false).unwrap();
242 repo.set_head_detached(oid).unwrap();
243 assert_eq!(repo.branch_name().unwrap(), "v1.0.0");
244 }
245
246 #[test]
247 fn tag_name_returns_none_without_tag() {
248 let (_dir, repo) = init_repo();
249 let oid = commit(&repo, "initial");
250 repo.set_head_detached(oid).unwrap();
251 assert_eq!(repo.tag_name().unwrap(), None);
252 }
253
254 #[test]
255 fn tag_name_returns_tag_at_head() {
256 let (_dir, repo) = init_repo();
257 let oid = commit(&repo, "initial");
258 let object = repo.find_object(oid, None).unwrap();
259 repo.tag_lightweight("v1.0.0", &object, false).unwrap();
260 repo.set_head_detached(oid).unwrap();
261 assert_eq!(repo.tag_name().unwrap(), Some("v1.0.0".to_string()));
262 }
263
264 #[test]
265 fn unborn_branch_name_returns_symbolic_target() {
266 let (_dir, repo) = init_repo();
267 assert_eq!(repo.unborn_branch_name().unwrap(), Some("main".to_string()));
268 }
269
270 #[test]
271 fn rebase_head_name_returns_none_when_not_rebasing() {
272 let (_dir, repo) = init_repo();
273 assert_eq!(repo.rebase_head_name(RepositoryState::Clean).unwrap(), None);
274 }
275
276 #[test]
277 fn rebase_head_name_reads_apply_backend_head_name() {
278 let (_dir, repo) = init_repo();
279 let rebase_dir = repo.path().join("rebase-apply");
280 fs::create_dir_all(&rebase_dir).unwrap();
281 fs::write(rebase_dir.join("head-name"), "refs/heads/feature\n").unwrap();
282
283 assert_eq!(
284 repo.rebase_head_name(RepositoryState::Rebase).unwrap(),
285 Some("feature".to_string())
286 );
287 }
288
289 #[test]
290 fn rebase_head_name_reads_merge_backend_head_name() {
291 let (_dir, repo) = init_repo();
292 let rebase_dir = repo.path().join("rebase-merge");
293 fs::create_dir_all(&rebase_dir).unwrap();
294 fs::write(rebase_dir.join("head-name"), "refs/heads/feature\n").unwrap();
295
296 assert_eq!(
297 repo.rebase_head_name(RepositoryState::RebaseInteractive)
298 .unwrap(),
299 Some("feature".to_string())
300 );
301 assert_eq!(
302 repo.rebase_head_name(RepositoryState::RebaseMerge).unwrap(),
303 Some("feature".to_string())
304 );
305 }
306}