1use std::convert::TryFrom;
2use std::fmt::Debug;
3use std::iter::FromIterator;
4use std::ops::Deref;
5use std::str::FromStr;
6
7use anyhow::{Context, Result};
8use git2::{BranchType, Config as GitConfig, Error, ErrorClass, ErrorCode, Remote, Repository};
9use log::*;
10
11use crate::args::{Args, DeleteFilter, DeleteRange};
12use crate::branch::{LocalBranch, RemoteTrackingBranchStatus};
13use std::collections::HashSet;
14
15type GitResult<T> = std::result::Result<T, git2::Error>;
16
17#[derive(Debug)]
18pub struct Config {
19 pub bases: ConfigValue<HashSet<String>>,
20 pub protected: ConfigValue<Vec<String>>,
21 pub update: ConfigValue<bool>,
22 pub update_interval: ConfigValue<u64>,
23 pub confirm: ConfigValue<bool>,
24 pub detach: ConfigValue<bool>,
25 pub delete: ConfigValue<DeleteFilter>,
26}
27
28impl Config {
29 pub fn read(repo: &Repository, config: &GitConfig, args: &Args) -> Result<Self> {
30 fn non_empty<T>(x: Vec<T>) -> Option<Vec<T>> {
31 if x.is_empty() {
32 None
33 } else {
34 Some(x)
35 }
36 }
37
38 let bases = get_comma_separated_multi(config, "trim.bases")
39 .with_explicit(non_empty(args.bases.clone()))
40 .with_default(get_branches_tracks_remote_heads(repo, config)?)
41 .parses_and_collect::<HashSet<String>>()?;
42 let protected = get_comma_separated_multi(config, "trim.protected")
43 .with_explicit(non_empty(args.protected.clone()))
44 .parses_and_collect::<Vec<String>>()?;
45 let update = get(config, "trim.update")
46 .with_explicit(args.update())
47 .with_default(true)
48 .read()?
49 .expect("has default");
50 let update_interval = get(config, "trim.updateInterval")
51 .with_explicit(args.update_interval)
52 .with_default(5)
53 .read()?
54 .expect("has default");
55 let confirm = get(config, "trim.confirm")
56 .with_explicit(args.confirm())
57 .with_default(true)
58 .read()?
59 .expect("has default");
60 let detach = get(config, "trim.detach")
61 .with_explicit(args.detach())
62 .with_default(true)
63 .read()?
64 .expect("has default");
65 let delete = get_comma_separated_multi(config, "trim.delete")
66 .with_explicit(non_empty(args.delete.clone()))
67 .with_default(DeleteRange::merged_origin())
68 .parses_and_collect::<DeleteFilter>()?;
69
70 Ok(Config {
71 bases,
72 protected,
73 update,
74 update_interval,
75 confirm,
76 detach,
77 delete,
78 })
79 }
80}
81
82fn get_branches_tracks_remote_heads(repo: &Repository, config: &GitConfig) -> Result<Vec<String>> {
83 let mut local_bases = Vec::new();
84 let mut all_bases = Vec::new();
85
86 for reference in repo.references_glob("refs/remotes/*/HEAD")? {
87 let reference = reference?;
88 let resolved = match reference.resolve() {
90 Ok(resolved) => resolved,
91 Err(_) => {
92 debug!(
93 "Reference {:?} is expected to be an symbolic ref, but it isn't",
94 reference.name()
95 );
96 continue;
97 }
98 };
99 let refname = resolved.name().context("non utf-8 reference name")?;
100 all_bases.push(refname.to_owned());
101
102 for branch in repo.branches(Some(BranchType::Local))? {
103 let (branch, _) = branch?;
104 let branch = LocalBranch::try_from(&branch)?;
105
106 if let RemoteTrackingBranchStatus::Exists(upstream) =
107 branch.fetch_upstream(repo, config)?
108 {
109 if upstream.refname == refname {
110 local_bases.push(branch.short_name().to_owned());
111 }
112 }
113 }
114 }
115
116 if local_bases.is_empty() {
117 Ok(all_bases)
118 } else {
119 Ok(local_bases)
120 }
121}
122
123#[derive(Debug, Eq, PartialEq)]
124pub enum ConfigValue<T> {
125 Explicit(T),
126 GitConfig(T),
127 Implicit(T),
128}
129
130impl<T> ConfigValue<T> {
131 pub fn unwrap(self) -> T {
132 match self {
133 ConfigValue::Explicit(x) | ConfigValue::GitConfig(x) | ConfigValue::Implicit(x) => x,
134 }
135 }
136
137 pub fn is_implicit(&self) -> bool {
138 match self {
139 ConfigValue::Explicit(_) => false,
140 ConfigValue::GitConfig(_) => false,
141 ConfigValue::Implicit(_) => true,
142 }
143 }
144}
145
146impl<T> Deref for ConfigValue<T> {
147 type Target = T;
148
149 fn deref(&self) -> &Self::Target {
150 match self {
151 ConfigValue::Explicit(x) | ConfigValue::GitConfig(x) | ConfigValue::Implicit(x) => x,
152 }
153 }
154}
155
156pub struct ConfigBuilder<'a, T> {
157 config: &'a GitConfig,
158 key: &'a str,
159 explicit: Option<T>,
160 default: Option<T>,
161 comma_separated: bool,
162}
163
164pub fn get<'a, T>(config: &'a GitConfig, key: &'a str) -> ConfigBuilder<'a, T> {
165 ConfigBuilder {
166 config,
167 key,
168 explicit: None,
169 default: None,
170 comma_separated: false,
171 }
172}
173
174pub fn get_comma_separated_multi<'a, T>(
175 config: &'a GitConfig,
176 key: &'a str,
177) -> ConfigBuilder<'a, T> {
178 ConfigBuilder {
179 config,
180 key,
181 explicit: None,
182 default: None,
183 comma_separated: true,
184 }
185}
186
187impl<'a, T> ConfigBuilder<'a, T> {
188 fn with_explicit(self, value: Option<T>) -> ConfigBuilder<'a, T> {
189 if let Some(value) = value {
190 ConfigBuilder {
191 explicit: Some(value),
192 ..self
193 }
194 } else {
195 self
196 }
197 }
198
199 pub fn with_default(self, value: T) -> ConfigBuilder<'a, T> {
200 ConfigBuilder {
201 default: Some(value),
202 ..self
203 }
204 }
205}
206
207impl<'a, T> ConfigBuilder<'a, T>
208where
209 T: ConfigValues,
210{
211 pub fn read(self) -> GitResult<Option<ConfigValue<T>>> {
212 if let Some(value) = self.explicit {
213 return Ok(Some(ConfigValue::Explicit(value)));
214 }
215 match T::get_config_value(self.config, self.key) {
216 Ok(value) => Ok(Some(ConfigValue::GitConfig(value))),
217 Err(err) if config_not_exist(&err) => {
218 if let Some(default) = self.default {
219 Ok(Some(ConfigValue::Implicit(default)))
220 } else {
221 Ok(None)
222 }
223 }
224 Err(err) => Err(err),
225 }
226 }
227}
228
229impl<'a, T> ConfigBuilder<'a, T> {
230 fn parses_and_collect<U>(self) -> Result<ConfigValue<U>>
231 where
232 T: IntoIterator,
233 T::Item: FromStr,
234 <T::Item as FromStr>::Err: std::error::Error + Send + Sync + 'static,
235 U: FromIterator<<T as IntoIterator>::Item> + Default,
236 {
237 if let Some(value) = self.explicit {
238 return Ok(ConfigValue::Explicit(value.into_iter().collect()));
239 }
240
241 let result = match Vec::<String>::get_config_value(self.config, self.key) {
242 Ok(entries) if !entries.is_empty() => {
243 let mut result = Vec::new();
244 if self.comma_separated {
245 for entry in entries {
246 for item in entry.split(',') {
247 if !item.is_empty() {
248 let value = <T::Item>::from_str(item)?;
249 result.push(value);
250 }
251 }
252 }
253 } else {
254 for entry in entries {
255 let value = <T::Item>::from_str(&entry)?;
256 result.push(value);
257 }
258 }
259
260 ConfigValue::GitConfig(result.into_iter().collect())
261 }
262 Ok(_) => {
263 if let Some(default) = self.default {
264 ConfigValue::Implicit(default.into_iter().collect())
265 } else {
266 ConfigValue::Implicit(U::default())
267 }
268 }
269 Err(err) => return Err(err.into()),
270 };
271 Ok(result)
272 }
273}
274
275pub trait ConfigValues {
276 fn get_config_value(config: &GitConfig, key: &str) -> Result<Self, git2::Error>
277 where
278 Self: Sized;
279}
280
281impl ConfigValues for String {
282 fn get_config_value(config: &GitConfig, key: &str) -> Result<Self, git2::Error> {
283 config.get_string(key)
284 }
285}
286
287impl ConfigValues for Vec<String> {
288 fn get_config_value(config: &GitConfig, key: &str) -> Result<Self, git2::Error> {
289 let mut result = Vec::new();
290 let mut entries = config.entries(Some(key))?;
291 while let Some(entry) = entries.next() {
292 let entry = entry?;
293 if let Some(value) = entry.value() {
294 result.push(value.to_owned());
295 } else {
296 warn!(
297 "non utf-8 config entry {}",
298 String::from_utf8_lossy(entry.name_bytes())
299 );
300 }
301 }
302 Ok(result)
303 }
304}
305
306impl ConfigValues for bool {
307 fn get_config_value(config: &GitConfig, key: &str) -> Result<Self, git2::Error> {
308 config.get_bool(key)
309 }
310}
311
312impl ConfigValues for u64 {
313 fn get_config_value(config: &GitConfig, key: &str) -> Result<Self, git2::Error> {
314 let value = config.get_i64(key)?;
315 if value >= 0 {
316 return Ok(value as u64);
317 }
318 panic!("`git config {}` cannot be negative value", key);
319 }
320}
321
322fn config_not_exist(err: &git2::Error) -> bool {
323 err.code() == ErrorCode::NotFound && err.class() == ErrorClass::Config
324}
325
326pub fn get_push_remote(config: &GitConfig, branch: &LocalBranch) -> Result<String> {
327 let push_remote_key = format!("branch.{}.pushRemote", branch.short_name());
328 if let Some(push_remote) = get::<String>(config, &push_remote_key).read()? {
329 return Ok(push_remote.unwrap());
330 }
331
332 if let Some(push_default) = get::<String>(config, "remote.pushDefault").read()? {
333 return Ok(push_default.unwrap());
334 }
335
336 Ok(get_remote_name(config, branch)?.unwrap_or_else(|| "origin".to_owned()))
337}
338
339pub fn get_remote_name(config: &GitConfig, branch: &LocalBranch) -> Result<Option<String>> {
340 let key = format!("branch.{}.remote", branch.short_name());
341 match config.get_string(&key) {
342 Ok(remote) => Ok(Some(remote)),
343 Err(err) if config_not_exist(&err) => Ok(None),
344 Err(err) => Err(err.into()),
345 }
346}
347
348pub fn get_remote<'a>(repo: &'a Repository, remote_name: &str) -> Result<Option<Remote<'a>>> {
349 fn error_is_missing_remote(err: &Error) -> bool {
350 err.class() == ErrorClass::Config && err.code() == ErrorCode::InvalidSpec
351 }
352
353 match repo.find_remote(remote_name) {
354 Ok(remote) => Ok(Some(remote)),
355 Err(err) if error_is_missing_remote(&err) => Ok(None),
356 Err(err) => Err(err.into()),
357 }
358}
359
360pub fn get_merge(config: &GitConfig, branch: &LocalBranch) -> Result<Option<String>> {
361 let key = format!("branch.{}.merge", branch.short_name());
362 match config.get_string(&key) {
363 Ok(merge) => Ok(Some(merge)),
364 Err(err) if config_not_exist(&err) => Ok(None),
365 Err(err) => Err(err.into()),
366 }
367}