1use std::process::Command;
2
3use camino::Utf8PathBuf;
4use clap::Parser;
5use miette::Context;
6use miette::IntoDiagnostic;
7use regex::Regex;
8use serde::de::Error;
9use serde::Deserialize;
10use unindent::unindent;
11use xdg::BaseDirectories;
12
13use crate::cli::Cli;
14use crate::fs;
15use crate::install_tracing::install_tracing;
16
17#[derive(Debug)]
19pub struct Config {
20 #[expect(dead_code)]
22 pub(crate) dirs: BaseDirectories,
23 pub file: ConfigFile,
25 pub path: Utf8PathBuf,
27 pub cli: Cli,
29}
30
31impl Config {
32 pub const DEFAULT: &str = include_str!("../config.toml");
34
35 pub fn new() -> miette::Result<Self> {
36 let cli = Cli::parse();
37 install_tracing(&cli.log)?;
39 let dirs = BaseDirectories::with_prefix("git-prole").into_diagnostic()?;
40 let path = cli
42 .config
43 .as_ref()
44 .map(|path| Ok(path.to_owned()))
45 .unwrap_or_else(|| config_file_path(&dirs))?;
46 let file = {
47 if !path.exists() {
48 ConfigFile::default()
49 } else {
50 toml::from_str(
51 &fs::read_to_string(&path).wrap_err("Failed to read configuration file")?,
52 )
53 .into_diagnostic()
54 .wrap_err("Failed to deserialize configuration file")?
55 }
56 };
57 Ok(Self {
58 dirs,
59 path,
60 file,
61 cli,
62 })
63 }
64
65 #[cfg(test)]
67 pub fn test_stub() -> Self {
68 let dirs = BaseDirectories::new().unwrap();
70 let path = config_file_path(&dirs).unwrap();
71 Self {
72 dirs,
73 file: ConfigFile::default(),
74 path,
75 cli: Cli::test_stub(),
76 }
77 }
78}
79
80fn config_file_path(dirs: &BaseDirectories) -> miette::Result<Utf8PathBuf> {
81 dirs.get_config_file(ConfigFile::FILE_NAME)
82 .try_into()
83 .into_diagnostic()
84}
85
86#[derive(Debug, Default, Deserialize, PartialEq, Eq)]
96#[serde(default, deny_unknown_fields)]
97pub struct ConfigFile {
98 remote_names: Vec<String>,
99 branch_names: Vec<String>,
100 pub clone: CloneConfig,
101 pub add: AddConfig,
102}
103
104impl ConfigFile {
105 pub const FILE_NAME: &str = "config.toml";
106
107 pub fn remote_names(&self) -> Vec<String> {
108 if self.remote_names.is_empty() {
110 vec!["upstream".to_owned(), "origin".to_owned()]
111 } else {
112 self.remote_names.clone()
113 }
114 }
115
116 pub fn branch_names(&self) -> Vec<String> {
117 if self.branch_names.is_empty() {
119 vec!["main".to_owned(), "master".to_owned(), "trunk".to_owned()]
120 } else {
121 self.branch_names.clone()
122 }
123 }
124}
125
126#[derive(Debug, Default, Deserialize, PartialEq, Eq)]
127#[serde(default)]
128pub struct CloneConfig {
129 enable_gh: Option<bool>,
130}
131
132impl CloneConfig {
133 pub fn enable_gh(&self) -> bool {
134 self.enable_gh.unwrap_or(false)
135 }
136}
137
138#[derive(Debug, Default, Deserialize, PartialEq, Eq)]
139#[serde(default)]
140pub struct AddConfig {
141 copy_untracked: Option<bool>,
142 copy_ignored: Option<bool>,
143 commands: Vec<ShellCommand>,
144 branch_replacements: Vec<BranchReplacement>,
145}
146
147impl AddConfig {
148 pub fn copy_ignored(&self) -> bool {
149 if let Some(copy_untracked) = self.copy_untracked {
150 tracing::warn!("`add.copy_untracked` has been replaced with `add.copy_ignored`");
151 return copy_untracked;
152 }
153 self.copy_ignored.unwrap_or(true)
154 }
155
156 pub fn commands(&self) -> &[ShellCommand] {
157 &self.commands
158 }
159
160 pub fn branch_replacements(&self) -> &[BranchReplacement] {
161 &self.branch_replacements
162 }
163}
164
165#[derive(Clone, Debug, PartialEq, Eq, Deserialize)]
166#[serde(untagged)]
167pub enum ShellCommand {
168 Simple(ShellArgs),
169 Shell { sh: String },
170}
171
172impl ShellCommand {
173 pub fn as_command(&self) -> Command {
174 match self {
175 ShellCommand::Simple(args) => {
176 let mut command = Command::new(&args.program);
177 command.args(&args.args);
178 command
179 }
180 ShellCommand::Shell { sh } => {
181 let mut command = Command::new("sh");
182 let sh = unindent(sh);
183 command.args(["-c", sh.trim_ascii()]);
184 command
185 }
186 }
187 }
188}
189
190#[derive(Clone, Debug, PartialEq, Eq)]
191pub struct ShellArgs {
192 program: String,
193 args: Vec<String>,
194}
195
196impl<'de> Deserialize<'de> for ShellArgs {
197 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
198 where
199 D: serde::Deserializer<'de>,
200 {
201 let quoted: String = Deserialize::deserialize(deserializer)?;
202 let mut args = shell_words::split("ed).map_err(D::Error::custom)?;
203
204 if args.is_empty() {
205 return Err(D::Error::invalid_value(
206 serde::de::Unexpected::Str("ed),
207 &"a shell command (you are missing a program)",
211 ));
212 }
213
214 let program = args.remove(0);
215
216 Ok(Self { program, args })
217 }
218}
219
220#[derive(Clone, Debug, Deserialize)]
221pub struct BranchReplacement {
222 #[serde(deserialize_with = "deserialize_regex")]
223 pub find: Regex,
224 pub replace: String,
225 pub count: Option<usize>,
226}
227
228impl PartialEq for BranchReplacement {
229 fn eq(&self, other: &Self) -> bool {
230 self.replace == other.replace && self.find.as_str() == other.find.as_str()
231 }
232}
233
234impl Eq for BranchReplacement {}
235
236fn deserialize_regex<'de, D>(deserializer: D) -> Result<Regex, D::Error>
237where
238 D: serde::Deserializer<'de>,
239{
240 let input: String = Deserialize::deserialize(deserializer)?;
241 Regex::new(&input).map_err(D::Error::custom)
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247 use pretty_assertions::assert_eq;
248
249 #[test]
250 fn test_default_config_file_parse() {
251 let default_config = toml::from_str::<ConfigFile>(Config::DEFAULT).unwrap();
252 assert_eq!(
253 default_config,
254 ConfigFile {
255 remote_names: vec!["upstream".to_owned(), "origin".to_owned(),],
256 branch_names: vec!["main".to_owned(), "master".to_owned(), "trunk".to_owned(),],
257 clone: CloneConfig {
258 enable_gh: Some(false)
259 },
260 add: AddConfig {
261 copy_untracked: None,
262 copy_ignored: Some(true),
263 commands: vec![],
264 branch_replacements: vec![],
265 }
266 }
267 );
268
269 let empty_config = toml::from_str::<ConfigFile>("").unwrap();
270 assert_eq!(
271 default_config,
272 ConfigFile {
273 remote_names: empty_config.remote_names(),
274 branch_names: empty_config.branch_names(),
275 clone: CloneConfig {
276 enable_gh: Some(empty_config.clone.enable_gh()),
277 },
278 add: AddConfig {
279 copy_untracked: None,
280 copy_ignored: Some(empty_config.add.copy_ignored()),
281 commands: empty_config
282 .add
283 .commands()
284 .iter()
285 .map(|command| command.to_owned())
286 .collect(),
287 branch_replacements: empty_config
288 .add
289 .branch_replacements()
290 .iter()
291 .map(|replacement| replacement.to_owned())
292 .collect(),
293 },
294 }
295 );
296 }
297}