Skip to main content

repo/
stash.rs

1// SPDX-License-Identifier: Apache-2.0
2//! Stash storage and operations.
3
4use std::{fs, path::PathBuf};
5
6use objects::{
7    fs_atomic::{sync_directory, temp_path, write_file_atomic},
8    fs_ops::remove_path_recursively,
9    lock::RepoLock,
10    object::{ChangeId, ContentHash},
11};
12use serde::{Deserialize, Serialize};
13
14use crate::{Repository, Result};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct StashEntry {
18    pub index: usize,
19    pub change_id: ChangeId,
20    pub tree_hash: String,
21    pub parent_tree_hash: String,
22    pub message: Option<String>,
23    pub created_at: chrono::DateTime<chrono::Utc>,
24}
25
26pub struct StashManager {
27    stash_dir: PathBuf,
28    lock: RepoLock,
29}
30
31impl StashManager {
32    pub fn new(heddle_dir: impl AsRef<std::path::Path>) -> Self {
33        Self {
34            stash_dir: heddle_dir.as_ref().join("stashes"),
35            lock: RepoLock::at(heddle_dir.as_ref().join("locks/stash.lock")),
36        }
37    }
38
39    pub fn init(&self) -> Result<()> {
40        if !self.stash_dir.exists() {
41            fs::create_dir_all(&self.stash_dir)?;
42        }
43        Ok(())
44    }
45
46    pub fn push(
47        &self,
48        tree_hash: ContentHash,
49        parent_tree_hash: String,
50        message: Option<String>,
51    ) -> Result<StashEntry> {
52        let _lock = self.write_lock()?;
53        let stashes = self.list_unlocked()?;
54
55        let change_id = ChangeId::generate();
56
57        let entry = StashEntry {
58            index: stashes.len(),
59            change_id,
60            tree_hash: tree_hash.to_string(),
61            parent_tree_hash,
62            message,
63            created_at: chrono::Utc::now(),
64        };
65
66        let entry_path = self.stash_dir.join(format!("{}", entry.index));
67        let content = serde_json::to_string(&entry)?;
68        write_file_atomic(&entry_path, content.as_bytes())?;
69
70        Ok(entry)
71    }
72
73    pub fn list(&self) -> Result<Vec<StashEntry>> {
74        let _lock = self.read_lock()?;
75        self.list_unlocked()
76    }
77
78    fn list_unlocked(&self) -> Result<Vec<StashEntry>> {
79        if !self.stash_dir.exists() {
80            return Ok(Vec::new());
81        }
82
83        let mut stashes = Vec::new();
84
85        for entry in fs::read_dir(&self.stash_dir)? {
86            let entry = entry?;
87            let path = entry.path();
88
89            if path.extension().is_none()
90                && let Ok(content) = fs::read_to_string(&path)
91                && let Ok(stash) = serde_json::from_str::<StashEntry>(&content)
92            {
93                stashes.push(stash);
94            }
95        }
96
97        stashes.sort_by_key(|s| s.index);
98        Ok(stashes)
99    }
100
101    pub fn top(&self) -> Result<Option<StashEntry>> {
102        let stashes = self.list()?;
103        Ok(stashes.last().cloned())
104    }
105
106    pub fn drop(&self) -> Result<Option<StashEntry>> {
107        let _lock = self.write_lock()?;
108        let mut stashes = self.list_unlocked()?;
109
110        if stashes.is_empty() {
111            return Ok(None);
112        }
113
114        let removed = stashes.pop();
115        self.rewrite_unlocked(&mut stashes)?;
116        Ok(removed)
117    }
118
119    pub fn pop_with<F>(&self, apply: F) -> Result<Option<StashEntry>>
120    where
121        F: FnOnce(&StashEntry) -> Result<()>,
122    {
123        let _lock = self.write_lock()?;
124        let mut stashes = self.list_unlocked()?;
125
126        let Some(removed) = stashes.pop() else {
127            return Ok(None);
128        };
129
130        apply(&removed)?;
131        self.rewrite_unlocked(&mut stashes)?;
132        Ok(Some(removed))
133    }
134
135    fn rewrite_unlocked(&self, stashes: &mut [StashEntry]) -> Result<()> {
136        let parent = self
137            .stash_dir
138            .parent()
139            .ok_or_else(|| std::io::Error::other("invalid stash directory"))?;
140        fs::create_dir_all(parent)?;
141
142        let replacement_dir = temp_path(&self.stash_dir);
143        fs::create_dir_all(&replacement_dir)?;
144
145        for (new_index, entry) in stashes.iter_mut().enumerate() {
146            entry.index = new_index;
147            let path = replacement_dir.join(format!("{}", new_index));
148            let content = serde_json::to_string(entry)?;
149            write_file_atomic(&path, content.as_bytes())?;
150        }
151
152        sync_directory(&replacement_dir)?;
153
154        let backup_dir = self.stash_dir.with_extension("old");
155        remove_stash_path(&backup_dir)?;
156        fs::rename(&self.stash_dir, &backup_dir)?;
157        sync_directory(parent)?;
158        if let Err(error) = fs::rename(&replacement_dir, &self.stash_dir) {
159            fs::rename(&backup_dir, &self.stash_dir)?;
160            sync_directory(parent)?;
161            return Err(error.into());
162        }
163        sync_directory(parent)?;
164        remove_stash_path(&backup_dir)?;
165        sync_directory(parent)?;
166
167        Ok(())
168    }
169
170    pub fn clear(&self) -> Result<usize> {
171        let _lock = self.write_lock()?;
172        let stashes = self.list_unlocked()?;
173        let count = stashes.len();
174
175        if self.stash_dir.exists() {
176            if self.stash_dir.is_symlink() {
177                fs::remove_file(&self.stash_dir)?;
178            } else {
179                remove_path_recursively(&self.stash_dir)?;
180            }
181        }
182        fs::create_dir_all(&self.stash_dir)?;
183
184        Ok(count)
185    }
186
187    fn read_lock(&self) -> Result<objects::lock::ReadLockGuard> {
188        self.lock.read().map_err(|err| {
189            std::io::Error::other(format!("failed to acquire stash lock: {err}")).into()
190        })
191    }
192
193    fn write_lock(&self) -> Result<objects::lock::WriteLockGuard> {
194        self.lock.write().map_err(|err| {
195            std::io::Error::other(format!("failed to acquire stash lock: {err}")).into()
196        })
197    }
198}
199
200fn remove_stash_path(path: &std::path::Path) -> Result<()> {
201    if !path.exists() {
202        return Ok(());
203    }
204
205    if path.is_symlink() {
206        fs::remove_file(path)?;
207    } else {
208        remove_path_recursively(path)?;
209    }
210
211    Ok(())
212}
213
214impl Repository {
215    pub fn stash_manager(&self) -> StashManager {
216        StashManager::new(self.heddle_dir())
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use std::{
223        sync::{Arc, Barrier},
224        thread,
225    };
226
227    use tempfile::TempDir;
228
229    use super::*;
230
231    fn create_manager() -> (TempDir, StashManager) {
232        let temp_dir = TempDir::new().unwrap();
233        let heddle_dir = temp_dir.path().join(".heddle");
234        let manager = StashManager::new(&heddle_dir);
235        manager.init().unwrap();
236        (temp_dir, manager)
237    }
238
239    #[test]
240    fn test_drop_rewrites_remaining_entries() {
241        let (_temp_dir, manager) = create_manager();
242        let first = manager
243            .push(ContentHash::compute(b"one"), "parent-1".to_string(), None)
244            .unwrap();
245        let second = manager
246            .push(ContentHash::compute(b"two"), "parent-2".to_string(), None)
247            .unwrap();
248        let third = manager
249            .push(ContentHash::compute(b"three"), "parent-3".to_string(), None)
250            .unwrap();
251
252        let removed = manager.drop().unwrap().unwrap();
253        assert_eq!(removed.change_id, third.change_id);
254
255        let remaining = manager.list().unwrap();
256        assert_eq!(remaining.len(), 2);
257        assert_eq!(remaining[0].index, 0);
258        assert_eq!(remaining[0].change_id, first.change_id);
259        assert_eq!(remaining[1].index, 1);
260        assert_eq!(remaining[1].change_id, second.change_id);
261
262        let temp_entries = fs::read_dir(&manager.stash_dir)
263            .unwrap()
264            .filter_map(|entry| entry.ok())
265            .filter(|entry| entry.file_name().to_string_lossy().contains(".tmp-"))
266            .count();
267        assert_eq!(temp_entries, 0);
268        assert!(!manager.stash_dir.with_extension("old").exists());
269    }
270
271    #[test]
272    fn test_pop_with_drops_only_after_successful_apply() {
273        let (_temp_dir, manager) = create_manager();
274        let first = manager
275            .push(ContentHash::compute(b"one"), "parent-1".to_string(), None)
276            .unwrap();
277        let second = manager
278            .push(ContentHash::compute(b"two"), "parent-2".to_string(), None)
279            .unwrap();
280
281        let error = manager
282            .pop_with(|_| Err(std::io::Error::other("apply failed").into()))
283            .unwrap_err();
284        assert!(error.to_string().contains("apply failed"));
285        assert_eq!(manager.list().unwrap().len(), 2);
286
287        let applied = manager
288            .pop_with(|stash| {
289                assert_eq!(stash.change_id, second.change_id);
290                Ok(())
291            })
292            .unwrap()
293            .unwrap();
294        assert_eq!(applied.change_id, second.change_id);
295
296        let remaining = manager.list().unwrap();
297        assert_eq!(remaining.len(), 1);
298        assert_eq!(remaining[0].index, 0);
299        assert_eq!(remaining[0].change_id, first.change_id);
300    }
301
302    #[test]
303    fn test_concurrent_pushes_preserve_all_entries() {
304        let (_temp_dir, manager) = create_manager();
305        let manager = Arc::new(manager);
306        let barrier = Arc::new(Barrier::new(9));
307        let mut handles = Vec::new();
308
309        for i in 0..8 {
310            let manager = Arc::clone(&manager);
311            let barrier = Arc::clone(&barrier);
312            handles.push(thread::spawn(move || {
313                barrier.wait();
314                manager
315                    .push(
316                        ContentHash::compute(format!("tree-{i}").as_bytes()),
317                        format!("parent-{i}"),
318                        Some(format!("stash-{i}")),
319                    )
320                    .unwrap();
321            }));
322        }
323
324        barrier.wait();
325
326        for handle in handles {
327            handle.join().unwrap();
328        }
329
330        let stashes = manager.list().unwrap();
331        assert_eq!(stashes.len(), 8);
332
333        let mut indices: Vec<_> = stashes.iter().map(|entry| entry.index).collect();
334        indices.sort_unstable();
335        assert_eq!(indices, (0..8).collect::<Vec<_>>());
336
337        let change_ids: std::collections::HashSet<_> =
338            stashes.iter().map(|entry| entry.change_id).collect();
339        assert_eq!(change_ids.len(), 8);
340    }
341}