Skip to main content

repo/
merge_state.rs

1// SPDX-License-Identifier: Apache-2.0
2//! Merge state tracking for conflict resolution.
3
4use std::{fs, path::PathBuf};
5
6use objects::{fs_atomic::write_file_atomic, lock::RepoLock, object::ChangeId};
7use serde::{Deserialize, Serialize};
8
9use crate::{Repository, Result};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct MergeState {
13    pub ours: ChangeId,
14    pub theirs: ChangeId,
15    pub base: Option<ChangeId>,
16    pub conflicts: Vec<String>,
17    pub resolved: Vec<String>,
18}
19
20pub struct MergeStateManager {
21    merge_state_path: PathBuf,
22    lock: RepoLock,
23}
24
25impl MergeStateManager {
26    pub fn new(heddle_dir: impl AsRef<std::path::Path>) -> Self {
27        let heddle_dir = heddle_dir.as_ref();
28        Self {
29            merge_state_path: heddle_dir.join("MERGE_STATE"),
30            lock: RepoLock::at(heddle_dir.join("locks/merge_state.lock")),
31        }
32    }
33
34    pub fn start(
35        &self,
36        ours: ChangeId,
37        theirs: ChangeId,
38        base: Option<ChangeId>,
39        conflicts: Vec<String>,
40    ) -> Result<()> {
41        let _lock = self.write_lock()?;
42        let state = MergeState {
43            ours,
44            theirs,
45            base,
46            conflicts,
47            resolved: Vec::new(),
48        };
49        self.write_state(&state)?;
50        Ok(())
51    }
52
53    pub fn load(&self) -> Result<Option<MergeState>> {
54        let _lock = self.read_lock()?;
55        self.load_unlocked()
56    }
57
58    pub fn resolve(&self, path: &str) -> Result<()> {
59        let _lock = self.write_lock()?;
60        let mut state = self
61            .load_unlocked()?
62            .ok_or_else(|| crate::HeddleError::NotFound("No merge in progress".to_string()))?;
63
64        if state.conflicts.iter().any(|conflict| conflict == path)
65            && state.resolved.iter().all(|resolved| resolved != path)
66        {
67            state.resolved.push(path.to_string());
68        }
69
70        self.write_state(&state)?;
71        Ok(())
72    }
73
74    pub fn resolve_all(&self) -> Result<Vec<String>> {
75        let _lock = self.write_lock()?;
76        let mut state = self
77            .load_unlocked()?
78            .ok_or_else(|| crate::HeddleError::NotFound("No merge in progress".to_string()))?;
79
80        let newly_resolved: Vec<String> = state
81            .conflicts
82            .iter()
83            .filter(|c| !state.resolved.contains(c))
84            .cloned()
85            .collect();
86
87        state.resolved = state.conflicts.clone();
88
89        self.write_state(&state)?;
90        Ok(newly_resolved)
91    }
92
93    pub fn unresolved(&self) -> Result<Vec<String>> {
94        let _lock = self.read_lock()?;
95        let state = self
96            .load_unlocked()?
97            .ok_or_else(|| crate::HeddleError::NotFound("No merge in progress".to_string()))?;
98
99        Ok(state
100            .conflicts
101            .iter()
102            .filter(|c| !state.resolved.contains(c))
103            .cloned()
104            .collect())
105    }
106
107    pub fn abort(&self) -> Result<MergeState> {
108        let _lock = self.write_lock()?;
109        let state = self
110            .load_unlocked()?
111            .ok_or_else(|| crate::HeddleError::NotFound("No merge in progress".to_string()))?;
112
113        if !self.merge_state_path.exists() {
114            return Ok(state);
115        }
116
117        fs::remove_file(&self.merge_state_path)?;
118        Ok(state)
119    }
120
121    pub fn finish(&self) -> Result<MergeState> {
122        let _lock = self.write_lock()?;
123        let state = self
124            .load_unlocked()?
125            .ok_or_else(|| crate::HeddleError::NotFound("No merge in progress".to_string()))?;
126
127        let unresolved: Vec<_> = state
128            .conflicts
129            .iter()
130            .filter(|c| !state.resolved.contains(c))
131            .collect();
132
133        if !unresolved.is_empty() {
134            let unresolved_strs: Vec<&str> = unresolved.iter().map(|s| s.as_str()).collect();
135            return Err(crate::HeddleError::Conflict(format!(
136                "Unresolved conflicts: {}",
137                unresolved_strs.join(", ")
138            )));
139        }
140
141        if self.merge_state_path.exists() {
142            fs::remove_file(&self.merge_state_path)?;
143        }
144
145        Ok(state)
146    }
147
148    pub fn is_merge_in_progress(&self) -> bool {
149        self.read_lock().is_ok() && self.merge_state_path.exists()
150    }
151
152    fn load_unlocked(&self) -> Result<Option<MergeState>> {
153        if !self.merge_state_path.exists() {
154            return Ok(None);
155        }
156        let content = fs::read_to_string(&self.merge_state_path)?;
157        let state: MergeState = serde_json::from_str(&content)?;
158        Ok(Some(state))
159    }
160
161    fn write_state(&self, state: &MergeState) -> Result<()> {
162        let content = serde_json::to_vec(state)?;
163        write_file_atomic(&self.merge_state_path, &content)?;
164        Ok(())
165    }
166
167    fn read_lock(&self) -> Result<objects::lock::ReadLockGuard> {
168        self.lock
169            .read()
170            .map_err(|e| crate::HeddleError::Io(std::io::Error::other(e.to_string())))
171    }
172
173    fn write_lock(&self) -> Result<objects::lock::WriteLockGuard> {
174        self.lock
175            .write()
176            .map_err(|e| crate::HeddleError::Io(std::io::Error::other(e.to_string())))
177    }
178}
179
180impl Repository {
181    pub fn merge_state_manager(&self) -> MergeStateManager {
182        MergeStateManager::new(self.heddle_dir())
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use tempfile::TempDir;
189
190    use super::*;
191
192    fn create_manager() -> (TempDir, MergeStateManager) {
193        let temp = TempDir::new().unwrap();
194        let heddle_dir = temp.path().join(".heddle");
195        std::fs::create_dir_all(&heddle_dir).unwrap();
196        (temp, MergeStateManager::new(&heddle_dir))
197    }
198
199    fn sample_state_ids() -> (ChangeId, ChangeId, ChangeId) {
200        (
201            ChangeId::generate(),
202            ChangeId::generate(),
203            ChangeId::generate(),
204        )
205    }
206
207    #[test]
208    fn start_and_resolve_persist_state_atomically() {
209        let (_temp, manager) = create_manager();
210        let (ours, theirs, base) = sample_state_ids();
211
212        manager
213            .start(
214                ours,
215                theirs,
216                Some(base),
217                vec!["a.txt".to_string(), "b.txt".to_string()],
218            )
219            .unwrap();
220        manager.resolve("a.txt").unwrap();
221
222        let state = manager.load().unwrap().unwrap();
223        assert_eq!(state.ours, ours);
224        assert_eq!(state.theirs, theirs);
225        assert_eq!(state.base, Some(base));
226        assert_eq!(state.resolved, vec!["a.txt".to_string()]);
227    }
228
229    #[test]
230    fn resolve_all_marks_everything_resolved() {
231        let (_temp, manager) = create_manager();
232        let (ours, theirs, _base) = sample_state_ids();
233
234        manager
235            .start(
236                ours,
237                theirs,
238                None,
239                vec!["a.txt".to_string(), "b.txt".to_string()],
240            )
241            .unwrap();
242        let newly_resolved = manager.resolve_all().unwrap();
243
244        assert_eq!(
245            newly_resolved,
246            vec!["a.txt".to_string(), "b.txt".to_string()]
247        );
248        assert!(manager.unresolved().unwrap().is_empty());
249    }
250}