git_trim/
config.rs

1use std::convert::TryFrom;
2use std::fmt::Debug;
3use std::iter::FromIterator;
4use std::ops::Deref;
5use std::str::FromStr;
6
7use anyhow::{Context, Result};
8use git2::{BranchType, Config as GitConfig, Error, ErrorClass, ErrorCode, Remote, Repository};
9use log::*;
10
11use crate::args::{Args, DeleteFilter, DeleteRange};
12use crate::branch::{LocalBranch, RemoteTrackingBranchStatus};
13use std::collections::HashSet;
14
15type GitResult<T> = std::result::Result<T, git2::Error>;
16
17#[derive(Debug)]
18pub struct Config {
19    pub bases: ConfigValue<HashSet<String>>,
20    pub protected: ConfigValue<Vec<String>>,
21    pub update: ConfigValue<bool>,
22    pub update_interval: ConfigValue<u64>,
23    pub confirm: ConfigValue<bool>,
24    pub detach: ConfigValue<bool>,
25    pub delete: ConfigValue<DeleteFilter>,
26}
27
28impl Config {
29    pub fn read(repo: &Repository, config: &GitConfig, args: &Args) -> Result<Self> {
30        fn non_empty<T>(x: Vec<T>) -> Option<Vec<T>> {
31            if x.is_empty() {
32                None
33            } else {
34                Some(x)
35            }
36        }
37
38        let bases = get_comma_separated_multi(config, "trim.bases")
39            .with_explicit(non_empty(args.bases.clone()))
40            .with_default(get_branches_tracks_remote_heads(repo, config)?)
41            .parses_and_collect::<HashSet<String>>()?;
42        let protected = get_comma_separated_multi(config, "trim.protected")
43            .with_explicit(non_empty(args.protected.clone()))
44            .parses_and_collect::<Vec<String>>()?;
45        let update = get(config, "trim.update")
46            .with_explicit(args.update())
47            .with_default(true)
48            .read()?
49            .expect("has default");
50        let update_interval = get(config, "trim.updateInterval")
51            .with_explicit(args.update_interval)
52            .with_default(5)
53            .read()?
54            .expect("has default");
55        let confirm = get(config, "trim.confirm")
56            .with_explicit(args.confirm())
57            .with_default(true)
58            .read()?
59            .expect("has default");
60        let detach = get(config, "trim.detach")
61            .with_explicit(args.detach())
62            .with_default(true)
63            .read()?
64            .expect("has default");
65        let delete = get_comma_separated_multi(config, "trim.delete")
66            .with_explicit(non_empty(args.delete.clone()))
67            .with_default(DeleteRange::merged_origin())
68            .parses_and_collect::<DeleteFilter>()?;
69
70        Ok(Config {
71            bases,
72            protected,
73            update,
74            update_interval,
75            confirm,
76            detach,
77            delete,
78        })
79    }
80}
81
82fn get_branches_tracks_remote_heads(repo: &Repository, config: &GitConfig) -> Result<Vec<String>> {
83    let mut local_bases = Vec::new();
84    let mut all_bases = Vec::new();
85
86    for reference in repo.references_glob("refs/remotes/*/HEAD")? {
87        let reference = reference?;
88        // git symbolic-ref refs/remotes/*/HEAD
89        let resolved = match reference.resolve() {
90            Ok(resolved) => resolved,
91            Err(_) => {
92                debug!(
93                    "Reference {:?} is expected to be an symbolic ref, but it isn't",
94                    reference.name()
95                );
96                continue;
97            }
98        };
99        let refname = resolved.name().context("non utf-8 reference name")?;
100        all_bases.push(refname.to_owned());
101
102        for branch in repo.branches(Some(BranchType::Local))? {
103            let (branch, _) = branch?;
104            let branch = LocalBranch::try_from(&branch)?;
105
106            if let RemoteTrackingBranchStatus::Exists(upstream) =
107                branch.fetch_upstream(repo, config)?
108            {
109                if upstream.refname == refname {
110                    local_bases.push(branch.short_name().to_owned());
111                }
112            }
113        }
114    }
115
116    if local_bases.is_empty() {
117        Ok(all_bases)
118    } else {
119        Ok(local_bases)
120    }
121}
122
123#[derive(Debug, Eq, PartialEq)]
124pub enum ConfigValue<T> {
125    Explicit(T),
126    GitConfig(T),
127    Implicit(T),
128}
129
130impl<T> ConfigValue<T> {
131    pub fn unwrap(self) -> T {
132        match self {
133            ConfigValue::Explicit(x) | ConfigValue::GitConfig(x) | ConfigValue::Implicit(x) => x,
134        }
135    }
136
137    pub fn is_implicit(&self) -> bool {
138        match self {
139            ConfigValue::Explicit(_) => false,
140            ConfigValue::GitConfig(_) => false,
141            ConfigValue::Implicit(_) => true,
142        }
143    }
144}
145
146impl<T> Deref for ConfigValue<T> {
147    type Target = T;
148
149    fn deref(&self) -> &Self::Target {
150        match self {
151            ConfigValue::Explicit(x) | ConfigValue::GitConfig(x) | ConfigValue::Implicit(x) => x,
152        }
153    }
154}
155
156pub struct ConfigBuilder<'a, T> {
157    config: &'a GitConfig,
158    key: &'a str,
159    explicit: Option<T>,
160    default: Option<T>,
161    comma_separated: bool,
162}
163
164pub fn get<'a, T>(config: &'a GitConfig, key: &'a str) -> ConfigBuilder<'a, T> {
165    ConfigBuilder {
166        config,
167        key,
168        explicit: None,
169        default: None,
170        comma_separated: false,
171    }
172}
173
174pub fn get_comma_separated_multi<'a, T>(
175    config: &'a GitConfig,
176    key: &'a str,
177) -> ConfigBuilder<'a, T> {
178    ConfigBuilder {
179        config,
180        key,
181        explicit: None,
182        default: None,
183        comma_separated: true,
184    }
185}
186
187impl<'a, T> ConfigBuilder<'a, T> {
188    fn with_explicit(self, value: Option<T>) -> ConfigBuilder<'a, T> {
189        if let Some(value) = value {
190            ConfigBuilder {
191                explicit: Some(value),
192                ..self
193            }
194        } else {
195            self
196        }
197    }
198
199    pub fn with_default(self, value: T) -> ConfigBuilder<'a, T> {
200        ConfigBuilder {
201            default: Some(value),
202            ..self
203        }
204    }
205}
206
207impl<'a, T> ConfigBuilder<'a, T>
208where
209    T: ConfigValues,
210{
211    pub fn read(self) -> GitResult<Option<ConfigValue<T>>> {
212        if let Some(value) = self.explicit {
213            return Ok(Some(ConfigValue::Explicit(value)));
214        }
215        match T::get_config_value(self.config, self.key) {
216            Ok(value) => Ok(Some(ConfigValue::GitConfig(value))),
217            Err(err) if config_not_exist(&err) => {
218                if let Some(default) = self.default {
219                    Ok(Some(ConfigValue::Implicit(default)))
220                } else {
221                    Ok(None)
222                }
223            }
224            Err(err) => Err(err),
225        }
226    }
227}
228
229impl<'a, T> ConfigBuilder<'a, T> {
230    fn parses_and_collect<U>(self) -> Result<ConfigValue<U>>
231    where
232        T: IntoIterator,
233        T::Item: FromStr,
234        <T::Item as FromStr>::Err: std::error::Error + Send + Sync + 'static,
235        U: FromIterator<<T as IntoIterator>::Item> + Default,
236    {
237        if let Some(value) = self.explicit {
238            return Ok(ConfigValue::Explicit(value.into_iter().collect()));
239        }
240
241        let result = match Vec::<String>::get_config_value(self.config, self.key) {
242            Ok(entries) if !entries.is_empty() => {
243                let mut result = Vec::new();
244                if self.comma_separated {
245                    for entry in entries {
246                        for item in entry.split(',') {
247                            if !item.is_empty() {
248                                let value = <T::Item>::from_str(item)?;
249                                result.push(value);
250                            }
251                        }
252                    }
253                } else {
254                    for entry in entries {
255                        let value = <T::Item>::from_str(&entry)?;
256                        result.push(value);
257                    }
258                }
259
260                ConfigValue::GitConfig(result.into_iter().collect())
261            }
262            Ok(_) => {
263                if let Some(default) = self.default {
264                    ConfigValue::Implicit(default.into_iter().collect())
265                } else {
266                    ConfigValue::Implicit(U::default())
267                }
268            }
269            Err(err) => return Err(err.into()),
270        };
271        Ok(result)
272    }
273}
274
275pub trait ConfigValues {
276    fn get_config_value(config: &GitConfig, key: &str) -> Result<Self, git2::Error>
277    where
278        Self: Sized;
279}
280
281impl ConfigValues for String {
282    fn get_config_value(config: &GitConfig, key: &str) -> Result<Self, git2::Error> {
283        config.get_string(key)
284    }
285}
286
287impl ConfigValues for Vec<String> {
288    fn get_config_value(config: &GitConfig, key: &str) -> Result<Self, git2::Error> {
289        let mut result = Vec::new();
290        let mut entries = config.entries(Some(key))?;
291        while let Some(entry) = entries.next() {
292            let entry = entry?;
293            if let Some(value) = entry.value() {
294                result.push(value.to_owned());
295            } else {
296                warn!(
297                    "non utf-8 config entry {}",
298                    String::from_utf8_lossy(entry.name_bytes())
299                );
300            }
301        }
302        Ok(result)
303    }
304}
305
306impl ConfigValues for bool {
307    fn get_config_value(config: &GitConfig, key: &str) -> Result<Self, git2::Error> {
308        config.get_bool(key)
309    }
310}
311
312impl ConfigValues for u64 {
313    fn get_config_value(config: &GitConfig, key: &str) -> Result<Self, git2::Error> {
314        let value = config.get_i64(key)?;
315        if value >= 0 {
316            return Ok(value as u64);
317        }
318        panic!("`git config {}` cannot be negative value", key);
319    }
320}
321
322fn config_not_exist(err: &git2::Error) -> bool {
323    err.code() == ErrorCode::NotFound && err.class() == ErrorClass::Config
324}
325
326pub fn get_push_remote(config: &GitConfig, branch: &LocalBranch) -> Result<String> {
327    let push_remote_key = format!("branch.{}.pushRemote", branch.short_name());
328    if let Some(push_remote) = get::<String>(config, &push_remote_key).read()? {
329        return Ok(push_remote.unwrap());
330    }
331
332    if let Some(push_default) = get::<String>(config, "remote.pushDefault").read()? {
333        return Ok(push_default.unwrap());
334    }
335
336    Ok(get_remote_name(config, branch)?.unwrap_or_else(|| "origin".to_owned()))
337}
338
339pub fn get_remote_name(config: &GitConfig, branch: &LocalBranch) -> Result<Option<String>> {
340    let key = format!("branch.{}.remote", branch.short_name());
341    match config.get_string(&key) {
342        Ok(remote) => Ok(Some(remote)),
343        Err(err) if config_not_exist(&err) => Ok(None),
344        Err(err) => Err(err.into()),
345    }
346}
347
348pub fn get_remote<'a>(repo: &'a Repository, remote_name: &str) -> Result<Option<Remote<'a>>> {
349    fn error_is_missing_remote(err: &Error) -> bool {
350        err.class() == ErrorClass::Config && err.code() == ErrorCode::InvalidSpec
351    }
352
353    match repo.find_remote(remote_name) {
354        Ok(remote) => Ok(Some(remote)),
355        Err(err) if error_is_missing_remote(&err) => Ok(None),
356        Err(err) => Err(err.into()),
357    }
358}
359
360pub fn get_merge(config: &GitConfig, branch: &LocalBranch) -> Result<Option<String>> {
361    let key = format!("branch.{}.merge", branch.short_name());
362    match config.get_string(&key) {
363        Ok(merge) => Ok(Some(merge)),
364        Err(err) if config_not_exist(&err) => Ok(None),
365        Err(err) => Err(err.into()),
366    }
367}