use std::collections::BTreeMap;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use super::merge::{Merge, merge_optional};
use crate::config::HooksConfig;
use crate::config::commands::CommandConfig;
#[derive(
Debug,
Clone,
Copy,
Default,
PartialEq,
Eq,
clap::ValueEnum,
serde::Serialize,
serde::Deserialize,
JsonSchema,
)]
#[serde(rename_all = "kebab-case")]
pub enum StageMode {
#[default]
All,
Tracked,
None,
}
#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq, JsonSchema)]
pub struct CommitGenerationConfig {
#[serde(default)]
pub command: Option<String>,
#[serde(default)]
pub template: Option<String>,
#[serde(default, rename = "template-file")]
pub template_file: Option<String>,
#[serde(default, rename = "squash-template")]
pub squash_template: Option<String>,
#[serde(default, rename = "squash-template-file")]
pub squash_template_file: Option<String>,
}
impl CommitGenerationConfig {
pub fn is_configured(&self) -> bool {
self.command
.as_ref()
.map(|s| !s.trim().is_empty())
.unwrap_or(false)
}
}
impl Merge for CommitGenerationConfig {
fn merge_with(&self, other: &Self) -> Self {
let (template, template_file) = if other.template.is_some() {
(other.template.clone(), None)
} else if other.template_file.is_some() {
(None, other.template_file.clone())
} else {
(self.template.clone(), self.template_file.clone())
};
let (squash_template, squash_template_file) = if other.squash_template.is_some() {
(other.squash_template.clone(), None)
} else if other.squash_template_file.is_some() {
(None, other.squash_template_file.clone())
} else {
(
self.squash_template.clone(),
self.squash_template_file.clone(),
)
};
Self {
command: other.command.clone().or_else(|| self.command.clone()),
template,
template_file,
squash_template,
squash_template_file,
}
}
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default, JsonSchema)]
pub struct ListConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub full: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub branches: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub remotes: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub summary: Option<bool>,
#[serde(rename = "task-timeout-ms", skip_serializing_if = "Option::is_none")]
pub task_timeout_ms: Option<u64>,
#[serde(rename = "timeout-ms", skip_serializing_if = "Option::is_none")]
pub timeout_ms: Option<u64>,
}
impl ListConfig {
pub fn full(&self) -> bool {
self.full.unwrap_or(false)
}
pub fn branches(&self) -> bool {
self.branches.unwrap_or(false)
}
pub fn remotes(&self) -> bool {
self.remotes.unwrap_or(false)
}
pub fn summary(&self) -> bool {
self.summary.unwrap_or(false)
}
pub fn task_timeout(&self) -> Option<std::time::Duration> {
self.task_timeout_ms
.filter(|&ms| ms > 0)
.map(std::time::Duration::from_millis)
}
pub fn timeout(&self) -> Option<std::time::Duration> {
self.timeout_ms
.filter(|&ms| ms > 0)
.map(std::time::Duration::from_millis)
}
}
impl Merge for ListConfig {
fn merge_with(&self, other: &Self) -> Self {
Self {
full: other.full.or(self.full),
branches: other.branches.or(self.branches),
remotes: other.remotes.or(self.remotes),
summary: other.summary.or(self.summary),
task_timeout_ms: other.task_timeout_ms.or(self.task_timeout_ms),
timeout_ms: other.timeout_ms.or(self.timeout_ms),
}
}
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default, JsonSchema)]
pub struct CommitConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub stage: Option<StageMode>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub generation: Option<CommitGenerationConfig>,
}
impl CommitConfig {
pub fn stage(&self) -> StageMode {
self.stage.unwrap_or_default()
}
}
impl Merge for CommitConfig {
fn merge_with(&self, other: &Self) -> Self {
Self {
stage: other.stage.or(self.stage),
generation: match (&self.generation, &other.generation) {
(None, None) => None,
(Some(s), None) => Some(s.clone()),
(None, Some(o)) => Some(o.clone()),
(Some(s), Some(o)) => Some(s.merge_with(o)),
},
}
}
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default, JsonSchema)]
pub struct MergeConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub squash: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub commit: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub rebase: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub remove: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub verify: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ff: Option<bool>,
}
impl MergeConfig {
pub fn squash(&self) -> bool {
self.squash.unwrap_or(true)
}
pub fn commit(&self) -> bool {
self.commit.unwrap_or(true)
}
pub fn rebase(&self) -> bool {
self.rebase.unwrap_or(true)
}
pub fn remove(&self) -> bool {
self.remove.unwrap_or(true)
}
pub fn verify(&self) -> bool {
self.verify.unwrap_or(true)
}
pub fn ff(&self) -> bool {
self.ff.unwrap_or(true)
}
}
impl Merge for MergeConfig {
fn merge_with(&self, other: &Self) -> Self {
Self {
squash: other.squash.or(self.squash),
commit: other.commit.or(self.commit),
rebase: other.rebase.or(self.rebase),
remove: other.remove.or(self.remove),
verify: other.verify.or(self.verify),
ff: other.ff.or(self.ff),
}
}
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default, JsonSchema)]
pub struct SwitchPickerConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub pager: Option<String>,
#[serde(rename = "timeout-ms", skip_serializing_if = "Option::is_none")]
pub timeout_ms: Option<u64>,
}
impl SwitchPickerConfig {
pub fn pager(&self) -> Option<&str> {
self.pager.as_deref()
}
pub fn timeout(&self) -> Option<std::time::Duration> {
if std::env::var_os("WORKTRUNK_TEST_PICKER_NO_TIMEOUT").is_some() {
return None;
}
match self.timeout_ms {
Some(0) => None,
Some(ms) => Some(std::time::Duration::from_millis(ms)),
None => Some(std::time::Duration::from_millis(500)),
}
}
}
impl Merge for SwitchPickerConfig {
fn merge_with(&self, other: &Self) -> Self {
Self {
pager: other.pager.clone().or_else(|| self.pager.clone()),
timeout_ms: other.timeout_ms.or(self.timeout_ms),
}
}
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default, JsonSchema)]
pub struct SwitchConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub cd: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub picker: Option<SwitchPickerConfig>,
}
impl SwitchConfig {
pub fn cd(&self) -> bool {
self.cd.unwrap_or(true)
}
}
impl Merge for SwitchConfig {
fn merge_with(&self, other: &Self) -> Self {
Self {
cd: other.cd.or(self.cd),
picker: match (&self.picker, &other.picker) {
(None, None) => None,
(Some(s), None) => Some(s.clone()),
(None, Some(o)) => Some(o.clone()),
(Some(s), Some(o)) => Some(s.merge_with(o)),
},
}
}
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default, JsonSchema)]
pub struct CopyIgnoredConfig {
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub exclude: Vec<String>,
}
impl CopyIgnoredConfig {
pub fn merged_with(&self, other: &Self) -> Self {
let mut exclude = self.exclude.clone();
for pattern in &other.exclude {
if !exclude.contains(pattern) {
exclude.push(pattern.clone());
}
}
Self { exclude }
}
}
impl Merge for CopyIgnoredConfig {
fn merge_with(&self, other: &Self) -> Self {
self.merged_with(other)
}
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default, JsonSchema)]
pub struct StepConfig {
#[serde(
default,
rename = "copy-ignored",
skip_serializing_if = "Option::is_none"
)]
pub copy_ignored: Option<CopyIgnoredConfig>,
}
impl StepConfig {
pub fn copy_ignored(&self) -> CopyIgnoredConfig {
self.copy_ignored.clone().unwrap_or_default()
}
}
impl Merge for StepConfig {
fn merge_with(&self, other: &Self) -> Self {
Self {
copy_ignored: merge_optional(self.copy_ignored.as_ref(), other.copy_ignored.as_ref()),
}
}
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default, JsonSchema)]
pub struct OverridableConfig {
#[serde(flatten, default)]
pub hooks: HooksConfig,
#[serde(
rename = "worktree-path",
default,
skip_serializing_if = "Option::is_none"
)]
pub worktree_path: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub list: Option<ListConfig>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub commit: Option<CommitConfig>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub merge: Option<MergeConfig>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub switch: Option<SwitchConfig>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub step: Option<StepConfig>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub aliases: Option<BTreeMap<String, CommandConfig>>,
}
impl OverridableConfig {
pub fn is_empty(&self) -> bool {
self.hooks == HooksConfig::default()
&& self.worktree_path.is_none()
&& self.list.is_none()
&& self.commit.is_none()
&& self.merge.is_none()
&& self.switch.is_none()
&& self.step.is_none()
&& self.aliases.is_none()
}
}
impl Merge for OverridableConfig {
fn merge_with(&self, other: &Self) -> Self {
use super::merge::merge_optional;
Self {
hooks: self.hooks.merge_with(&other.hooks), worktree_path: other
.worktree_path
.clone()
.or_else(|| self.worktree_path.clone()),
list: merge_optional(self.list.as_ref(), other.list.as_ref()),
commit: merge_optional(self.commit.as_ref(), other.commit.as_ref()),
merge: merge_optional(self.merge.as_ref(), other.merge.as_ref()),
switch: merge_optional(self.switch.as_ref(), other.switch.as_ref()),
step: merge_optional(self.step.as_ref(), other.step.as_ref()),
aliases: merge_alias_maps(&self.aliases, &other.aliases), }
}
}
fn merge_alias_maps(
base: &Option<BTreeMap<String, CommandConfig>>,
other: &Option<BTreeMap<String, CommandConfig>>,
) -> Option<BTreeMap<String, CommandConfig>> {
match (base, other) {
(None, None) => None,
(Some(b), None) => Some(b.clone()),
(None, Some(o)) => Some(o.clone()),
(Some(b), Some(o)) => {
let mut merged = b.clone();
crate::config::commands::append_aliases(&mut merged, o);
Some(merged)
}
}
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default, JsonSchema)]
pub struct UserProjectOverrides {
#[serde(
default,
rename = "approved-commands",
skip_serializing_if = "Vec::is_empty"
)]
pub approved_commands: Vec<String>,
#[serde(flatten, default)]
pub overrides: OverridableConfig,
}
impl UserProjectOverrides {
pub fn is_empty(&self) -> bool {
self.overrides.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_merge_alias_maps_both_none() {
assert_eq!(merge_alias_maps(&None, &None), None);
}
#[test]
fn test_merge_alias_maps_base_only() {
let base = BTreeMap::from([("a".into(), CommandConfig::single("1"))]);
let result = merge_alias_maps(&Some(base.clone()), &None);
assert_eq!(result, Some(base));
}
#[test]
fn test_merge_alias_maps_other_only() {
let other = BTreeMap::from([("b".into(), CommandConfig::single("2"))]);
let result = merge_alias_maps(&None, &Some(other.clone()));
assert_eq!(result, Some(other));
}
#[test]
fn test_merge_alias_maps_appends_on_collision() {
let base = BTreeMap::from([
("a".into(), CommandConfig::single("1")),
("shared".into(), CommandConfig::single("base-cmd")),
]);
let other = BTreeMap::from([
("b".into(), CommandConfig::single("2")),
("shared".into(), CommandConfig::single("other-cmd")),
]);
let result = merge_alias_maps(&Some(base), &Some(other)).unwrap();
assert_eq!(result["a"].commands().count(), 1);
assert_eq!(result["b"].commands().count(), 1);
let shared: Vec<_> = result["shared"].commands().collect();
assert_eq!(shared.len(), 2);
assert_eq!(shared[0].template, "base-cmd");
assert_eq!(shared[1].template, "other-cmd");
}
}