Skip to main content

git_workty/
config.rs

1use crate::git::GitRepo;
2use anyhow::{Context, Result};
3use serde::{Deserialize, Serialize};
4use sha2::{Digest, Sha256};
5use std::path::PathBuf;
6
7const CONFIG_FILENAME: &str = "workty.toml";
8const DEFAULT_BASE: &str = "main";
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11#[serde(default)]
12pub struct Config {
13    pub version: u32,
14    pub base: String,
15    pub root: String,
16    pub layout: String,
17    pub open_cmd: Option<String>,
18}
19
20impl Default for Config {
21    fn default() -> Self {
22        Self {
23            version: 1,
24            base: DEFAULT_BASE.to_string(),
25            root: "~/.workty/{repo}-{id}".to_string(),
26            layout: "flat".to_string(),
27            open_cmd: None,
28        }
29    }
30}
31
32impl Config {
33    pub fn load(repo: &GitRepo) -> Result<Self> {
34        let mut candidates = vec![
35            // 1. Repo root
36            repo.root.join(CONFIG_FILENAME),
37            // 2. Git dir
38            config_path(repo),
39        ];
40
41        // 3. User config dir (~/.config/workty/workty.toml)
42        if let Some(config_dir) = dirs::config_dir() {
43            candidates.push(config_dir.join("workty").join(CONFIG_FILENAME));
44        }
45
46        if let Some(home) = dirs::home_dir() {
47            // 4. ~/.workty.toml
48            candidates.push(home.join(format!(".{}", CONFIG_FILENAME)));
49            // 5. ~/workty.toml
50            candidates.push(home.join(CONFIG_FILENAME));
51        }
52
53        let mut config: Self = candidates
54            .into_iter()
55            .find(|path| path.exists())
56            .map(|path| {
57                let contents = std::fs::read_to_string(&path)
58                    .with_context(|| format!("Failed to read config from {}", path.display()))?;
59                toml::from_str(&contents)
60                    .with_context(|| format!("Failed to parse config from {}", path.display()))
61            })
62            .transpose()?
63            .unwrap_or_default();
64
65        config.adjust_defaults(repo);
66
67        Ok(config)
68    }
69
70    fn adjust_defaults(&mut self, repo: &GitRepo) {
71        // If the base branch is the default one but it doesn't exist,
72        // we try to detect the actual default branch (e.g. master, trunk, etc)
73        if self.base == DEFAULT_BASE && !repo.branch_exists(DEFAULT_BASE) {
74            if let Some(default) = repo.default_branch() {
75                self.base = default;
76            }
77        }
78    }
79
80    #[allow(dead_code)]
81    pub fn save(&self, repo: &GitRepo) -> Result<()> {
82        let path = config_path(repo);
83        let contents = toml::to_string_pretty(self).context("Failed to serialize config")?;
84        std::fs::write(&path, contents)
85            .with_context(|| format!("Failed to write config to {}", path.display()))
86    }
87
88    pub fn workspace_root(&self, repo: &GitRepo) -> PathBuf {
89        let repo_name = repo
90            .root
91            .file_name()
92            .and_then(|s| s.to_str())
93            .unwrap_or("repo");
94
95        let id = compute_repo_id(repo);
96
97        let expanded = self.root.replace("{repo}", repo_name).replace("{id}", &id);
98
99        expand_tilde(&expanded)
100    }
101
102    pub fn worktree_path(&self, repo: &GitRepo, branch_slug: &str) -> PathBuf {
103        let root = self.workspace_root(repo);
104        root.join(branch_slug)
105    }
106}
107
108pub fn config_path(repo: &GitRepo) -> PathBuf {
109    repo.common_dir.join(CONFIG_FILENAME)
110}
111
112pub fn config_exists(repo: &GitRepo) -> bool {
113    config_path(repo).exists()
114}
115
116fn compute_repo_id(repo: &GitRepo) -> String {
117    let input = repo
118        .origin_url()
119        .unwrap_or_else(|| repo.common_dir.to_string_lossy().to_string());
120
121    let normalized = normalize_url(&input);
122    let mut hasher = Sha256::new();
123    hasher.update(normalized.as_bytes());
124    let result = hasher.finalize();
125    hex::encode(&result[..4])
126}
127
128fn normalize_url(url: &str) -> String {
129    url.trim()
130        .trim_end_matches('/')
131        .trim_end_matches(".git")
132        .to_lowercase()
133}
134
135fn expand_tilde(path: &str) -> PathBuf {
136    if path == "~" {
137        return dirs::home_dir().unwrap_or_else(|| PathBuf::from("~"));
138    }
139
140    if let Some(rest) = path.strip_prefix("~/") {
141        if let Some(home) = dirs::home_dir() {
142            return home.join(rest);
143        }
144    }
145    PathBuf::from(path)
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151
152    #[test]
153    fn test_config_default() {
154        let config = Config::default();
155        assert_eq!(config.version, 1);
156        assert_eq!(config.base, "main");
157        assert_eq!(config.layout, "flat");
158    }
159
160    #[test]
161    fn test_normalize_url() {
162        assert_eq!(
163            normalize_url("https://github.com/user/repo.git"),
164            "https://github.com/user/repo"
165        );
166        assert_eq!(
167            normalize_url("git@github.com:user/repo.git/"),
168            "git@github.com:user/repo"
169        );
170    }
171
172    #[test]
173    fn test_expand_tilde() {
174        // We can't easily verify the exact home dir path in a cross-platform way without dirs::home_dir
175        // but we can check that it doesn't panic and returns something different than "~" if home exists
176        if let Some(home) = dirs::home_dir() {
177            assert_eq!(expand_tilde("~"), home);
178            assert_eq!(expand_tilde("~/foo"), home.join("foo"));
179        }
180
181        assert_eq!(expand_tilde("/abs/path"), PathBuf::from("/abs/path"));
182        assert_eq!(expand_tilde("rel/path"), PathBuf::from("rel/path"));
183    }
184
185    #[test]
186    fn test_config_roundtrip() {
187        let config = Config {
188            version: 1,
189            base: "develop".to_string(),
190            root: "~/.worktrees/{repo}".to_string(),
191            layout: "flat".to_string(),
192            open_cmd: Some("code".to_string()),
193        };
194
195        let serialized = toml::to_string_pretty(&config).unwrap();
196        let deserialized: Config = toml::from_str(&serialized).unwrap();
197
198        assert_eq!(config.base, deserialized.base);
199        assert_eq!(config.open_cmd, deserialized.open_cmd);
200    }
201}