git_bonsai/
git.rs

1/*
2 * Copyright 2020 Aurélien Gâteau <mail@agateau.com>
3 *
4 * This file is part of git-bonsai.
5 *
6 * Git-bonsai is free software: you can redistribute it and/or modify it under
7 * the terms of the GNU General Public License as published by the Free
8 * Software Foundation, either version 3 of the License, or (at your option)
9 * any later version.
10 *
11 * This program is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
14 * more details.
15 *
16 * You should have received a copy of the GNU General Public License along with
17 * this program.  If not, see <http://www.gnu.org/licenses/>.
18 */
19use std::env;
20use std::fmt;
21use std::fs::File;
22use std::path::{Path, PathBuf};
23use std::process::Command;
24
25// Define this environment variable to print all executed git commands to stderr
26const GIT_BONSAI_DEBUG: &str = "GB_DEBUG";
27
28// If a branch is checked out in a separate worktree, then `git branch` prefixes it with this
29// string
30const WORKTREE_BRANCH_PREFIX: &str = "+ ";
31
32#[derive(Debug, PartialEq, Eq)]
33pub enum GitError {
34    FailedToRunGit,
35    CommandFailed { exit_code: i32 },
36    TerminatedBySignal,
37    UnexpectedOutput(String),
38}
39
40impl fmt::Display for GitError {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        match self {
43            GitError::FailedToRunGit => {
44                write!(f, "Failed to run git")
45            }
46            GitError::CommandFailed { exit_code: e } => {
47                write!(f, "Command exited with code {}", e)
48            }
49            GitError::TerminatedBySignal => {
50                write!(f, "Terminated by signal")
51            }
52            GitError::UnexpectedOutput(message) => {
53                write!(f, "UnexpectedOutput: {}", message)
54            }
55        }
56    }
57}
58
59/**
60 * Restores the current git branch when dropped
61 * Assumes we are on a real branch
62 */
63pub struct BranchRestorer<'a> {
64    repository: &'a Repository,
65    branch: String,
66}
67
68impl BranchRestorer<'_> {
69    pub fn new(repo: &Repository) -> BranchRestorer {
70        let current_branch = repo.get_current_branch().expect("Can't get current branch");
71        BranchRestorer {
72            repository: repo,
73            branch: current_branch,
74        }
75    }
76}
77
78impl Drop for BranchRestorer<'_> {
79    fn drop(&mut self) {
80        if let Err(_x) = self.repository.checkout(&self.branch) {
81            println!("Failed to restore original branch {}", self.branch);
82        }
83    }
84}
85
86pub struct Repository {
87    pub path: PathBuf,
88}
89
90impl Repository {
91    pub fn new(path: &Path) -> Repository {
92        Repository {
93            path: path.to_path_buf(),
94        }
95    }
96
97    #[allow(dead_code)]
98    pub fn clone(path: &Path, url: &str) -> Result<Repository, GitError> {
99        let repo = Repository::new(path);
100        repo.git("clone", &[url, path.to_str().unwrap()])?;
101        Ok(repo)
102    }
103
104    pub fn git(&self, subcommand: &str, args: &[&str]) -> Result<String, GitError> {
105        let mut cmd = Command::new("git");
106        cmd.current_dir(&self.path);
107        cmd.env("LANG", "C");
108        cmd.arg(subcommand);
109        for arg in args {
110            cmd.arg(arg);
111        }
112        if env::var(GIT_BONSAI_DEBUG).is_ok() {
113            eprintln!(
114                "DEBUG: pwd={}: git {} {}",
115                self.path.to_str().unwrap(),
116                subcommand,
117                args.join(" ")
118            );
119        }
120        let output = match cmd.output() {
121            Ok(x) => x,
122            Err(_x) => {
123                println!("Failed to execute process");
124                return Err(GitError::FailedToRunGit);
125            }
126        };
127        if !output.status.success() {
128            // TODO: store error message in GitError
129            println!(
130                "{}",
131                String::from_utf8(output.stderr).expect("Failed to decode command stderr")
132            );
133            return match output.status.code() {
134                Some(code) => Err(GitError::CommandFailed { exit_code: code }),
135                None => Err(GitError::TerminatedBySignal),
136            };
137        }
138        let out = String::from_utf8(output.stdout).expect("Failed to decode command stdout");
139        Ok(out)
140    }
141
142    pub fn fetch(&self) -> Result<(), GitError> {
143        self.git("fetch", &["--prune"])?;
144        Ok(())
145    }
146
147    /// Reads config keys defined with `git config --add <key> <value>`
148    pub fn get_config_keys(&self, key: &str) -> Result<Vec<String>, GitError> {
149        let stdout = match self.git("config", &["--get-all", key]) {
150            Ok(x) => x,
151            Err(x) => match x {
152                GitError::CommandFailed { exit_code: 1 } => {
153                    // Happens when reading a non-existing key
154                    return Ok([].to_vec());
155                }
156                x => {
157                    return Err(x);
158                }
159            },
160        };
161
162        let values: Vec<String> = stdout.lines().map(|x| x.into()).collect();
163        Ok(values)
164    }
165
166    pub fn set_config_key(&self, key: &str, value: &str) -> Result<(), GitError> {
167        self.git("config", &[key, value])?;
168        Ok(())
169    }
170
171    pub fn find_default_branch(&self) -> Result<String, GitError> {
172        let stdout = self.git("ls-remote", &["--symref", "origin", "HEAD"])?;
173        /* Output looks like this:
174         *
175         * ref: refs/heads/master\tHEAD
176         * 960389f1c69e8b9c3fe06d29866d0d193375a6cb\tHEAD
177         *
178         * We want to extra "master" from the first line
179         */
180        let line = stdout.lines().next().ok_or_else(|| {
181            GitError::UnexpectedOutput("ls-remote returned an empty string".to_string())
182        })?;
183
184        let line = line
185            .strip_prefix("ref: refs/heads/")
186            .ok_or_else(|| GitError::UnexpectedOutput("missing prefix".to_string()))?;
187
188        let line = line
189            .strip_suffix("\tHEAD")
190            .ok_or_else(|| GitError::UnexpectedOutput("missing suffix".to_string()))?;
191
192        Ok(line.to_string())
193    }
194
195    pub fn list_branches(&self) -> Result<Vec<String>, GitError> {
196        self.list_branches_internal(&[])
197    }
198
199    pub fn list_branches_with_sha1s(&self) -> Result<Vec<(String, String)>, GitError> {
200        let mut list: Vec<(String, String)> = Vec::new();
201
202        let lines = self.list_branches_internal(&["-v"])?;
203
204        for line in lines {
205            let mut it = line.split_whitespace();
206            let branch = it.next().unwrap().to_string();
207            let sha1 = it.next().unwrap().to_string();
208            list.push((branch, sha1));
209        }
210        Ok(list)
211    }
212
213    fn list_branches_internal(&self, args: &[&str]) -> Result<Vec<String>, GitError> {
214        let mut branches: Vec<String> = Vec::new();
215
216        let stdout = self.git("branch", args)?;
217
218        for line in stdout.lines() {
219            if line.starts_with(WORKTREE_BRANCH_PREFIX) {
220                continue;
221            }
222            let branch = line.get(2..).expect("Invalid branch name");
223            branches.push(branch.to_string());
224        }
225        Ok(branches)
226    }
227
228    pub fn list_branches_containing(&self, commit: &str) -> Result<Vec<String>, GitError> {
229        self.list_branches_internal(&["--contains", commit])
230    }
231
232    pub fn list_tracking_branches(&self) -> Result<Vec<String>, GitError> {
233        let mut branches: Vec<String> = Vec::new();
234
235        let lines = self.list_branches_internal(&["-vv"])?;
236
237        for line in lines {
238            if line.contains("[origin/") && !line.contains(": gone]") {
239                let branch = line.split(' ').next();
240                branches.push(branch.unwrap().to_string());
241            }
242        }
243        Ok(branches)
244    }
245
246    pub fn checkout(&self, branch: &str) -> Result<(), GitError> {
247        self.git("checkout", &[branch])?;
248        Ok(())
249    }
250
251    pub fn delete_branch(&self, branch: &str) -> Result<(), GitError> {
252        self.git("branch", &["-D", branch])?;
253        Ok(())
254    }
255
256    pub fn get_current_branch(&self) -> Option<String> {
257        let stdout = self.git("branch", &[]);
258        if stdout.is_err() {
259            return None;
260        }
261        for line in stdout.unwrap().lines() {
262            if line.starts_with('*') {
263                return Some(line[2..].to_string());
264            }
265        }
266        None
267    }
268
269    pub fn update_branch(&self) -> Result<(), GitError> {
270        let out = self.git("merge", &["--ff-only"])?;
271        println!("{}", out);
272        Ok(())
273    }
274
275    pub fn has_changes(&self) -> Result<bool, GitError> {
276        let out = self.git("status", &["--short"])?;
277        Ok(!out.is_empty())
278    }
279
280    #[allow(dead_code)]
281    pub fn get_current_sha1(&self) -> Result<String, GitError> {
282        let out = self.git("show", &["--no-patch", "--oneline"])?;
283        let sha1 = out.split(' ').next().unwrap().to_string();
284        Ok(sha1)
285    }
286}
287
288// Used by test code
289#[allow(dead_code)]
290pub fn create_test_repository(path: &Path) -> Repository {
291    let repo = Repository::new(path);
292
293    repo.git("init", &[]).expect("init failed");
294    repo.git("config", &["user.name", "test"])
295        .expect("setting username failed");
296    repo.git("config", &["user.email", "test@example.com"])
297        .expect("setting email failed");
298
299    // Create a file so that we have more than the start commit
300    File::create(path.join("f")).unwrap();
301    repo.git("add", &["."]).expect("add failed");
302    repo.git("commit", &["-m", "init"]).expect("commit failed");
303
304    repo
305}
306
307#[cfg(test)]
308mod tests {
309    extern crate assert_fs;
310
311    use super::*;
312    use std::fs;
313
314    #[test]
315    fn get_current_branch() {
316        let dir = assert_fs::TempDir::new().unwrap();
317        let repo = create_test_repository(dir.path());
318        assert_eq!(repo.get_current_branch().unwrap(), "master");
319
320        repo.git("checkout", &["-b", "test"])
321            .expect("create branch failed");
322        assert_eq!(repo.get_current_branch().unwrap(), "test");
323    }
324
325    #[test]
326    fn delete_branch() {
327        // GIVEN a repository with a test branch containing unique content
328        let dir = assert_fs::TempDir::new().unwrap();
329        let repo = create_test_repository(dir.path());
330        assert_eq!(repo.get_current_branch().unwrap(), "master");
331
332        repo.git("checkout", &["-b", "test"]).unwrap();
333        File::create(dir.path().join("test")).unwrap();
334        repo.git("add", &["test"]).unwrap();
335        repo.git("commit", &["-m", &format!("Create file")])
336            .unwrap();
337
338        repo.checkout("master").unwrap();
339
340        // WHEN I call delete_branch
341        let result = repo.delete_branch("test");
342
343        // THEN the branch is deleted
344        assert_eq!(result, Ok(()));
345
346        // AND only the master branch remains
347        assert_eq!(repo.list_branches().unwrap(), &["master"]);
348    }
349
350    #[test]
351    fn list_branches_with_sha1s() {
352        // GIVEN a repository with two branches
353        let dir = assert_fs::TempDir::new().unwrap();
354        let repo = create_test_repository(dir.path());
355
356        repo.git("checkout", &["-b", "test"]).unwrap();
357        File::create(dir.path().join("test")).unwrap();
358        repo.git("add", &["test"]).unwrap();
359        repo.git("commit", &["-m", &format!("Create file")])
360            .unwrap();
361
362        // WHEN I list branches with sha1
363        let branches_with_sha1 = repo.list_branches_with_sha1s().unwrap();
364
365        // THEN the list contains two entries
366        assert_eq!(branches_with_sha1.len(), 2);
367
368        // AND when switching to each branch, the current sha1 is the expected one
369        for (branch, sha1) in branches_with_sha1 {
370            repo.git("checkout", &[&branch]).unwrap();
371            assert_eq!(repo.get_current_sha1().unwrap(), sha1);
372        }
373    }
374
375    #[test]
376    fn list_branches_skip_worktree_branches() {
377        // GIVEN a source repository with two branches
378        let tmp_dir = assert_fs::TempDir::new().unwrap();
379
380        let source_path = tmp_dir.path().join("source");
381        fs::create_dir_all(&source_path).unwrap();
382        let source_repo = create_test_repository(&source_path);
383        source_repo.git("branch", &["topic1"]).unwrap();
384
385        // AND a clone of this repository
386        let clone_path = tmp_dir.path().join("clone");
387        fs::create_dir_all(&clone_path).unwrap();
388        let clone_repo = Repository::clone(&clone_path, &source_path.to_str().unwrap()).unwrap();
389
390        // with the topic1 branch checked-out in a separate worktree
391        let worktree_dir = assert_fs::TempDir::new().unwrap();
392        let worktree_path_str = worktree_dir.path().to_str().unwrap();
393        clone_repo
394            .git("worktree", &["add", worktree_path_str, "topic1"])
395            .unwrap();
396
397        // WHEN I list branches
398        let branches = clone_repo.list_branches().unwrap();
399
400        // THEN it does not list worktree branches
401        assert_eq!(branches.len(), 1);
402        assert_eq!(branches, &["master"]);
403    }
404
405    #[test]
406    fn find_default_branch_happy_path() {
407        // GIVEN a source repository
408        let tmp_dir = assert_fs::TempDir::new().unwrap();
409        let source_path = tmp_dir.path().join("source");
410        fs::create_dir_all(&source_path).unwrap();
411        create_test_repository(&source_path);
412
413        // AND a clone of this repository
414        let clone_path = tmp_dir.path().join("clone");
415        fs::create_dir_all(&clone_path).unwrap();
416        let clone_repo = Repository::clone(&clone_path, &source_path.to_str().unwrap()).unwrap();
417
418        // WHEN I call find_default_branch() on the clone
419        let branch = clone_repo.find_default_branch();
420
421        // THEN it finds the default branch name
422        assert_eq!(branch, Ok("master".to_string()));
423    }
424
425    #[test]
426    fn find_default_branch_no_remote() {
427        // GIVEN a repository without a remote
428        let tmp_dir = assert_fs::TempDir::new().unwrap();
429        let repo = create_test_repository(&tmp_dir.path());
430
431        // WHEN I call find_default_branch()
432        let branch = repo.find_default_branch();
433
434        // THEN it fails
435        assert_eq!(branch, Err(GitError::CommandFailed { exit_code: 128 }));
436    }
437}