Skip to main content

libgrite_core/
lock.rs

1//! Lock types for team coordination
2//!
3//! Grit uses lease-based locks stored as git refs for coordination.
4//! Locks are optional and designed for coordination, not enforcement.
5
6use serde::{Deserialize, Serialize};
7use sha2::{Sha256, Digest};
8
9/// A lease-based lock on a resource
10#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
11pub struct Lock {
12    /// Actor ID who owns the lock (hex-encoded)
13    pub owner: String,
14    /// Unique nonce for this lock instance
15    pub nonce: String,
16    /// When the lock expires (Unix timestamp in ms)
17    pub expires_unix_ms: u64,
18    /// Resource being locked (e.g., "repo:global", "issue:abc123")
19    pub resource: String,
20}
21
22impl Lock {
23    /// Create a new lock
24    pub fn new(owner: String, resource: String, ttl_ms: u64) -> Self {
25        let now = current_time_ms();
26        Self {
27            owner,
28            nonce: uuid::Uuid::new_v4().to_string(),
29            expires_unix_ms: now + ttl_ms,
30            resource,
31        }
32    }
33
34    /// Check if the lock has expired
35    pub fn is_expired(&self) -> bool {
36        let now = current_time_ms();
37        now >= self.expires_unix_ms
38    }
39
40    /// Get time remaining in milliseconds (0 if expired)
41    pub fn time_remaining_ms(&self) -> u64 {
42        let now = current_time_ms();
43        if now >= self.expires_unix_ms {
44            0
45        } else {
46            self.expires_unix_ms - now
47        }
48    }
49
50    /// Extend the lock's expiration
51    pub fn renew(&mut self, ttl_ms: u64) {
52        let now = current_time_ms();
53        self.expires_unix_ms = now + ttl_ms;
54    }
55
56    /// Create an expired lock (for release)
57    pub fn expired(owner: String, resource: String) -> Self {
58        Self {
59            owner,
60            nonce: uuid::Uuid::new_v4().to_string(),
61            expires_unix_ms: 0,
62            resource,
63        }
64    }
65
66    /// Get the namespace of this lock's resource
67    pub fn namespace(&self) -> Option<&str> {
68        self.resource.split(':').next()
69    }
70
71    /// Check if this lock conflicts with another resource
72    pub fn conflicts_with(&self, other_resource: &str) -> bool {
73        if self.is_expired() {
74            return false;
75        }
76
77        let self_ns = self.namespace();
78        let other_ns = other_resource.split(':').next();
79
80        match (self_ns, other_ns) {
81            // Repo-wide lock conflicts with everything
82            (Some("repo"), _) => true,
83            (_, Some("repo")) => true,
84
85            // Path locks only conflict with overlapping paths
86            (Some("path"), Some("path")) => {
87                let self_path = self.resource.strip_prefix("path:").unwrap_or("");
88                let other_path = other_resource.strip_prefix("path:").unwrap_or("");
89                paths_overlap(self_path, other_path)
90            }
91
92            // Issue locks only conflict with same issue
93            (Some("issue"), Some("issue")) => self.resource == other_resource,
94
95            // Different namespaces don't conflict (except repo)
96            _ => false,
97        }
98    }
99}
100
101/// Lock policy for write operations
102#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
103#[serde(rename_all = "lowercase")]
104pub enum LockPolicy {
105    /// No lock checks
106    Off,
107    /// Warn on conflicts but continue (default)
108    #[default]
109    Warn,
110    /// Block if conflicting lock exists
111    Require,
112}
113
114impl LockPolicy {
115    /// Parse from string
116    pub fn from_str(s: &str) -> Option<Self> {
117        match s.to_lowercase().as_str() {
118            "off" => Some(LockPolicy::Off),
119            "warn" => Some(LockPolicy::Warn),
120            "require" => Some(LockPolicy::Require),
121            _ => None,
122        }
123    }
124
125    /// Convert to string
126    pub fn as_str(&self) -> &'static str {
127        match self {
128            LockPolicy::Off => "off",
129            LockPolicy::Warn => "warn",
130            LockPolicy::Require => "require",
131        }
132    }
133}
134
135/// Status of a lock check
136#[derive(Debug, Clone)]
137pub struct LockStatus {
138    /// The lock
139    pub lock: Lock,
140    /// Whether it's owned by the current actor
141    pub owned_by_self: bool,
142}
143
144/// Result of a lock conflict check
145#[derive(Debug, Clone)]
146pub enum LockCheckResult {
147    /// No conflicts
148    Clear,
149    /// Conflicts exist but policy allows continue (warn)
150    Warning(Vec<Lock>),
151    /// Conflicts exist and policy blocks operation
152    Blocked(Vec<Lock>),
153}
154
155impl LockCheckResult {
156    /// Check if operation should proceed
157    pub fn should_proceed(&self) -> bool {
158        !matches!(self, LockCheckResult::Blocked(_))
159    }
160
161    /// Get conflicting locks if any
162    pub fn conflicts(&self) -> &[Lock] {
163        match self {
164            LockCheckResult::Clear => &[],
165            LockCheckResult::Warning(locks) | LockCheckResult::Blocked(locks) => locks,
166        }
167    }
168}
169
170/// Compute the hash for a lock ref name
171///
172/// Returns first 16 chars of SHA256 hex
173pub fn resource_hash(resource: &str) -> String {
174    let mut hasher = Sha256::new();
175    hasher.update(resource.as_bytes());
176    let result = hasher.finalize();
177    hex::encode(&result[..8]) // 8 bytes = 16 hex chars
178}
179
180/// Default lock TTL in milliseconds (5 minutes)
181pub const DEFAULT_LOCK_TTL_MS: u64 = 5 * 60 * 1000;
182
183/// Get current time in milliseconds since Unix epoch
184fn current_time_ms() -> u64 {
185    std::time::SystemTime::now()
186        .duration_since(std::time::UNIX_EPOCH)
187        .unwrap()
188        .as_millis() as u64
189}
190
191/// Check if two paths overlap (one is prefix of the other or they're equal)
192fn paths_overlap(path1: &str, path2: &str) -> bool {
193    if path1 == path2 {
194        return true;
195    }
196
197    // Normalize paths - remove trailing slashes for comparison
198    let p1 = path1.trim_end_matches('/');
199    let p2 = path2.trim_end_matches('/');
200
201    if p1 == p2 {
202        return true;
203    }
204
205    // Check if one is a prefix of the other (as a directory)
206    let p1_dir = if p1.ends_with('/') { p1.to_string() } else { format!("{}/", p1) };
207    let p2_dir = if p2.ends_with('/') { p2.to_string() } else { format!("{}/", p2) };
208
209    p2.starts_with(&p1_dir) || p1.starts_with(&p2_dir)
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215
216    #[test]
217    fn test_lock_creation() {
218        let lock = Lock::new("actor123".to_string(), "repo:global".to_string(), 60000);
219        assert_eq!(lock.owner, "actor123");
220        assert_eq!(lock.resource, "repo:global");
221        assert!(!lock.is_expired());
222        assert!(lock.time_remaining_ms() > 0);
223    }
224
225    #[test]
226    fn test_lock_expiration() {
227        let lock = Lock::expired("actor123".to_string(), "repo:global".to_string());
228        assert!(lock.is_expired());
229        assert_eq!(lock.time_remaining_ms(), 0);
230    }
231
232    #[test]
233    fn test_lock_namespace() {
234        let lock = Lock::new("actor".to_string(), "repo:global".to_string(), 1000);
235        assert_eq!(lock.namespace(), Some("repo"));
236
237        let lock = Lock::new("actor".to_string(), "path:src/main.rs".to_string(), 1000);
238        assert_eq!(lock.namespace(), Some("path"));
239
240        let lock = Lock::new("actor".to_string(), "issue:abc123".to_string(), 1000);
241        assert_eq!(lock.namespace(), Some("issue"));
242    }
243
244    #[test]
245    fn test_repo_lock_conflicts() {
246        let repo_lock = Lock::new("actor".to_string(), "repo:global".to_string(), 60000);
247
248        // Repo lock conflicts with everything
249        assert!(repo_lock.conflicts_with("repo:global"));
250        assert!(repo_lock.conflicts_with("path:src/main.rs"));
251        assert!(repo_lock.conflicts_with("issue:abc123"));
252    }
253
254    #[test]
255    fn test_path_lock_conflicts() {
256        let path_lock = Lock::new("actor".to_string(), "path:src/".to_string(), 60000);
257
258        // Path lock conflicts with overlapping paths
259        assert!(path_lock.conflicts_with("path:src/main.rs"));
260        assert!(path_lock.conflicts_with("path:src/lib.rs"));
261        assert!(path_lock.conflicts_with("path:src/"));
262
263        // Doesn't conflict with non-overlapping
264        assert!(!path_lock.conflicts_with("path:tests/"));
265        assert!(!path_lock.conflicts_with("path:docs/"));
266
267        // Doesn't conflict with other namespaces (except repo)
268        assert!(!path_lock.conflicts_with("issue:abc123"));
269    }
270
271    #[test]
272    fn test_issue_lock_conflicts() {
273        let issue_lock = Lock::new("actor".to_string(), "issue:abc123".to_string(), 60000);
274
275        // Issue lock only conflicts with same issue
276        assert!(issue_lock.conflicts_with("issue:abc123"));
277        assert!(!issue_lock.conflicts_with("issue:def456"));
278        assert!(!issue_lock.conflicts_with("path:src/"));
279    }
280
281    #[test]
282    fn test_expired_lock_no_conflict() {
283        let expired = Lock::expired("actor".to_string(), "repo:global".to_string());
284
285        // Expired locks don't conflict
286        assert!(!expired.conflicts_with("repo:global"));
287        assert!(!expired.conflicts_with("path:src/"));
288    }
289
290    #[test]
291    fn test_resource_hash() {
292        let hash1 = resource_hash("repo:global");
293        let hash2 = resource_hash("repo:global");
294        let hash3 = resource_hash("issue:abc123");
295
296        // Same resource produces same hash
297        assert_eq!(hash1, hash2);
298        // Different resources produce different hashes
299        assert_ne!(hash1, hash3);
300        // Hash is 16 hex chars
301        assert_eq!(hash1.len(), 16);
302    }
303
304    #[test]
305    fn test_lock_policy_parse() {
306        assert_eq!(LockPolicy::from_str("off"), Some(LockPolicy::Off));
307        assert_eq!(LockPolicy::from_str("warn"), Some(LockPolicy::Warn));
308        assert_eq!(LockPolicy::from_str("require"), Some(LockPolicy::Require));
309        assert_eq!(LockPolicy::from_str("WARN"), Some(LockPolicy::Warn));
310        assert_eq!(LockPolicy::from_str("invalid"), None);
311    }
312
313    #[test]
314    fn test_paths_overlap() {
315        // Exact match
316        assert!(paths_overlap("src/main.rs", "src/main.rs"));
317
318        // Directory contains file
319        assert!(paths_overlap("src/", "src/main.rs"));
320        assert!(paths_overlap("src", "src/main.rs"));
321
322        // File in directory
323        assert!(paths_overlap("src/main.rs", "src/"));
324
325        // Non-overlapping
326        assert!(!paths_overlap("src/", "tests/"));
327        assert!(!paths_overlap("src/main.rs", "src/lib.rs"));
328    }
329
330    #[test]
331    fn test_lock_check_result() {
332        let clear = LockCheckResult::Clear;
333        assert!(clear.should_proceed());
334        assert!(clear.conflicts().is_empty());
335
336        let lock = Lock::new("other".to_string(), "repo:global".to_string(), 1000);
337        let warning = LockCheckResult::Warning(vec![lock.clone()]);
338        assert!(warning.should_proceed());
339        assert_eq!(warning.conflicts().len(), 1);
340
341        let blocked = LockCheckResult::Blocked(vec![lock]);
342        assert!(!blocked.should_proceed());
343        assert_eq!(blocked.conflicts().len(), 1);
344    }
345}