1use std::path::Path;
7
8use git2::{Repository, Signature};
9use libgrite_core::{Lock, LockPolicy, LockCheckResult, resource_hash, DEFAULT_LOCK_TTL_MS};
10
11use crate::GitError;
12
13#[derive(Debug, Clone, Default)]
15pub struct LockGcStats {
16 pub removed: usize,
18 pub kept: usize,
20}
21
22pub struct LockManager {
24 repo: Repository,
25}
26
27impl LockManager {
28 pub fn open(git_dir: &Path) -> Result<Self, GitError> {
30 let repo = Repository::open(git_dir)?;
31 Ok(Self { repo })
32 }
33
34 pub fn acquire(&self, resource: &str, owner: &str, ttl_ms: Option<u64>) -> Result<Lock, GitError> {
38 let ttl = ttl_ms.unwrap_or(DEFAULT_LOCK_TTL_MS);
39 let ref_name = lock_ref_name(resource);
40
41 if let Some(existing) = self.read_lock(resource)? {
43 if !existing.is_expired() {
44 if existing.owner == owner {
45 return self.renew(resource, owner, Some(ttl));
47 } else {
48 let expires_in_ms = existing.time_remaining_ms();
49 return Err(GitError::LockConflict {
50 resource: resource.to_string(),
51 owner: existing.owner,
52 expires_in_ms,
53 });
54 }
55 }
56 }
58
59 let lock = Lock::new(owner.to_string(), resource.to_string(), ttl);
61 self.write_lock(&ref_name, &lock)?;
62
63 Ok(lock)
64 }
65
66 pub fn release(&self, resource: &str, owner: &str) -> Result<(), GitError> {
68 let ref_name = lock_ref_name(resource);
69
70 if let Some(existing) = self.read_lock(resource)? {
72 if existing.owner != owner && !existing.is_expired() {
73 return Err(GitError::LockNotOwned {
74 resource: resource.to_string(),
75 owner: existing.owner,
76 });
77 }
78 }
79
80 self.delete_ref(&ref_name)?;
82
83 Ok(())
84 }
85
86 pub fn renew(&self, resource: &str, owner: &str, ttl_ms: Option<u64>) -> Result<Lock, GitError> {
88 let ttl = ttl_ms.unwrap_or(DEFAULT_LOCK_TTL_MS);
89 let ref_name = lock_ref_name(resource);
90
91 if let Some(mut existing) = self.read_lock(resource)? {
93 if existing.owner != owner {
94 return Err(GitError::LockNotOwned {
95 resource: resource.to_string(),
96 owner: existing.owner,
97 });
98 }
99
100 existing.renew(ttl);
102 self.write_lock(&ref_name, &existing)?;
103 return Ok(existing);
104 }
105
106 self.acquire(resource, owner, Some(ttl))
108 }
109
110 pub fn read_lock(&self, resource: &str) -> Result<Option<Lock>, GitError> {
112 let ref_name = lock_ref_name(resource);
113 self.read_lock_ref(&ref_name)
114 }
115
116 pub fn list_locks(&self) -> Result<Vec<Lock>, GitError> {
118 let mut locks = Vec::new();
119
120 let refs = self.repo.references_glob("refs/grite/locks/*")?;
122 for ref_result in refs {
123 let reference = ref_result?;
124 if let Some(lock) = self.read_lock_from_ref(&reference)? {
125 locks.push(lock);
126 }
127 }
128
129 Ok(locks)
130 }
131
132 pub fn check_conflicts(&self, resource: &str, current_owner: &str, policy: LockPolicy) -> Result<LockCheckResult, GitError> {
134 if policy == LockPolicy::Off {
135 return Ok(LockCheckResult::Clear);
136 }
137
138 let locks = self.list_locks()?;
139 let conflicts: Vec<Lock> = locks
140 .into_iter()
141 .filter(|lock| {
142 !lock.is_expired() &&
143 lock.owner != current_owner &&
144 lock.conflicts_with(resource)
145 })
146 .collect();
147
148 if conflicts.is_empty() {
149 Ok(LockCheckResult::Clear)
150 } else if policy == LockPolicy::Warn {
151 Ok(LockCheckResult::Warning(conflicts))
152 } else {
153 Ok(LockCheckResult::Blocked(conflicts))
154 }
155 }
156
157 pub fn gc(&self) -> Result<LockGcStats, GitError> {
159 let mut stats = LockGcStats::default();
160
161 let refs: Vec<_> = self.repo.references_glob("refs/grite/locks/*")?
162 .collect::<Result<Vec<_>, _>>()?;
163
164 for reference in refs {
165 if let Some(lock) = self.read_lock_from_ref(&reference)? {
166 if lock.is_expired() {
167 if let Some(name) = reference.name() {
168 self.delete_ref(name)?;
169 stats.removed += 1;
170 }
171 } else {
172 stats.kept += 1;
173 }
174 }
175 }
176
177 Ok(stats)
178 }
179
180 fn read_lock_ref(&self, ref_name: &str) -> Result<Option<Lock>, GitError> {
182 let reference = match self.repo.find_reference(ref_name) {
183 Ok(r) => r,
184 Err(e) if e.code() == git2::ErrorCode::NotFound => return Ok(None),
185 Err(e) => return Err(e.into()),
186 };
187
188 self.read_lock_from_ref(&reference)
189 }
190
191 fn read_lock_from_ref(&self, reference: &git2::Reference) -> Result<Option<Lock>, GitError> {
193 let commit = reference.peel_to_commit()?;
194 let tree = commit.tree()?;
195
196 let entry = match tree.get_name("lock.json") {
198 Some(e) => e,
199 None => return Ok(None),
200 };
201
202 let blob = self.repo.find_blob(entry.id())?;
203 let content = std::str::from_utf8(blob.content())
204 .map_err(|e| GitError::ParseError(e.to_string()))?;
205
206 let lock: Lock = serde_json::from_str(content)
207 .map_err(|e| GitError::ParseError(e.to_string()))?;
208
209 Ok(Some(lock))
210 }
211
212 fn write_lock(&self, ref_name: &str, lock: &Lock) -> Result<(), GitError> {
214 let json = serde_json::to_string_pretty(lock)
215 .map_err(|e| GitError::ParseError(e.to_string()))?;
216
217 let blob_id = self.repo.blob(json.as_bytes())?;
219
220 let mut tree_builder = self.repo.treebuilder(None)?;
222 tree_builder.insert("lock.json", blob_id, 0o100644)?;
223 let tree_id = tree_builder.write()?;
224 let tree = self.repo.find_tree(tree_id)?;
225
226 let sig = Signature::now("grite", "grit@localhost")?;
228 let message = format!("Lock: {}", lock.resource);
229
230 let parent = self.repo.find_reference(ref_name)
232 .ok()
233 .and_then(|r| r.peel_to_commit().ok());
234
235 let parents: Vec<&git2::Commit> = parent.iter().collect();
236
237 let _commit_id = self.repo.commit(
238 Some(ref_name),
239 &sig,
240 &sig,
241 &message,
242 &tree,
243 &parents,
244 )?;
245
246 Ok(())
247 }
248
249 fn delete_ref(&self, ref_name: &str) -> Result<(), GitError> {
251 match self.repo.find_reference(ref_name) {
252 Ok(mut reference) => {
253 reference.delete()?;
254 Ok(())
255 }
256 Err(e) if e.code() == git2::ErrorCode::NotFound => Ok(()),
257 Err(e) => Err(e.into()),
258 }
259 }
260}
261
262fn lock_ref_name(resource: &str) -> String {
264 format!("refs/grite/locks/{}", resource_hash(resource))
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270 use tempfile::tempdir;
271
272 fn setup_repo() -> tempfile::TempDir {
273 let dir = tempdir().unwrap();
274 let repo = Repository::init(dir.path()).unwrap();
275
276 let sig = Signature::now("test", "test@test.com").unwrap();
278 let tree_id = repo.treebuilder(None).unwrap().write().unwrap();
279 {
280 let tree = repo.find_tree(tree_id).unwrap();
281 repo.commit(Some("HEAD"), &sig, &sig, "Initial", &tree, &[]).unwrap();
282 }
283
284 dir
285 }
286
287 #[test]
288 fn test_acquire_and_release() {
289 let dir = setup_repo();
290 let manager = LockManager::open(dir.path()).unwrap();
291
292 let lock = manager.acquire("repo:global", "actor1", Some(60000)).unwrap();
294 assert_eq!(lock.owner, "actor1");
295 assert_eq!(lock.resource, "repo:global");
296 assert!(!lock.is_expired());
297
298 let read_lock = manager.read_lock("repo:global").unwrap().unwrap();
300 assert_eq!(read_lock.owner, "actor1");
301
302 manager.release("repo:global", "actor1").unwrap();
304
305 let read_lock = manager.read_lock("repo:global").unwrap();
307 assert!(read_lock.is_none());
308 }
309
310 #[test]
311 fn test_acquire_conflict() {
312 let dir = setup_repo();
313 let manager = LockManager::open(dir.path()).unwrap();
314
315 manager.acquire("repo:global", "actor1", Some(60000)).unwrap();
317
318 let result = manager.acquire("repo:global", "actor2", Some(60000));
320 assert!(result.is_err());
321 }
322
323 #[test]
324 fn test_renew_lock() {
325 let dir = setup_repo();
326 let manager = LockManager::open(dir.path()).unwrap();
327
328 let lock1 = manager.acquire("issue:abc123", "actor1", Some(1000)).unwrap();
330 let expires1 = lock1.expires_unix_ms;
331
332 std::thread::sleep(std::time::Duration::from_millis(10));
334
335 let lock2 = manager.renew("issue:abc123", "actor1", Some(60000)).unwrap();
337 assert!(lock2.expires_unix_ms > expires1);
338 }
339
340 #[test]
341 fn test_list_locks() {
342 let dir = setup_repo();
343 let manager = LockManager::open(dir.path()).unwrap();
344
345 manager.acquire("repo:global", "actor1", Some(60000)).unwrap();
347 manager.acquire("issue:abc123", "actor2", Some(60000)).unwrap();
348
349 let locks = manager.list_locks().unwrap();
351 assert_eq!(locks.len(), 2);
352 }
353
354 #[test]
355 fn test_gc_expired() {
356 let dir = setup_repo();
357 let manager = LockManager::open(dir.path()).unwrap();
358
359 manager.acquire("issue:abc123", "actor1", Some(1)).unwrap();
361
362 std::thread::sleep(std::time::Duration::from_millis(10));
364
365 let stats = manager.gc().unwrap();
367 assert_eq!(stats.removed, 1);
368 assert_eq!(stats.kept, 0);
369
370 let locks = manager.list_locks().unwrap();
372 assert!(locks.is_empty());
373 }
374
375 #[test]
376 fn test_check_conflicts() {
377 let dir = setup_repo();
378 let manager = LockManager::open(dir.path()).unwrap();
379
380 manager.acquire("repo:global", "actor1", Some(60000)).unwrap();
382
383 let result = manager.check_conflicts("issue:abc123", "actor2", LockPolicy::Warn).unwrap();
385 assert!(matches!(result, LockCheckResult::Warning(_)));
386
387 let result = manager.check_conflicts("issue:abc123", "actor2", LockPolicy::Require).unwrap();
388 assert!(matches!(result, LockCheckResult::Blocked(_)));
389
390 let result = manager.check_conflicts("issue:abc123", "actor1", LockPolicy::Require).unwrap();
392 assert!(matches!(result, LockCheckResult::Clear));
393 }
394}