1use std::env;
20use std::fmt;
21use std::fs::File;
22use std::path::{Path, PathBuf};
23use std::process::Command;
24
25const GIT_BONSAI_DEBUG: &str = "GB_DEBUG";
27
28const WORKTREE_BRANCH_PREFIX: &str = "+ ";
31
32#[derive(Debug, PartialEq, Eq)]
33pub enum GitError {
34 FailedToRunGit,
35 CommandFailed { exit_code: i32 },
36 TerminatedBySignal,
37 UnexpectedOutput(String),
38}
39
40impl fmt::Display for GitError {
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 match self {
43 GitError::FailedToRunGit => {
44 write!(f, "Failed to run git")
45 }
46 GitError::CommandFailed { exit_code: e } => {
47 write!(f, "Command exited with code {}", e)
48 }
49 GitError::TerminatedBySignal => {
50 write!(f, "Terminated by signal")
51 }
52 GitError::UnexpectedOutput(message) => {
53 write!(f, "UnexpectedOutput: {}", message)
54 }
55 }
56 }
57}
58
59pub struct BranchRestorer<'a> {
64 repository: &'a Repository,
65 branch: String,
66}
67
68impl BranchRestorer<'_> {
69 pub fn new(repo: &Repository) -> BranchRestorer {
70 let current_branch = repo.get_current_branch().expect("Can't get current branch");
71 BranchRestorer {
72 repository: repo,
73 branch: current_branch,
74 }
75 }
76}
77
78impl Drop for BranchRestorer<'_> {
79 fn drop(&mut self) {
80 if let Err(_x) = self.repository.checkout(&self.branch) {
81 println!("Failed to restore original branch {}", self.branch);
82 }
83 }
84}
85
86pub struct Repository {
87 pub path: PathBuf,
88}
89
90impl Repository {
91 pub fn new(path: &Path) -> Repository {
92 Repository {
93 path: path.to_path_buf(),
94 }
95 }
96
97 #[allow(dead_code)]
98 pub fn clone(path: &Path, url: &str) -> Result<Repository, GitError> {
99 let repo = Repository::new(path);
100 repo.git("clone", &[url, path.to_str().unwrap()])?;
101 Ok(repo)
102 }
103
104 pub fn git(&self, subcommand: &str, args: &[&str]) -> Result<String, GitError> {
105 let mut cmd = Command::new("git");
106 cmd.current_dir(&self.path);
107 cmd.env("LANG", "C");
108 cmd.arg(subcommand);
109 for arg in args {
110 cmd.arg(arg);
111 }
112 if env::var(GIT_BONSAI_DEBUG).is_ok() {
113 eprintln!(
114 "DEBUG: pwd={}: git {} {}",
115 self.path.to_str().unwrap(),
116 subcommand,
117 args.join(" ")
118 );
119 }
120 let output = match cmd.output() {
121 Ok(x) => x,
122 Err(_x) => {
123 println!("Failed to execute process");
124 return Err(GitError::FailedToRunGit);
125 }
126 };
127 if !output.status.success() {
128 println!(
130 "{}",
131 String::from_utf8(output.stderr).expect("Failed to decode command stderr")
132 );
133 return match output.status.code() {
134 Some(code) => Err(GitError::CommandFailed { exit_code: code }),
135 None => Err(GitError::TerminatedBySignal),
136 };
137 }
138 let out = String::from_utf8(output.stdout).expect("Failed to decode command stdout");
139 Ok(out)
140 }
141
142 pub fn fetch(&self) -> Result<(), GitError> {
143 self.git("fetch", &["--prune"])?;
144 Ok(())
145 }
146
147 pub fn get_config_keys(&self, key: &str) -> Result<Vec<String>, GitError> {
149 let stdout = match self.git("config", &["--get-all", key]) {
150 Ok(x) => x,
151 Err(x) => match x {
152 GitError::CommandFailed { exit_code: 1 } => {
153 return Ok([].to_vec());
155 }
156 x => {
157 return Err(x);
158 }
159 },
160 };
161
162 let values: Vec<String> = stdout.lines().map(|x| x.into()).collect();
163 Ok(values)
164 }
165
166 pub fn set_config_key(&self, key: &str, value: &str) -> Result<(), GitError> {
167 self.git("config", &[key, value])?;
168 Ok(())
169 }
170
171 pub fn find_default_branch(&self) -> Result<String, GitError> {
172 let stdout = self.git("ls-remote", &["--symref", "origin", "HEAD"])?;
173 let line = stdout.lines().next().ok_or_else(|| {
181 GitError::UnexpectedOutput("ls-remote returned an empty string".to_string())
182 })?;
183
184 let line = line
185 .strip_prefix("ref: refs/heads/")
186 .ok_or_else(|| GitError::UnexpectedOutput("missing prefix".to_string()))?;
187
188 let line = line
189 .strip_suffix("\tHEAD")
190 .ok_or_else(|| GitError::UnexpectedOutput("missing suffix".to_string()))?;
191
192 Ok(line.to_string())
193 }
194
195 pub fn list_branches(&self) -> Result<Vec<String>, GitError> {
196 self.list_branches_internal(&[])
197 }
198
199 pub fn list_branches_with_sha1s(&self) -> Result<Vec<(String, String)>, GitError> {
200 let mut list: Vec<(String, String)> = Vec::new();
201
202 let lines = self.list_branches_internal(&["-v"])?;
203
204 for line in lines {
205 let mut it = line.split_whitespace();
206 let branch = it.next().unwrap().to_string();
207 let sha1 = it.next().unwrap().to_string();
208 list.push((branch, sha1));
209 }
210 Ok(list)
211 }
212
213 fn list_branches_internal(&self, args: &[&str]) -> Result<Vec<String>, GitError> {
214 let mut branches: Vec<String> = Vec::new();
215
216 let stdout = self.git("branch", args)?;
217
218 for line in stdout.lines() {
219 if line.starts_with(WORKTREE_BRANCH_PREFIX) {
220 continue;
221 }
222 let branch = line.get(2..).expect("Invalid branch name");
223 branches.push(branch.to_string());
224 }
225 Ok(branches)
226 }
227
228 pub fn list_branches_containing(&self, commit: &str) -> Result<Vec<String>, GitError> {
229 self.list_branches_internal(&["--contains", commit])
230 }
231
232 pub fn list_tracking_branches(&self) -> Result<Vec<String>, GitError> {
233 let mut branches: Vec<String> = Vec::new();
234
235 let lines = self.list_branches_internal(&["-vv"])?;
236
237 for line in lines {
238 if line.contains("[origin/") && !line.contains(": gone]") {
239 let branch = line.split(' ').next();
240 branches.push(branch.unwrap().to_string());
241 }
242 }
243 Ok(branches)
244 }
245
246 pub fn checkout(&self, branch: &str) -> Result<(), GitError> {
247 self.git("checkout", &[branch])?;
248 Ok(())
249 }
250
251 pub fn delete_branch(&self, branch: &str) -> Result<(), GitError> {
252 self.git("branch", &["-D", branch])?;
253 Ok(())
254 }
255
256 pub fn get_current_branch(&self) -> Option<String> {
257 let stdout = self.git("branch", &[]);
258 if stdout.is_err() {
259 return None;
260 }
261 for line in stdout.unwrap().lines() {
262 if line.starts_with('*') {
263 return Some(line[2..].to_string());
264 }
265 }
266 None
267 }
268
269 pub fn update_branch(&self) -> Result<(), GitError> {
270 let out = self.git("merge", &["--ff-only"])?;
271 println!("{}", out);
272 Ok(())
273 }
274
275 pub fn has_changes(&self) -> Result<bool, GitError> {
276 let out = self.git("status", &["--short"])?;
277 Ok(!out.is_empty())
278 }
279
280 #[allow(dead_code)]
281 pub fn get_current_sha1(&self) -> Result<String, GitError> {
282 let out = self.git("show", &["--no-patch", "--oneline"])?;
283 let sha1 = out.split(' ').next().unwrap().to_string();
284 Ok(sha1)
285 }
286}
287
288#[allow(dead_code)]
290pub fn create_test_repository(path: &Path) -> Repository {
291 let repo = Repository::new(path);
292
293 repo.git("init", &[]).expect("init failed");
294 repo.git("config", &["user.name", "test"])
295 .expect("setting username failed");
296 repo.git("config", &["user.email", "test@example.com"])
297 .expect("setting email failed");
298
299 File::create(path.join("f")).unwrap();
301 repo.git("add", &["."]).expect("add failed");
302 repo.git("commit", &["-m", "init"]).expect("commit failed");
303
304 repo
305}
306
307#[cfg(test)]
308mod tests {
309 extern crate assert_fs;
310
311 use super::*;
312 use std::fs;
313
314 #[test]
315 fn get_current_branch() {
316 let dir = assert_fs::TempDir::new().unwrap();
317 let repo = create_test_repository(dir.path());
318 assert_eq!(repo.get_current_branch().unwrap(), "master");
319
320 repo.git("checkout", &["-b", "test"])
321 .expect("create branch failed");
322 assert_eq!(repo.get_current_branch().unwrap(), "test");
323 }
324
325 #[test]
326 fn delete_branch() {
327 let dir = assert_fs::TempDir::new().unwrap();
329 let repo = create_test_repository(dir.path());
330 assert_eq!(repo.get_current_branch().unwrap(), "master");
331
332 repo.git("checkout", &["-b", "test"]).unwrap();
333 File::create(dir.path().join("test")).unwrap();
334 repo.git("add", &["test"]).unwrap();
335 repo.git("commit", &["-m", &format!("Create file")])
336 .unwrap();
337
338 repo.checkout("master").unwrap();
339
340 let result = repo.delete_branch("test");
342
343 assert_eq!(result, Ok(()));
345
346 assert_eq!(repo.list_branches().unwrap(), &["master"]);
348 }
349
350 #[test]
351 fn list_branches_with_sha1s() {
352 let dir = assert_fs::TempDir::new().unwrap();
354 let repo = create_test_repository(dir.path());
355
356 repo.git("checkout", &["-b", "test"]).unwrap();
357 File::create(dir.path().join("test")).unwrap();
358 repo.git("add", &["test"]).unwrap();
359 repo.git("commit", &["-m", &format!("Create file")])
360 .unwrap();
361
362 let branches_with_sha1 = repo.list_branches_with_sha1s().unwrap();
364
365 assert_eq!(branches_with_sha1.len(), 2);
367
368 for (branch, sha1) in branches_with_sha1 {
370 repo.git("checkout", &[&branch]).unwrap();
371 assert_eq!(repo.get_current_sha1().unwrap(), sha1);
372 }
373 }
374
375 #[test]
376 fn list_branches_skip_worktree_branches() {
377 let tmp_dir = assert_fs::TempDir::new().unwrap();
379
380 let source_path = tmp_dir.path().join("source");
381 fs::create_dir_all(&source_path).unwrap();
382 let source_repo = create_test_repository(&source_path);
383 source_repo.git("branch", &["topic1"]).unwrap();
384
385 let clone_path = tmp_dir.path().join("clone");
387 fs::create_dir_all(&clone_path).unwrap();
388 let clone_repo = Repository::clone(&clone_path, &source_path.to_str().unwrap()).unwrap();
389
390 let worktree_dir = assert_fs::TempDir::new().unwrap();
392 let worktree_path_str = worktree_dir.path().to_str().unwrap();
393 clone_repo
394 .git("worktree", &["add", worktree_path_str, "topic1"])
395 .unwrap();
396
397 let branches = clone_repo.list_branches().unwrap();
399
400 assert_eq!(branches.len(), 1);
402 assert_eq!(branches, &["master"]);
403 }
404
405 #[test]
406 fn find_default_branch_happy_path() {
407 let tmp_dir = assert_fs::TempDir::new().unwrap();
409 let source_path = tmp_dir.path().join("source");
410 fs::create_dir_all(&source_path).unwrap();
411 create_test_repository(&source_path);
412
413 let clone_path = tmp_dir.path().join("clone");
415 fs::create_dir_all(&clone_path).unwrap();
416 let clone_repo = Repository::clone(&clone_path, &source_path.to_str().unwrap()).unwrap();
417
418 let branch = clone_repo.find_default_branch();
420
421 assert_eq!(branch, Ok("master".to_string()));
423 }
424
425 #[test]
426 fn find_default_branch_no_remote() {
427 let tmp_dir = assert_fs::TempDir::new().unwrap();
429 let repo = create_test_repository(&tmp_dir.path());
430
431 let branch = repo.find_default_branch();
433
434 assert_eq!(branch, Err(GitError::CommandFailed { exit_code: 128 }));
436 }
437}