git_prole/
config.rs

1use std::process::Command;
2
3use camino::Utf8PathBuf;
4use clap::Parser;
5use miette::Context;
6use miette::IntoDiagnostic;
7use regex::Regex;
8use serde::de::Error;
9use serde::Deserialize;
10use unindent::unindent;
11use xdg::BaseDirectories;
12
13use crate::cli::Cli;
14use crate::fs;
15use crate::install_tracing::install_tracing;
16
17/// Configuration, both from the command-line and a user configuration file.
18#[derive(Debug)]
19pub struct Config {
20    /// User directories.
21    #[expect(dead_code)]
22    pub(crate) dirs: BaseDirectories,
23    /// User configuration file.
24    pub file: ConfigFile,
25    /// User configuration file path.
26    pub path: Utf8PathBuf,
27    /// Command-line options.
28    pub cli: Cli,
29}
30
31impl Config {
32    /// The contents of the default configuration file.
33    pub const DEFAULT: &str = include_str!("../config.toml");
34
35    pub fn new() -> miette::Result<Self> {
36        let cli = Cli::parse();
37        // TODO: add tracing settings to the config file
38        install_tracing(&cli.log)?;
39        let dirs = BaseDirectories::with_prefix("git-prole").into_diagnostic()?;
40        // TODO: Use `git config` for configuration?
41        let path = cli
42            .config
43            .as_ref()
44            .map(|path| Ok(path.to_owned()))
45            .unwrap_or_else(|| config_file_path(&dirs))?;
46        let file = {
47            if !path.exists() {
48                ConfigFile::default()
49            } else {
50                toml::from_str(
51                    &fs::read_to_string(&path).wrap_err("Failed to read configuration file")?,
52                )
53                .into_diagnostic()
54                .wrap_err("Failed to deserialize configuration file")?
55            }
56        };
57        Ok(Self {
58            dirs,
59            path,
60            file,
61            cli,
62        })
63    }
64
65    /// A fake stub config for testing.
66    #[cfg(test)]
67    pub fn test_stub() -> Self {
68        // TODO: Make this pure-er.
69        let dirs = BaseDirectories::new().unwrap();
70        let path = config_file_path(&dirs).unwrap();
71        Self {
72            dirs,
73            file: ConfigFile::default(),
74            path,
75            cli: Cli::test_stub(),
76        }
77    }
78}
79
80fn config_file_path(dirs: &BaseDirectories) -> miette::Result<Utf8PathBuf> {
81    dirs.get_config_file(ConfigFile::FILE_NAME)
82        .try_into()
83        .into_diagnostic()
84}
85
86/// Configuration file format.
87///
88/// Each configuration key should have two test cases:
89/// - `config_{key}` for setting the value.
90/// - `config_{key}_default` for the default value.
91///
92/// For documentation, see the default configuration file (`../config.toml`).
93///
94/// The default configuration file is accessible as [`Config::DEFAULT`].
95#[derive(Debug, Default, Deserialize, PartialEq, Eq)]
96#[serde(default, deny_unknown_fields)]
97pub struct ConfigFile {
98    remote_names: Vec<String>,
99    branch_names: Vec<String>,
100    pub clone: CloneConfig,
101    pub add: AddConfig,
102}
103
104impl ConfigFile {
105    pub const FILE_NAME: &str = "config.toml";
106
107    pub fn remote_names(&self) -> Vec<String> {
108        // Yeah this basically sucks. But how big could these lists really be?
109        if self.remote_names.is_empty() {
110            vec!["upstream".to_owned(), "origin".to_owned()]
111        } else {
112            self.remote_names.clone()
113        }
114    }
115
116    pub fn branch_names(&self) -> Vec<String> {
117        // Yeah this basically sucks. But how big could these lists really be?
118        if self.branch_names.is_empty() {
119            vec!["main".to_owned(), "master".to_owned(), "trunk".to_owned()]
120        } else {
121            self.branch_names.clone()
122        }
123    }
124}
125
126#[derive(Debug, Default, Deserialize, PartialEq, Eq)]
127#[serde(default)]
128pub struct CloneConfig {
129    enable_gh: Option<bool>,
130}
131
132impl CloneConfig {
133    pub fn enable_gh(&self) -> bool {
134        self.enable_gh.unwrap_or(false)
135    }
136}
137
138#[derive(Debug, Default, Deserialize, PartialEq, Eq)]
139#[serde(default)]
140pub struct AddConfig {
141    copy_untracked: Option<bool>,
142    copy_ignored: Option<bool>,
143    commands: Vec<ShellCommand>,
144    branch_replacements: Vec<BranchReplacement>,
145}
146
147impl AddConfig {
148    pub fn copy_ignored(&self) -> bool {
149        if let Some(copy_untracked) = self.copy_untracked {
150            tracing::warn!("`add.copy_untracked` has been replaced with `add.copy_ignored`");
151            return copy_untracked;
152        }
153        self.copy_ignored.unwrap_or(true)
154    }
155
156    pub fn commands(&self) -> &[ShellCommand] {
157        &self.commands
158    }
159
160    pub fn branch_replacements(&self) -> &[BranchReplacement] {
161        &self.branch_replacements
162    }
163}
164
165#[derive(Clone, Debug, PartialEq, Eq, Deserialize)]
166#[serde(untagged)]
167pub enum ShellCommand {
168    Simple(ShellArgs),
169    Shell { sh: String },
170}
171
172impl ShellCommand {
173    pub fn as_command(&self) -> Command {
174        match self {
175            ShellCommand::Simple(args) => {
176                let mut command = Command::new(&args.program);
177                command.args(&args.args);
178                command
179            }
180            ShellCommand::Shell { sh } => {
181                let mut command = Command::new("sh");
182                let sh = unindent(sh);
183                command.args(["-c", sh.trim_ascii()]);
184                command
185            }
186        }
187    }
188}
189
190#[derive(Clone, Debug, PartialEq, Eq)]
191pub struct ShellArgs {
192    program: String,
193    args: Vec<String>,
194}
195
196impl<'de> Deserialize<'de> for ShellArgs {
197    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
198    where
199        D: serde::Deserializer<'de>,
200    {
201        let quoted: String = Deserialize::deserialize(deserializer)?;
202        let mut args = shell_words::split(&quoted).map_err(D::Error::custom)?;
203
204        if args.is_empty() {
205            return Err(D::Error::invalid_value(
206                serde::de::Unexpected::Str(&quoted),
207                // TODO: This error message doesn't actually get propagated upward
208                // correctly, so you get "data did not match any variant of untagged enum
209                // ShellCommand" instead.
210                &"a shell command (you are missing a program)",
211            ));
212        }
213
214        let program = args.remove(0);
215
216        Ok(Self { program, args })
217    }
218}
219
220#[derive(Clone, Debug, Deserialize)]
221pub struct BranchReplacement {
222    #[serde(deserialize_with = "deserialize_regex")]
223    pub find: Regex,
224    pub replace: String,
225    pub count: Option<usize>,
226}
227
228impl PartialEq for BranchReplacement {
229    fn eq(&self, other: &Self) -> bool {
230        self.replace == other.replace && self.find.as_str() == other.find.as_str()
231    }
232}
233
234impl Eq for BranchReplacement {}
235
236fn deserialize_regex<'de, D>(deserializer: D) -> Result<Regex, D::Error>
237where
238    D: serde::Deserializer<'de>,
239{
240    let input: String = Deserialize::deserialize(deserializer)?;
241    Regex::new(&input).map_err(D::Error::custom)
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use pretty_assertions::assert_eq;
248
249    #[test]
250    fn test_default_config_file_parse() {
251        let default_config = toml::from_str::<ConfigFile>(Config::DEFAULT).unwrap();
252        assert_eq!(
253            default_config,
254            ConfigFile {
255                remote_names: vec!["upstream".to_owned(), "origin".to_owned(),],
256                branch_names: vec!["main".to_owned(), "master".to_owned(), "trunk".to_owned(),],
257                clone: CloneConfig {
258                    enable_gh: Some(false)
259                },
260                add: AddConfig {
261                    copy_untracked: None,
262                    copy_ignored: Some(true),
263                    commands: vec![],
264                    branch_replacements: vec![],
265                }
266            }
267        );
268
269        let empty_config = toml::from_str::<ConfigFile>("").unwrap();
270        assert_eq!(
271            default_config,
272            ConfigFile {
273                remote_names: empty_config.remote_names(),
274                branch_names: empty_config.branch_names(),
275                clone: CloneConfig {
276                    enable_gh: Some(empty_config.clone.enable_gh()),
277                },
278                add: AddConfig {
279                    copy_untracked: None,
280                    copy_ignored: Some(empty_config.add.copy_ignored()),
281                    commands: empty_config
282                        .add
283                        .commands()
284                        .iter()
285                        .map(|command| command.to_owned())
286                        .collect(),
287                    branch_replacements: empty_config
288                        .add
289                        .branch_replacements()
290                        .iter()
291                        .map(|replacement| replacement.to_owned())
292                        .collect(),
293                },
294            }
295        );
296    }
297}