git_stack/
config.rs

1use std::str::FromStr;
2
3#[derive(Default, Clone, Debug)]
4pub struct RepoConfig {
5    pub editor: Option<String>,
6
7    pub protected_branches: Option<Vec<String>>,
8    pub protect_commit_count: Option<usize>,
9    pub protect_commit_age: Option<std::time::Duration>,
10    pub auto_base_commit_count: Option<usize>,
11    pub stack: Option<Stack>,
12    pub push_remote: Option<String>,
13    pub pull_remote: Option<String>,
14    pub show_format: Option<Format>,
15    pub show_commits: Option<ShowCommits>,
16    pub show_stacked: Option<bool>,
17    pub auto_fixup: Option<Fixup>,
18    pub auto_repair: Option<bool>,
19
20    pub capacity: Option<usize>,
21}
22
23static CORE_EDITOR: &str = "core.editor";
24static PROTECTED_STACK_FIELD: &str = "stack.protected-branch";
25static PROTECT_COMMIT_COUNT: &str = "stack.protect-commit-count";
26static PROTECT_COMMIT_AGE: &str = "stack.protect-commit-age";
27static AUTO_BASE_COMMIT_COUNT: &str = "stack.auto-base-commit-count";
28static STACK_FIELD: &str = "stack.stack";
29static PUSH_REMOTE_FIELD: &str = "stack.push-remote";
30static PULL_REMOTE_FIELD: &str = "stack.pull-remote";
31static FORMAT_FIELD: &str = "stack.show-format";
32static SHOW_COMMITS_FIELD: &str = "stack.show-commits";
33static STACKED_FIELD: &str = "stack.show-stacked";
34static AUTO_FIXUP_FIELD: &str = "stack.auto-fixup";
35static AUTO_REPAIR_FIELD: &str = "stack.auto-repair";
36static BACKUP_CAPACITY_FIELD: &str = "branch-stash.capacity";
37
38#[cfg(windows)]
39static DEFAULT_CORE_EDITOR: &str = "notepad.exe";
40#[cfg(not(windows))]
41static DEFAULT_CORE_EDITOR: &str = "vi";
42static DEFAULT_PROTECTED_BRANCHES: [&str; 4] = ["main", "master", "dev", "stable"];
43static DEFAULT_PROTECT_COMMIT_COUNT: usize = 50;
44static DEFAULT_PROTECT_COMMIT_AGE: std::time::Duration =
45    std::time::Duration::from_secs(60 * 60 * 24 * 14);
46static DEFAULT_AUTO_BASE_COMMIT_COUNT: usize = 500;
47const DEFAULT_CAPACITY: usize = 30;
48
49impl RepoConfig {
50    pub fn from_all(repo: &git2::Repository) -> eyre::Result<Self> {
51        log::trace!("Loading gitconfig");
52        let default_config = match git2::Config::open_default() {
53            Ok(config) => Some(config),
54            Err(err) => {
55                log::debug!("Failed to load git config: {}", err);
56                None
57            }
58        };
59        let config = Self::from_defaults_internal(default_config.as_ref());
60        let config = if let Some(default_config) = default_config.as_ref() {
61            config.update(Self::from_gitconfig(default_config))
62        } else {
63            config
64        };
65        let config = config.update(Self::from_workdir(repo)?);
66        let config = config.update(Self::from_repo(repo)?);
67        let config = config.update(Self::from_env());
68        Ok(config)
69    }
70
71    pub fn from_repo(repo: &git2::Repository) -> eyre::Result<Self> {
72        let config_path = git_dir_config(repo);
73        log::trace!("Loading {}", config_path.display());
74        if config_path.exists() {
75            match git2::Config::open(&config_path) {
76                Ok(config) => Ok(Self::from_gitconfig(&config)),
77                Err(err) => {
78                    log::debug!("Failed to load git config: {}", err);
79                    Ok(Default::default())
80                }
81            }
82        } else {
83            Ok(Default::default())
84        }
85    }
86
87    pub fn from_workdir(repo: &git2::Repository) -> eyre::Result<Self> {
88        let workdir = repo
89            .workdir()
90            .ok_or_else(|| eyre::eyre!("Cannot read config in bare repository."))?;
91        let config_path = workdir.join(".gitconfig");
92        log::trace!("Loading {}", config_path.display());
93        if config_path.exists() {
94            match git2::Config::open(&config_path) {
95                Ok(config) => Ok(Self::from_gitconfig(&config)),
96                Err(err) => {
97                    log::debug!("Failed to load git config: {}", err);
98                    Ok(Default::default())
99                }
100            }
101        } else {
102            Ok(Default::default())
103        }
104    }
105
106    pub fn from_env() -> Self {
107        let mut config = Self::default();
108
109        let params = git_config_env::ConfigParameters::new();
110        config = config.update(Self::from_env_iter(params.iter()));
111
112        let params = git_config_env::ConfigEnv::new();
113        config = config.update(Self::from_env_iter(
114            params.iter().map(|(k, v)| (k, Some(v))),
115        ));
116
117        config.editor = std::env::var("GIT_EDITOR").ok();
118
119        config
120    }
121
122    fn from_env_iter<'s>(
123        iter: impl Iterator<Item = (std::borrow::Cow<'s, str>, Option<std::borrow::Cow<'s, str>>)>,
124    ) -> Self {
125        let mut config = Self::default();
126
127        for (key, value) in iter {
128            log::trace!("Env config: {}={:?}", key, value);
129            if key == CORE_EDITOR {
130                if let Some(value) = value {
131                    config.editor = Some(value.into_owned());
132                }
133            } else if key == PROTECTED_STACK_FIELD {
134                if let Some(value) = value {
135                    config
136                        .protected_branches
137                        .get_or_insert_with(Vec::new)
138                        .push(value.into_owned());
139                }
140            } else if key == PROTECT_COMMIT_COUNT {
141                if let Some(value) = value.as_ref().and_then(|v| FromStr::from_str(v).ok()) {
142                    config.protect_commit_count = Some(value);
143                }
144            } else if key == PROTECT_COMMIT_AGE {
145                if let Some(value) = value
146                    .as_ref()
147                    .and_then(|v| humantime::parse_duration(v).ok())
148                {
149                    config.protect_commit_age = Some(value);
150                }
151            } else if key == AUTO_BASE_COMMIT_COUNT {
152                if let Some(value) = value.as_ref().and_then(|v| FromStr::from_str(v).ok()) {
153                    config.auto_base_commit_count = Some(value);
154                }
155            } else if key == STACK_FIELD {
156                if let Some(value) = value.as_ref().and_then(|v| FromStr::from_str(v).ok()) {
157                    config.stack = Some(value);
158                }
159            } else if key == PUSH_REMOTE_FIELD {
160                if let Some(value) = value {
161                    config.push_remote = Some(value.into_owned());
162                }
163            } else if key == PULL_REMOTE_FIELD {
164                if let Some(value) = value {
165                    config.pull_remote = Some(value.into_owned());
166                }
167            } else if key == FORMAT_FIELD {
168                if let Some(value) = value.as_ref().and_then(|v| FromStr::from_str(v).ok()) {
169                    config.show_format = Some(value);
170                }
171            } else if key == SHOW_COMMITS_FIELD {
172                if let Some(value) = value.as_ref().and_then(|v| FromStr::from_str(v).ok()) {
173                    config.show_commits = Some(value);
174                }
175            } else if key == STACKED_FIELD {
176                config.show_stacked = Some(value.as_ref().map(|v| v == "true").unwrap_or(true));
177            } else if key == AUTO_FIXUP_FIELD {
178                if let Some(value) = value.as_ref().and_then(|v| FromStr::from_str(v).ok()) {
179                    config.auto_fixup = Some(value);
180                }
181            } else if key == AUTO_REPAIR_FIELD {
182                config.auto_repair = Some(value.as_ref().map(|v| v == "true").unwrap_or(true));
183            } else if key == BACKUP_CAPACITY_FIELD {
184                config.capacity = value.as_deref().and_then(|s| s.parse::<usize>().ok());
185            } else {
186                log::warn!(
187                    "Unsupported config: {}={}",
188                    key,
189                    value.as_deref().unwrap_or("")
190                );
191            }
192        }
193
194        config
195    }
196
197    pub fn from_defaults() -> Self {
198        log::trace!("Loading gitconfig");
199        let config = match git2::Config::open_default() {
200            Ok(config) => Some(config),
201            Err(err) => {
202                log::debug!("Failed to load git config: {}", err);
203                None
204            }
205        };
206        Self::from_defaults_internal(config.as_ref())
207    }
208
209    fn from_defaults_internal(config: Option<&git2::Config>) -> Self {
210        let mut conf = Self::default();
211        conf.editor = std::env::var("VISUAL")
212            .or_else(|_err| std::env::var("EDITOR"))
213            .ok();
214        conf.protect_commit_count = Some(conf.protect_commit_count().unwrap_or(0));
215        conf.protect_commit_age = Some(conf.protect_commit_age());
216        conf.auto_base_commit_count = Some(conf.auto_base_commit_count().unwrap_or(0));
217        conf.stack = Some(conf.stack());
218        conf.push_remote = Some(conf.push_remote().to_owned());
219        conf.pull_remote = Some(conf.pull_remote().to_owned());
220        conf.show_format = Some(conf.show_format());
221        conf.show_commits = Some(conf.show_commits());
222        conf.show_stacked = Some(conf.show_stacked());
223        conf.auto_fixup = Some(conf.auto_fixup());
224        conf.capacity = Some(DEFAULT_CAPACITY);
225
226        let mut protected_branches: Vec<String> = Vec::new();
227
228        if let Some(config) = config {
229            let default_branch = default_branch(config);
230            let default_branch_ignore = default_branch.to_owned();
231            protected_branches.push(default_branch_ignore);
232        }
233        // Don't bother with removing duplicates if `default_branch` is the same as one of our
234        // default protected branches
235        protected_branches.extend(DEFAULT_PROTECTED_BRANCHES.iter().map(|s| (*s).to_owned()));
236        conf.protected_branches = Some(protected_branches);
237
238        conf
239    }
240
241    pub fn from_gitconfig(config: &git2::Config) -> Self {
242        let editor = config.get_string(CORE_EDITOR).ok();
243
244        let protected_branches = config
245            .multivar(PROTECTED_STACK_FIELD, None)
246            .map(|entries| {
247                let mut protected_branches = Vec::new();
248                entries
249                    .for_each(|entry| {
250                        if let Some(value) = entry.value() {
251                            protected_branches.push(value.to_owned());
252                        }
253                    })
254                    .unwrap();
255                if protected_branches.is_empty() {
256                    None
257                } else {
258                    Some(protected_branches)
259                }
260            })
261            .unwrap_or(None);
262
263        let protect_commit_count = config
264            .get_i64(PROTECT_COMMIT_COUNT)
265            .ok()
266            .map(|i| i.max(0) as usize);
267        let protect_commit_age = config
268            .get_string(PROTECT_COMMIT_AGE)
269            .ok()
270            .and_then(|s| humantime::parse_duration(&s).ok());
271
272        let auto_base_commit_count = config
273            .get_i64(AUTO_BASE_COMMIT_COUNT)
274            .ok()
275            .map(|i| i.max(0) as usize);
276
277        let push_remote = config
278            .get_string(PUSH_REMOTE_FIELD)
279            .ok()
280            .or_else(|| config.get_string("remote.pushDefault").ok());
281        let pull_remote = config.get_string(PULL_REMOTE_FIELD).ok();
282
283        let stack = config
284            .get_string(STACK_FIELD)
285            .ok()
286            .and_then(|s| FromStr::from_str(&s).ok());
287
288        let show_format = config
289            .get_string(FORMAT_FIELD)
290            .ok()
291            .and_then(|s| FromStr::from_str(&s).ok());
292
293        let show_commits = config
294            .get_string(SHOW_COMMITS_FIELD)
295            .ok()
296            .and_then(|s| FromStr::from_str(&s).ok());
297
298        let show_stacked = config.get_bool(STACKED_FIELD).ok();
299
300        let auto_fixup = config
301            .get_string(AUTO_FIXUP_FIELD)
302            .ok()
303            .and_then(|s| FromStr::from_str(&s).ok());
304
305        let auto_repair = config.get_bool(AUTO_REPAIR_FIELD).ok();
306
307        let capacity = config
308            .get_i64(BACKUP_CAPACITY_FIELD)
309            .map(|i| i as usize)
310            .ok();
311
312        Self {
313            editor,
314            protected_branches,
315            protect_commit_count,
316            protect_commit_age,
317            auto_base_commit_count,
318            stack,
319            push_remote,
320            pull_remote,
321            show_format,
322            show_commits,
323            show_stacked,
324            auto_fixup,
325            auto_repair,
326            capacity,
327        }
328    }
329
330    pub fn write_repo(&self, repo: &git2::Repository) -> eyre::Result<()> {
331        let config_path = git_dir_config(repo);
332        log::trace!("Loading {}", config_path.display());
333        let mut config = git2::Config::open(&config_path)?;
334        log::info!("Writing {}", config_path.display());
335        self.to_gitconfig(&mut config)?;
336        Ok(())
337    }
338
339    pub fn to_gitconfig(&self, config: &mut git2::Config) -> eyre::Result<()> {
340        if let Some(protected_branches) = self.protected_branches.as_ref() {
341            // Ignore errors if there aren't keys to remove
342            let _ = config.remove_multivar(PROTECTED_STACK_FIELD, ".*");
343            for branch in protected_branches {
344                config.set_multivar(PROTECTED_STACK_FIELD, "^$", branch)?;
345            }
346        }
347        Ok(())
348    }
349
350    pub fn update(mut self, other: Self) -> Self {
351        self.editor = other.editor.or(self.editor);
352        match (&mut self.protected_branches, other.protected_branches) {
353            (Some(lhs), Some(rhs)) => lhs.extend(rhs),
354            (None, Some(rhs)) => self.protected_branches = Some(rhs),
355            (_, _) => (),
356        }
357        self.protect_commit_count = other.protect_commit_count.or(self.protect_commit_count);
358        self.protect_commit_age = other.protect_commit_age.or(self.protect_commit_age);
359        self.auto_base_commit_count = other.auto_base_commit_count.or(self.auto_base_commit_count);
360        self.push_remote = other.push_remote.or(self.push_remote);
361        self.pull_remote = other.pull_remote.or(self.pull_remote);
362        self.stack = other.stack.or(self.stack);
363        self.show_format = other.show_format.or(self.show_format);
364        self.show_commits = other.show_commits.or(self.show_commits);
365        self.show_stacked = other.show_stacked.or(self.show_stacked);
366        self.auto_fixup = other.auto_fixup.or(self.auto_fixup);
367        self.auto_repair = other.auto_repair.or(self.auto_repair);
368        self.capacity = other.capacity.or(self.capacity);
369
370        self
371    }
372
373    pub fn editor(&self) -> &str {
374        self.editor.as_deref().unwrap_or(DEFAULT_CORE_EDITOR)
375    }
376
377    pub fn protected_branches(&self) -> &[String] {
378        self.protected_branches.as_deref().unwrap_or(&[])
379    }
380
381    pub fn protect_commit_count(&self) -> Option<usize> {
382        let protect_commit_count = self
383            .protect_commit_count
384            .unwrap_or(DEFAULT_PROTECT_COMMIT_COUNT);
385        (protect_commit_count != 0).then_some(protect_commit_count)
386    }
387
388    pub fn protect_commit_age(&self) -> std::time::Duration {
389        self.protect_commit_age
390            .unwrap_or(DEFAULT_PROTECT_COMMIT_AGE)
391    }
392
393    pub fn auto_base_commit_count(&self) -> Option<usize> {
394        let auto_base_commit_count = self
395            .auto_base_commit_count
396            .unwrap_or(DEFAULT_AUTO_BASE_COMMIT_COUNT);
397        (auto_base_commit_count != 0).then_some(auto_base_commit_count)
398    }
399
400    pub fn push_remote(&self) -> &str {
401        self.push_remote.as_deref().unwrap_or("origin")
402    }
403
404    pub fn pull_remote(&self) -> &str {
405        self.pull_remote
406            .as_deref()
407            .unwrap_or_else(|| self.push_remote())
408    }
409
410    pub fn stack(&self) -> Stack {
411        self.stack.unwrap_or_default()
412    }
413
414    pub fn show_format(&self) -> Format {
415        self.show_format.unwrap_or_default()
416    }
417
418    pub fn show_commits(&self) -> ShowCommits {
419        self.show_commits.unwrap_or_default()
420    }
421
422    pub fn show_stacked(&self) -> bool {
423        self.show_stacked.unwrap_or(true)
424    }
425
426    pub fn auto_fixup(&self) -> Fixup {
427        self.auto_fixup.unwrap_or_default()
428    }
429
430    pub fn auto_repair(&self) -> bool {
431        self.auto_repair.unwrap_or(true)
432    }
433
434    pub fn capacity(&self) -> Option<usize> {
435        let capacity = self.capacity.unwrap_or(DEFAULT_CAPACITY);
436        (capacity != 0).then_some(capacity)
437    }
438}
439
440impl std::fmt::Display for RepoConfig {
441    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
442        writeln!(f, "[{}]", CORE_EDITOR.split_once('.').unwrap().0)?;
443        writeln!(
444            f,
445            "\t{}={}",
446            CORE_EDITOR.split_once('.').unwrap().1,
447            self.editor()
448        )?;
449        writeln!(f, "[{}]", STACK_FIELD.split_once('.').unwrap().0)?;
450        for branch in self.protected_branches() {
451            writeln!(
452                f,
453                "\t{}={}",
454                PROTECTED_STACK_FIELD.split_once('.').unwrap().1,
455                branch
456            )?;
457        }
458        writeln!(
459            f,
460            "\t{}={}",
461            PROTECT_COMMIT_COUNT.split_once('.').unwrap().1,
462            self.protect_commit_count().unwrap_or(0)
463        )?;
464        writeln!(
465            f,
466            "\t{}={}",
467            PROTECT_COMMIT_AGE.split_once('.').unwrap().1,
468            humantime::format_duration(self.protect_commit_age())
469        )?;
470        writeln!(
471            f,
472            "\t{}={}",
473            AUTO_BASE_COMMIT_COUNT.split_once('.').unwrap().1,
474            self.auto_base_commit_count().unwrap_or(0)
475        )?;
476        writeln!(
477            f,
478            "\t{}={}",
479            STACK_FIELD.split_once('.').unwrap().1,
480            self.stack()
481        )?;
482        writeln!(
483            f,
484            "\t{}={}",
485            PUSH_REMOTE_FIELD.split_once('.').unwrap().1,
486            self.push_remote()
487        )?;
488        writeln!(
489            f,
490            "\t{}={}",
491            PULL_REMOTE_FIELD.split_once('.').unwrap().1,
492            self.pull_remote()
493        )?;
494        writeln!(
495            f,
496            "\t{}={}",
497            FORMAT_FIELD.split_once('.').unwrap().1,
498            self.show_format()
499        )?;
500        writeln!(
501            f,
502            "\t{}={}",
503            SHOW_COMMITS_FIELD.split_once('.').unwrap().1,
504            self.show_commits()
505        )?;
506        writeln!(
507            f,
508            "\t{}={}",
509            STACKED_FIELD.split_once('.').unwrap().1,
510            self.show_stacked()
511        )?;
512        writeln!(
513            f,
514            "\t{}={}",
515            AUTO_FIXUP_FIELD.split_once('.').unwrap().1,
516            self.auto_fixup()
517        )?;
518        writeln!(
519            f,
520            "\t{}={}",
521            AUTO_REPAIR_FIELD.split_once('.').unwrap().1,
522            self.auto_repair()
523        )?;
524        writeln!(f, "[{}]", BACKUP_CAPACITY_FIELD.split_once('.').unwrap().0)?;
525        writeln!(
526            f,
527            "\t{}={}",
528            BACKUP_CAPACITY_FIELD.split_once('.').unwrap().1,
529            self.capacity().unwrap_or(0)
530        )?;
531        Ok(())
532    }
533}
534
535fn git_dir_config(repo: &git2::Repository) -> std::path::PathBuf {
536    repo.path().join("config")
537}
538
539fn default_branch(config: &git2::Config) -> &str {
540    config.get_str("init.defaultBranch").ok().unwrap_or("main")
541}
542
543#[derive(Debug, Copy, Clone, PartialEq, Eq, clap::ValueEnum)]
544pub enum Format {
545    /// No output
546    Silent,
547    /// List branches in selected stacks
548    List,
549    /// Render a branch branch
550    Graph,
551    /// Internal data for debugging
552    Debug,
553}
554
555impl std::fmt::Display for Format {
556    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
557        use clap::ValueEnum;
558        self.to_possible_value()
559            .expect("no values are skipped")
560            .get_name()
561            .fmt(f)
562    }
563}
564
565impl FromStr for Format {
566    type Err = String;
567
568    fn from_str(s: &str) -> Result<Self, Self::Err> {
569        use clap::ValueEnum;
570        for variant in Self::value_variants() {
571            if variant.to_possible_value().unwrap().matches(s, false) {
572                return Ok(*variant);
573            }
574        }
575        Err(format!("Invalid variant: {s}"))
576    }
577}
578
579impl Default for Format {
580    fn default() -> Self {
581        Self::Graph
582    }
583}
584
585#[derive(Debug, Copy, Clone, PartialEq, Eq, clap::ValueEnum)]
586pub enum ShowCommits {
587    None,
588    Unprotected,
589    All,
590}
591
592impl std::fmt::Display for ShowCommits {
593    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
594        use clap::ValueEnum;
595        self.to_possible_value()
596            .expect("no values are skipped")
597            .get_name()
598            .fmt(f)
599    }
600}
601
602impl FromStr for ShowCommits {
603    type Err = String;
604
605    fn from_str(s: &str) -> Result<Self, Self::Err> {
606        use clap::ValueEnum;
607        for variant in Self::value_variants() {
608            if variant.to_possible_value().unwrap().matches(s, false) {
609                return Ok(*variant);
610            }
611        }
612        Err(format!("Invalid variant: {s}"))
613    }
614}
615
616impl Default for ShowCommits {
617    fn default() -> Self {
618        Self::Unprotected
619    }
620}
621
622#[derive(clap::ValueEnum, Debug, Copy, Clone, PartialEq, Eq)]
623pub enum Stack {
624    /// Branches in BASE..HEAD
625    Current,
626    /// Branches in BASE..HEAD..
627    Dependents,
628    /// Branches in BASE..
629    Descendants,
630    /// Show all branches
631    All,
632}
633
634impl std::fmt::Display for Stack {
635    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
636        use clap::ValueEnum;
637        self.to_possible_value()
638            .expect("no values are skipped")
639            .get_name()
640            .fmt(f)
641    }
642}
643
644impl FromStr for Stack {
645    type Err = String;
646
647    fn from_str(s: &str) -> Result<Self, Self::Err> {
648        use clap::ValueEnum;
649        for variant in Self::value_variants() {
650            if variant.to_possible_value().unwrap().matches(s, false) {
651                return Ok(*variant);
652            }
653        }
654        Err(format!("Invalid variant: {s}"))
655    }
656}
657
658impl Default for Stack {
659    fn default() -> Self {
660        Self::All
661    }
662}
663
664#[derive(Debug, Copy, Clone, PartialEq, Eq, clap::ValueEnum)]
665pub enum Fixup {
666    /// No special processing
667    Ignore,
668    /// Move them to after the commit they fix
669    Move,
670    /// Squash into the commit they fix
671    Squash,
672}
673
674impl Fixup {
675    pub fn variants() -> [&'static str; 3] {
676        ["ignore", "move", "squash"]
677    }
678}
679
680impl std::fmt::Display for Fixup {
681    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
682        use clap::ValueEnum;
683        self.to_possible_value()
684            .expect("no values are skipped")
685            .get_name()
686            .fmt(f)
687    }
688}
689
690impl FromStr for Fixup {
691    type Err = String;
692
693    fn from_str(s: &str) -> Result<Self, Self::Err> {
694        use clap::ValueEnum;
695        for variant in Self::value_variants() {
696            if variant.to_possible_value().unwrap().matches(s, false) {
697                return Ok(*variant);
698            }
699        }
700        Err(format!("Invalid variant: {s}"))
701    }
702}
703
704impl Default for Fixup {
705    fn default() -> Self {
706        Self::Move
707    }
708}