1#![cfg_attr(
17 not(any(feature = "pg", feature = "mysql", feature = "sqlite")),
18 allow(unused_imports, unused_variables, dead_code, unreachable_code)
19)]
20
21use std::path::PathBuf;
22use std::time::{Duration, Instant};
23use thiserror::Error;
24use xxhash_rust::xxh3::xxh3_64;
25
26use chrono::SecondsFormat;
27
28#[cfg(feature = "mysql")]
29use sea_orm::sqlx::MySql;
30#[cfg(feature = "pg")]
31use sea_orm::sqlx::Postgres;
32
33#[cfg(feature = "mysql")]
34type MySqlPoolConnection = sea_orm::sqlx::pool::PoolConnection<MySql>;
35#[cfg(feature = "pg")]
36type PostgresPoolConnection = sea_orm::sqlx::pool::PoolConnection<Postgres>;
37use tokio::fs::File;
38
39use crate::{DbEngine, DbPool};
40
41#[derive(Debug, Clone)]
45pub struct LockConfig {
46 pub max_wait: Option<Duration>,
48 pub initial_backoff: Duration,
50 pub max_backoff: Duration,
52 pub backoff_multiplier: f64,
54 pub jitter_pct: f32,
56 pub max_attempts: Option<u32>,
58}
59
60impl Default for LockConfig {
61 fn default() -> Self {
62 Self {
63 max_wait: Some(Duration::from_secs(30)),
64 initial_backoff: Duration::from_millis(50),
65 max_backoff: Duration::from_secs(5),
66 backoff_multiplier: 1.5,
67 jitter_pct: 0.2,
68 max_attempts: None,
69 }
70 }
71}
72
73#[derive(Debug)]
76enum GuardInner {
77 #[cfg(feature = "pg")]
78 Postgres {
79 conn: PostgresPoolConnection,
81 key_hash: i64,
82 },
83 #[cfg(feature = "mysql")]
84 MySql {
85 conn: MySqlPoolConnection,
87 lock_name: String,
88 },
89 File { path: PathBuf, file: File },
91}
92
93#[derive(Debug)]
96pub struct DbLockGuard {
97 namespaced_key: String,
98 inner: Option<GuardInner>, }
100
101impl DbLockGuard {
102 pub fn key(&self) -> &str {
104 &self.namespaced_key
105 }
106
107 pub async fn release(mut self) {
109 if let Some(inner) = self.inner.take() {
110 unlock_inner(inner).await;
111 }
112 }
114}
115
116impl Drop for DbLockGuard {
117 fn drop(&mut self) {
118 if let Some(inner) = self.inner.take()
120 && let Ok(handle) = tokio::runtime::Handle::try_current()
121 {
122 handle.spawn(async move { unlock_inner(inner).await });
123 }
124 }
128}
129
130async fn unlock_inner(inner: GuardInner) {
131 match inner {
132 #[cfg(feature = "pg")]
133 GuardInner::Postgres { mut conn, key_hash } => {
134 if let Err(e) = sea_orm::sqlx::query("SELECT pg_advisory_unlock($1)")
135 .bind(key_hash)
136 .execute(&mut *conn)
137 .await
138 {
139 tracing::warn!(error=%e, "failed to release PostgreSQL advisory lock");
140 }
141 }
143 #[cfg(feature = "mysql")]
144 GuardInner::MySql {
145 mut conn,
146 lock_name,
147 } => {
148 if let Err(e) = sea_orm::sqlx::query_scalar::<_, Option<i64>>("SELECT RELEASE_LOCK(?)")
150 .bind(&lock_name)
151 .fetch_one(&mut *conn)
152 .await
153 {
154 tracing::warn!(error=%e, "failed to release MySQL advisory lock");
155 }
156 }
157 GuardInner::File { path, file } => {
158 drop(file);
160 let _ = tokio::fs::remove_file(&path).await;
161 }
162 }
163}
164
165pub struct LockManager {
169 engine: DbEngine,
170 #[cfg_attr(
172 all(feature = "sqlite", not(any(feature = "pg", feature = "mysql"))),
173 allow(dead_code)
174 )]
175 pool: DbPool,
176 dsn: String,
177}
178
179impl LockManager {
180 #[must_use]
181 pub fn new(engine: DbEngine, pool: DbPool, dsn: String) -> Self {
182 Self { engine, pool, dsn }
183 }
184
185 pub async fn lock(&self, module: &str, key: &str) -> Result<DbLockGuard, DbLockError> {
193 let namespaced_key = format!("{module}:{key}");
194 match self.engine {
195 #[cfg(feature = "pg")]
196 DbEngine::Postgres => self.lock_pg(&namespaced_key).await,
197 #[cfg(not(feature = "pg"))]
198 DbEngine::Postgres => Err(DbLockError::InvalidState(
199 "PostgreSQL feature not enabled".to_owned(),
200 )),
201 #[cfg(feature = "mysql")]
202 DbEngine::MySql => self.lock_mysql(&namespaced_key).await,
203 #[cfg(not(feature = "mysql"))]
204 DbEngine::MySql => Err(DbLockError::InvalidState(
205 "MySQL feature not enabled".to_owned(),
206 )),
207 DbEngine::Sqlite => self.lock_file(&namespaced_key).await,
208 }
209 }
210
211 pub async fn try_lock(
221 &self,
222 module: &str,
223 key: &str,
224 config: LockConfig,
225 ) -> Result<Option<DbLockGuard>, DbLockError> {
226 let namespaced_key = format!("{module}:{key}");
227 let start = Instant::now();
228 let mut attempt = 0u32;
229 let mut backoff = config.initial_backoff;
230
231 loop {
232 attempt += 1;
233
234 if let Some(max_attempts) = config.max_attempts
235 && attempt > max_attempts
236 {
237 return Ok(None);
238 }
239 if let Some(max_wait) = config.max_wait
240 && start.elapsed() >= max_wait
241 {
242 return Ok(None);
243 }
244
245 if let Some(guard) = self.try_acquire_once(&namespaced_key).await? {
246 return Ok(Some(guard));
247 }
248
249 let remaining = config
251 .max_wait
252 .map_or(backoff, |mw| mw.saturating_sub(start.elapsed()));
253
254 if remaining.is_zero() {
255 return Ok(None);
256 }
257
258 #[allow(clippy::cast_precision_loss)]
259 let jitter_factor = {
260 let pct = f64::from(config.jitter_pct.clamp(0.0, 1.0));
261 let lo = 1.0 - pct;
262 let hi = 1.0 + pct;
263 let h = xxh3_64(namespaced_key.as_bytes()) as f64;
265 let frac = h / u64::MAX as f64; lo + frac * (hi - lo)
267 };
268
269 let sleep_for = std::cmp::min(backoff, remaining);
270 tokio::time::sleep(sleep_for.mul_f64(jitter_factor)).await;
271
272 let next = backoff.mul_f64(config.backoff_multiplier);
274 backoff = std::cmp::min(next, config.max_backoff);
275 }
276 }
277
278 #[cfg(feature = "pg")]
281 async fn lock_pg(&self, namespaced_key: &str) -> Result<DbLockGuard, DbLockError> {
282 #[allow(irrefutable_let_patterns)] let DbPool::Postgres(ref pool) = self.pool else {
284 return Err(DbLockError::InvalidState(
285 "not a PostgreSQL pool".to_owned(),
286 ));
287 };
288 let mut conn = pool.acquire().await?; #[allow(
291 clippy::cast_possible_wrap,
292 reason = "intentional wrapping of hash into i64 advisory lock key"
293 )]
294 let key_hash = xxh3_64(namespaced_key.as_bytes()) as i64;
295
296 sea_orm::sqlx::query("SELECT pg_advisory_lock($1)")
297 .bind(key_hash)
298 .execute(&mut *conn)
299 .await?; Ok(DbLockGuard {
302 namespaced_key: namespaced_key.to_owned(),
303 inner: Some(GuardInner::Postgres { conn, key_hash }),
304 })
305 }
306
307 #[cfg(feature = "pg")]
308 async fn try_lock_pg(&self, namespaced_key: &str) -> Result<Option<DbLockGuard>, DbLockError> {
309 #[allow(irrefutable_let_patterns)] let DbPool::Postgres(ref pool) = self.pool else {
311 return Err(DbLockError::InvalidState(
312 "not a PostgreSQL pool".to_owned(),
313 ));
314 };
315 let mut conn = pool.acquire().await?; #[allow(
318 clippy::cast_possible_wrap,
319 reason = "intentional wrapping of hash into i64 advisory lock key"
320 )]
321 let key_hash = xxh3_64(namespaced_key.as_bytes()) as i64;
322
323 let (ok,): (bool,) = sea_orm::sqlx::query_as("SELECT pg_try_advisory_lock($1)")
324 .bind(key_hash)
325 .fetch_one(&mut *conn)
326 .await?; if ok {
329 Ok(Some(DbLockGuard {
330 namespaced_key: namespaced_key.to_owned(),
331 inner: Some(GuardInner::Postgres { conn, key_hash }),
332 }))
333 } else {
334 drop(conn);
335 Ok(None)
336 }
337 }
338
339 #[cfg(feature = "mysql")]
340 async fn lock_mysql(&self, namespaced_key: &str) -> Result<DbLockGuard, DbLockError> {
341 #[allow(irrefutable_let_patterns)] let DbPool::MySql(ref pool) = self.pool else {
343 return Err(DbLockError::InvalidState("not a MySQL pool".to_owned()));
344 };
345 let mut conn = pool.acquire().await?; let (ok,): (i64,) = sea_orm::sqlx::query_as("SELECT GET_LOCK(?, 31536000)") .bind(namespaced_key)
351 .fetch_one(&mut *conn)
352 .await?; if ok != 1 {
355 return Err(DbLockError::InvalidState(
356 "failed to acquire MySQL lock".to_owned(),
357 ));
358 }
359
360 Ok(DbLockGuard {
361 namespaced_key: namespaced_key.to_owned(),
362 inner: Some(GuardInner::MySql {
363 conn,
364 lock_name: namespaced_key.to_owned(),
365 }),
366 })
367 }
368
369 #[cfg(feature = "mysql")]
370 async fn try_lock_mysql(
371 &self,
372 namespaced_key: &str,
373 ) -> Result<Option<DbLockGuard>, DbLockError> {
374 #[allow(irrefutable_let_patterns)] let DbPool::MySql(ref pool) = self.pool else {
376 return Err(DbLockError::InvalidState("not a MySQL pool".to_owned()));
377 };
378 let mut conn = pool.acquire().await?; let (ok,): (i64,) = sea_orm::sqlx::query_as("SELECT GET_LOCK(?, 0)")
382 .bind(namespaced_key)
383 .fetch_one(&mut *conn)
384 .await?; if ok == 1 {
387 Ok(Some(DbLockGuard {
388 namespaced_key: namespaced_key.to_owned(),
389 inner: Some(GuardInner::MySql {
390 conn,
391 lock_name: namespaced_key.to_owned(),
392 }),
393 }))
394 } else {
395 drop(conn);
396 Ok(None)
397 }
398 }
399
400 async fn lock_file(&self, namespaced_key: &str) -> Result<DbLockGuard, DbLockError> {
401 let path = self.get_lock_file_path(namespaced_key);
402 if let Some(parent) = path.parent() {
403 tokio::fs::create_dir_all(parent).await?;
404 }
405
406 let file_res = tokio::fs::OpenOptions::new()
408 .write(true)
409 .create_new(true)
410 .open(&path)
411 .await;
412 let file = match file_res {
413 Ok(f) => f,
414 Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => {
415 return Err(DbLockError::AlreadyHeld {
416 lock_name: namespaced_key.to_owned(),
417 });
418 }
419 Err(e) => return Err(e.into()),
420 };
421
422 {
424 use tokio::io::AsyncWriteExt;
425 let mut f = file.try_clone().await?;
426 let _ = f
427 .write_all(
428 format!(
429 "PID: {}\nKey: {}\nTimestamp: {}\n",
430 std::process::id(),
431 namespaced_key,
432 chrono::Utc::now().to_rfc3339_opts(SecondsFormat::Millis, true)
433 )
434 .as_bytes(),
435 )
436 .await;
437 }
438
439 Ok(DbLockGuard {
440 namespaced_key: namespaced_key.to_owned(),
441 inner: Some(GuardInner::File { path, file }),
442 })
443 }
444
445 async fn try_lock_file(
446 &self,
447 namespaced_key: &str,
448 ) -> Result<Option<DbLockGuard>, DbLockError> {
449 let path = self.get_lock_file_path(namespaced_key);
450 if let Some(parent) = path.parent() {
451 tokio::fs::create_dir_all(parent).await?;
452 }
453
454 match tokio::fs::OpenOptions::new()
455 .write(true)
456 .create_new(true)
457 .open(&path)
458 .await
459 {
460 Ok(file) => {
461 {
463 use tokio::io::AsyncWriteExt;
464 let mut f = file.try_clone().await?;
465 let _ = f
466 .write_all(
467 format!(
468 "PID: {}\nKey: {}\nTimestamp: {}\n",
469 std::process::id(),
470 namespaced_key,
471 chrono::Utc::now().to_rfc3339_opts(SecondsFormat::Millis, true)
472 )
473 .as_bytes(),
474 )
475 .await;
476 }
477
478 Ok(Some(DbLockGuard {
479 namespaced_key: namespaced_key.to_owned(),
480 inner: Some(GuardInner::File { path, file }),
481 }))
482 }
483 Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => Ok(None),
484 Err(e) => Err(e.into()),
485 }
486 }
487
488 async fn try_acquire_once(
489 &self,
490 namespaced_key: &str,
491 ) -> Result<Option<DbLockGuard>, DbLockError> {
492 match self.engine {
493 #[cfg(feature = "pg")]
494 DbEngine::Postgres => self.try_lock_pg(namespaced_key).await,
495 #[cfg(not(feature = "pg"))]
496 DbEngine::Postgres => Err(DbLockError::InvalidState(
497 "PostgreSQL feature not enabled".to_owned(),
498 )),
499 #[cfg(feature = "mysql")]
500 DbEngine::MySql => self.try_lock_mysql(namespaced_key).await,
501 #[cfg(not(feature = "mysql"))]
502 DbEngine::MySql => Err(DbLockError::InvalidState(
503 "MySQL feature not enabled".to_owned(),
504 )),
505 DbEngine::Sqlite => self.try_lock_file(namespaced_key).await,
506 }
507 }
508
509 fn get_lock_file_path(&self, namespaced_key: &str) -> PathBuf {
511 let base_dir = if self.dsn.contains("memdb") || cfg!(test) {
513 std::env::temp_dir().join("hyperspot_test_locks")
514 } else {
515 let cache = dirs::cache_dir().unwrap_or_else(std::env::temp_dir);
517 cache.join("hyperspot").join("locks")
518 };
519
520 let dsn_hash = format!("{:x}", xxh3_64(self.dsn.as_bytes()));
521 let key_hash = format!("{:x}", xxh3_64(namespaced_key.as_bytes()));
522 base_dir.join(dsn_hash).join(format!("{key_hash}.lock"))
523 }
524}
525
526#[derive(Error, Debug)]
529pub enum DbLockError {
530 #[error("I/O error: {0}")]
531 Io(#[from] std::io::Error),
532
533 #[cfg(any(feature = "pg", feature = "mysql"))]
535 #[error("SQLx error: {0}")]
536 Sqlx(#[from] sea_orm::sqlx::Error),
537
538 #[error("Lock already held: {lock_name}")]
539 AlreadyHeld { lock_name: String },
540
541 #[error("Lock not found: {lock_name}")]
542 NotFound { lock_name: String },
543
544 #[error("Invalid state: {0}")]
545 InvalidState(String),
546}
547
548#[cfg(test)]
551#[cfg(feature = "sqlite")]
552#[cfg_attr(coverage_nightly, coverage(off))]
553mod tests {
554 use super::*;
555 use anyhow::Result;
556 use std::sync::Arc;
557
558 #[tokio::test]
559 async fn test_namespaced_locks() -> Result<()> {
560 let dsn = "sqlite::memory:";
561 let pool = sea_orm::sqlx::SqlitePool::connect(dsn).await?;
562 let lock_manager = LockManager::new(
563 crate::DbEngine::Sqlite,
564 crate::DbPool::Sqlite(pool),
565 dsn.to_owned(),
566 );
567
568 let test_id = format!(
570 "test_ns_{}",
571 std::time::SystemTime::now()
572 .duration_since(std::time::UNIX_EPOCH)
573 .unwrap()
574 .as_nanos()
575 );
576
577 let guard1 = lock_manager
578 .lock("module1", &format!("{test_id}_key"))
579 .await?;
580 let guard2 = lock_manager
581 .lock("module2", &format!("{test_id}_key"))
582 .await?;
583
584 assert!(!guard1.key().is_empty());
585 assert!(!guard2.key().is_empty());
586
587 guard1.release().await;
588 guard2.release().await;
589 Ok(())
590 }
591
592 #[tokio::test]
593 async fn test_try_lock_with_timeout() -> Result<()> {
594 let dsn = "sqlite::memory:";
595 let pool = sea_orm::sqlx::SqlitePool::connect(dsn).await?;
596 let lock_manager = Arc::new(LockManager::new(
597 DbEngine::Sqlite,
598 DbPool::Sqlite(pool),
599 dsn.to_owned(),
600 ));
601
602 let test_id = format!(
603 "test_timeout_{}",
604 std::time::SystemTime::now()
605 .duration_since(std::time::UNIX_EPOCH)
606 .unwrap()
607 .as_nanos()
608 );
609
610 let _guard1 = lock_manager
611 .lock("test_module", &format!("{test_id}_key"))
612 .await?;
613
614 let config = LockConfig {
616 max_wait: Some(Duration::from_millis(200)),
617 initial_backoff: Duration::from_millis(50),
618 max_attempts: Some(3),
619 ..Default::default()
620 };
621
622 let result = lock_manager
623 .try_lock("test_module", &format!("{test_id}_different_key"), config)
624 .await?;
625 assert!(result.is_some(), "expected successful lock acquisition");
626 Ok(())
627 }
628
629 #[tokio::test]
630 async fn test_try_lock_success() -> Result<()> {
631 let dsn = "sqlite::memory:";
632 let pool = sea_orm::sqlx::SqlitePool::connect(dsn).await?;
633 let lock_manager = LockManager::new(DbEngine::Sqlite, DbPool::Sqlite(pool), dsn.to_owned());
634
635 let test_id = format!(
636 "test_success_{}",
637 std::time::SystemTime::now()
638 .duration_since(std::time::UNIX_EPOCH)
639 .unwrap()
640 .as_nanos()
641 );
642
643 let result = lock_manager
644 .try_lock(
645 "test_module",
646 &format!("{test_id}_key"),
647 LockConfig::default(),
648 )
649 .await?;
650 assert!(result.is_some(), "expected lock acquisition");
651 Ok(())
652 }
653
654 #[tokio::test]
655 async fn test_double_lock_same_key_errors() -> Result<()> {
656 let dsn = "sqlite::memory:";
657 let pool = sea_orm::sqlx::SqlitePool::connect(dsn).await?;
658 let lock_manager = LockManager::new(DbEngine::Sqlite, DbPool::Sqlite(pool), dsn.to_owned());
659
660 let test_id = format!(
661 "test_double_{}",
662 std::time::SystemTime::now()
663 .duration_since(std::time::UNIX_EPOCH)
664 .unwrap()
665 .as_nanos()
666 );
667
668 let guard = lock_manager.lock("test_module", &test_id).await?;
669 let err = lock_manager
670 .lock("test_module", &test_id)
671 .await
672 .unwrap_err();
673 match err {
674 DbLockError::AlreadyHeld { lock_name } => {
675 assert!(lock_name.contains(&test_id));
676 }
677 other => panic!("unexpected error: {other:?}"),
678 }
679
680 guard.release().await;
681 Ok(())
682 }
683
684 #[tokio::test]
685 async fn test_try_lock_conflict_returns_none() -> Result<()> {
686 let dsn = "sqlite::memory:";
687 let pool = sea_orm::sqlx::SqlitePool::connect(dsn).await?;
688 let lock_manager = LockManager::new(DbEngine::Sqlite, DbPool::Sqlite(pool), dsn.to_owned());
689
690 let key = format!(
691 "test_conflict_{}",
692 std::time::SystemTime::now()
693 .duration_since(std::time::UNIX_EPOCH)
694 .unwrap()
695 .as_nanos()
696 );
697
698 let _guard = lock_manager.lock("module", &key).await?;
699 let config = LockConfig {
700 max_wait: Some(Duration::from_millis(100)),
701 max_attempts: Some(2),
702 ..Default::default()
703 };
704 let res = lock_manager.try_lock("module", &key, config).await?;
705 assert!(res.is_none());
706 Ok(())
707 }
708}