1use std::path::Path;
7
8use git2::{Repository, Signature};
9use libgrite_core::{resource_hash, Lock, LockCheckResult, LockPolicy, DEFAULT_LOCK_TTL_MS};
10
11use crate::GitError;
12
13enum LockAcquireError {
15 Exists,
17 Git(GitError),
19}
20
21#[derive(Debug, Clone, Default)]
23pub struct LockGcStats {
24 pub removed: usize,
26 pub kept: usize,
28}
29
30pub struct LockManager {
32 repo: Repository,
33}
34
35impl LockManager {
36 pub fn open(git_dir: &Path) -> Result<Self, GitError> {
38 let repo = Repository::open(git_dir)?;
39 Ok(Self { repo })
40 }
41
42 pub fn acquire(
46 &self,
47 resource: &str,
48 owner: &str,
49 ttl_ms: Option<u64>,
50 ) -> Result<Lock, GitError> {
51 let ttl = ttl_ms.unwrap_or(DEFAULT_LOCK_TTL_MS);
52 let ref_name = lock_ref_name(resource);
53 let lock = Lock::new(owner.to_string(), resource.to_string(), ttl);
54
55 match self.try_create_lock(&ref_name, &lock) {
57 Ok(()) => Ok(lock),
58 Err(LockAcquireError::Exists) => {
59 if let Some(existing) = self.read_lock(resource)? {
61 if !existing.is_expired() {
62 if existing.owner == owner {
63 Ok(existing)
65 } else {
66 let expires_in_ms = existing.time_remaining_ms();
67 Err(GitError::LockConflict {
68 resource: resource.to_string(),
69 owner: existing.owner,
70 expires_in_ms,
71 })
72 }
73 } else {
74 self.delete_ref(&ref_name)?;
76 match self.try_create_lock(&ref_name, &lock) {
77 Ok(()) => Ok(lock),
78 Err(LockAcquireError::Exists) => {
79 if let Some(other) = self.read_lock(resource)? {
80 if !other.is_expired() {
81 return Err(GitError::LockConflict {
82 resource: resource.to_string(),
83 owner: other.owner.clone(),
84 expires_in_ms: other.time_remaining_ms(),
85 });
86 }
87 }
88 Err(GitError::LockConflict {
89 resource: resource.to_string(),
90 owner: "unknown".to_string(),
91 expires_in_ms: 0,
92 })
93 }
94 Err(LockAcquireError::Git(e)) => Err(e),
95 }
96 }
97 } else {
98 match self.try_create_lock(&ref_name, &lock) {
100 Ok(()) => Ok(lock),
101 Err(LockAcquireError::Exists) => {
102 if let Some(other) = self.read_lock(resource)? {
103 if !other.is_expired() {
104 return Err(GitError::LockConflict {
105 resource: resource.to_string(),
106 owner: other.owner.clone(),
107 expires_in_ms: other.time_remaining_ms(),
108 });
109 }
110 }
111 Err(GitError::LockConflict {
112 resource: resource.to_string(),
113 owner: "unknown".to_string(),
114 expires_in_ms: 0,
115 })
116 }
117 Err(LockAcquireError::Git(e)) => Err(e),
118 }
119 }
120 }
121 Err(LockAcquireError::Git(e)) => Err(e),
122 }
123 }
124
125 pub fn release(&self, resource: &str, owner: &str) -> Result<(), GitError> {
127 let ref_name = lock_ref_name(resource);
128
129 if let Some(existing) = self.read_lock(resource)? {
131 if existing.owner != owner && !existing.is_expired() {
132 return Err(GitError::LockNotOwned {
133 resource: resource.to_string(),
134 owner: existing.owner,
135 });
136 }
137 }
138
139 self.delete_ref(&ref_name)?;
141
142 Ok(())
143 }
144
145 pub fn renew(
147 &self,
148 resource: &str,
149 owner: &str,
150 ttl_ms: Option<u64>,
151 ) -> Result<Lock, GitError> {
152 let ttl = ttl_ms.unwrap_or(DEFAULT_LOCK_TTL_MS);
153 let ref_name = lock_ref_name(resource);
154
155 if let Some(mut existing) = self.read_lock(resource)? {
157 if existing.owner != owner {
158 return Err(GitError::LockNotOwned {
159 resource: resource.to_string(),
160 owner: existing.owner,
161 });
162 }
163
164 existing.renew(ttl);
166 self.write_lock(&ref_name, &existing)?;
167 return Ok(existing);
168 }
169
170 self.acquire(resource, owner, Some(ttl))
172 }
173
174 pub fn read_lock(&self, resource: &str) -> Result<Option<Lock>, GitError> {
176 let ref_name = lock_ref_name(resource);
177 self.read_lock_ref(&ref_name)
178 }
179
180 pub fn list_locks(&self) -> Result<Vec<Lock>, GitError> {
182 let mut locks = Vec::new();
183
184 let refs = self.repo.references_glob("refs/grite/locks/*")?;
186 for ref_result in refs {
187 let reference = ref_result?;
188 if let Some(lock) = self.read_lock_from_ref(&reference)? {
189 locks.push(lock);
190 }
191 }
192
193 Ok(locks)
194 }
195
196 pub fn check_conflicts(
198 &self,
199 resource: &str,
200 current_owner: &str,
201 policy: LockPolicy,
202 ) -> Result<LockCheckResult, GitError> {
203 if policy == LockPolicy::Off {
204 return Ok(LockCheckResult::Clear);
205 }
206
207 let locks = self.list_locks()?;
208 let conflicts: Vec<Lock> = locks
209 .into_iter()
210 .filter(|lock| {
211 !lock.is_expired() && lock.owner != current_owner && lock.conflicts_with(resource)
212 })
213 .collect();
214
215 if conflicts.is_empty() {
216 Ok(LockCheckResult::Clear)
217 } else if policy == LockPolicy::Warn {
218 Ok(LockCheckResult::Warning(conflicts))
219 } else {
220 Ok(LockCheckResult::Blocked(conflicts))
221 }
222 }
223
224 pub fn gc(&self) -> Result<LockGcStats, GitError> {
226 let mut stats = LockGcStats::default();
227
228 let refs: Vec<_> = self
229 .repo
230 .references_glob("refs/grite/locks/*")?
231 .collect::<Result<Vec<_>, _>>()?;
232
233 for reference in refs {
234 if let Some(lock) = self.read_lock_from_ref(&reference)? {
235 if lock.is_expired() {
236 if let Some(name) = reference.name() {
237 self.delete_ref(name)?;
238 stats.removed += 1;
239 }
240 } else {
241 stats.kept += 1;
242 }
243 }
244 }
245
246 Ok(stats)
247 }
248
249 fn read_lock_ref(&self, ref_name: &str) -> Result<Option<Lock>, GitError> {
251 let reference = match self.repo.find_reference(ref_name) {
252 Ok(r) => r,
253 Err(e) if e.code() == git2::ErrorCode::NotFound => return Ok(None),
254 Err(e) => return Err(e.into()),
255 };
256
257 self.read_lock_from_ref(&reference)
258 }
259
260 fn read_lock_from_ref(&self, reference: &git2::Reference) -> Result<Option<Lock>, GitError> {
262 let commit = reference.peel_to_commit()?;
263 let tree = commit.tree()?;
264
265 let entry = match tree.get_name("lock.json") {
267 Some(e) => e,
268 None => return Ok(None),
269 };
270
271 let blob = self.repo.find_blob(entry.id())?;
272 let content =
273 std::str::from_utf8(blob.content()).map_err(|e| GitError::ParseError(e.to_string()))?;
274
275 let lock: Lock =
276 serde_json::from_str(content).map_err(|e| GitError::ParseError(e.to_string()))?;
277
278 Ok(Some(lock))
279 }
280
281 fn try_create_lock(&self, ref_name: &str, lock: &Lock) -> Result<(), LockAcquireError> {
283 let commit_oid = self
284 .write_lock_commit(lock)
285 .map_err(LockAcquireError::Git)?;
286 match self
287 .repo
288 .reference(ref_name, commit_oid, false, "lock acquire")
289 {
290 Ok(_) => Ok(()),
291 Err(e) if e.code() == git2::ErrorCode::Exists => Err(LockAcquireError::Exists),
292 Err(e) => Err(LockAcquireError::Git(e.into())),
293 }
294 }
295
296 fn write_lock_commit(&self, lock: &Lock) -> Result<git2::Oid, GitError> {
298 let json =
299 serde_json::to_string_pretty(lock).map_err(|e| GitError::ParseError(e.to_string()))?;
300
301 let blob_id = self.repo.blob(json.as_bytes())?;
303
304 let mut tree_builder = self.repo.treebuilder(None)?;
306 tree_builder.insert("lock.json", blob_id, 0o100644)?;
307 let tree_id = tree_builder.write()?;
308 let tree = self.repo.find_tree(tree_id)?;
309
310 let sig = Signature::now("grite", "grit@localhost")?;
312 let message = format!("Lock: {}", lock.resource);
313
314 let parent = self
315 .repo
316 .find_reference(&lock_ref_name(&lock.resource))
317 .ok()
318 .and_then(|r| r.peel_to_commit().ok());
319
320 let parents: Vec<&git2::Commit> = parent.iter().collect();
321
322 let commit_oid = self
323 .repo
324 .commit(None, &sig, &sig, &message, &tree, &parents)?;
325
326 Ok(commit_oid)
327 }
328
329 fn write_lock(&self, ref_name: &str, lock: &Lock) -> Result<(), GitError> {
331 let commit_oid = self.write_lock_commit(lock)?;
332 self.repo
333 .reference(ref_name, commit_oid, true, "lock update")?;
334 Ok(())
335 }
336
337 fn delete_ref(&self, ref_name: &str) -> Result<(), GitError> {
339 match self.repo.find_reference(ref_name) {
340 Ok(mut reference) => {
341 reference.delete()?;
342 Ok(())
343 }
344 Err(e) if e.code() == git2::ErrorCode::NotFound => Ok(()),
345 Err(e) => Err(e.into()),
346 }
347 }
348}
349
350fn lock_ref_name(resource: &str) -> String {
352 format!("refs/grite/locks/{}", resource_hash(resource))
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358 use tempfile::tempdir;
359
360 fn setup_repo() -> tempfile::TempDir {
361 let dir = tempdir().unwrap();
362 let repo = Repository::init(dir.path()).unwrap();
363
364 let sig = Signature::now("test", "test@test.com").unwrap();
366 let tree_id = repo.treebuilder(None).unwrap().write().unwrap();
367 {
368 let tree = repo.find_tree(tree_id).unwrap();
369 repo.commit(Some("HEAD"), &sig, &sig, "Initial", &tree, &[])
370 .unwrap();
371 }
372
373 dir
374 }
375
376 #[test]
377 fn test_acquire_and_release() {
378 let dir = setup_repo();
379 let manager = LockManager::open(dir.path()).unwrap();
380
381 let lock = manager
383 .acquire("repo:global", "actor1", Some(60000))
384 .unwrap();
385 assert_eq!(lock.owner, "actor1");
386 assert_eq!(lock.resource, "repo:global");
387 assert!(!lock.is_expired());
388
389 let read_lock = manager.read_lock("repo:global").unwrap().unwrap();
391 assert_eq!(read_lock.owner, "actor1");
392
393 manager.release("repo:global", "actor1").unwrap();
395
396 let read_lock = manager.read_lock("repo:global").unwrap();
398 assert!(read_lock.is_none());
399 }
400
401 #[test]
402 fn test_acquire_conflict() {
403 let dir = setup_repo();
404 let manager = LockManager::open(dir.path()).unwrap();
405
406 manager
408 .acquire("repo:global", "actor1", Some(60000))
409 .unwrap();
410
411 let result = manager.acquire("repo:global", "actor2", Some(60000));
413 assert!(result.is_err());
414 }
415
416 #[test]
417 fn test_renew_lock() {
418 let dir = setup_repo();
419 let manager = LockManager::open(dir.path()).unwrap();
420
421 let lock1 = manager
423 .acquire("issue:abc123", "actor1", Some(1000))
424 .unwrap();
425 let expires1 = lock1.expires_unix_ms;
426
427 std::thread::sleep(std::time::Duration::from_millis(10));
429
430 let lock2 = manager
432 .renew("issue:abc123", "actor1", Some(60000))
433 .unwrap();
434 assert!(lock2.expires_unix_ms > expires1);
435 }
436
437 #[test]
438 fn test_list_locks() {
439 let dir = setup_repo();
440 let manager = LockManager::open(dir.path()).unwrap();
441
442 manager
444 .acquire("repo:global", "actor1", Some(60000))
445 .unwrap();
446 manager
447 .acquire("issue:abc123", "actor2", Some(60000))
448 .unwrap();
449
450 let locks = manager.list_locks().unwrap();
452 assert_eq!(locks.len(), 2);
453 }
454
455 #[test]
456 fn test_gc_expired() {
457 let dir = setup_repo();
458 let manager = LockManager::open(dir.path()).unwrap();
459
460 manager.acquire("issue:abc123", "actor1", Some(1)).unwrap();
462
463 std::thread::sleep(std::time::Duration::from_millis(10));
465
466 let stats = manager.gc().unwrap();
468 assert_eq!(stats.removed, 1);
469 assert_eq!(stats.kept, 0);
470
471 let locks = manager.list_locks().unwrap();
473 assert!(locks.is_empty());
474 }
475
476 #[test]
477 fn test_check_conflicts() {
478 let dir = setup_repo();
479 let manager = LockManager::open(dir.path()).unwrap();
480
481 manager
483 .acquire("repo:global", "actor1", Some(60000))
484 .unwrap();
485
486 let result = manager
488 .check_conflicts("issue:abc123", "actor2", LockPolicy::Warn)
489 .unwrap();
490 assert!(matches!(result, LockCheckResult::Warning(_)));
491
492 let result = manager
493 .check_conflicts("issue:abc123", "actor2", LockPolicy::Require)
494 .unwrap();
495 assert!(matches!(result, LockCheckResult::Blocked(_)));
496
497 let result = manager
499 .check_conflicts("issue:abc123", "actor1", LockPolicy::Require)
500 .unwrap();
501 assert!(matches!(result, LockCheckResult::Clear));
502 }
503}