1#![deny(unused_crate_dependencies)]
41use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
42use std::future::Future;
43use tokio_util::sync::CancellationToken;
44use url::Url;
45
46const DEFAULT_LIFETIME_SECONDS: u16 = 60;
47const DEFAULT_URL: &str = "https://justalock.dev/";
48
49macro_rules! id_string {
50 ($lock_id:expr) => {
51 URL_SAFE_NO_PAD.encode($lock_id)
52 };
53}
54
55#[derive(Debug)]
57pub enum Error {
58 Data(String),
60 Reqwest(reqwest::Error),
62}
63
64#[derive(Clone)]
74pub struct Lock {
75 lock_id: [u8; 16],
76 client_id: Vec<u8>,
77 lifetime_seconds: u16,
78 refresh_interval: std::time::Duration,
79 url: Url,
80 client: reqwest::Client,
81}
82
83impl std::fmt::Debug for Lock {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 f.debug_struct("Lock")
86 .field("lockID", &id_string!(&self.lock_id))
87 .field("clientID", &id_string!(&self.client_id))
88 .field("lifetime", &format_args!("{}s", self.lifetime_seconds))
89 .finish()
90 }
91}
92
93impl Lock {
94 pub fn builder<L: IntoLockID>(lock_id: L) -> LockBuilder {
98 LockBuilder::new(lock_id)
99 }
100
101 #[tracing::instrument]
102 async fn try_lock(&self) -> Result<bool, Error> {
103 let response = self
104 .client
105 .post(self.url.clone())
106 .body(self.client_id.clone())
107 .send()
108 .await
109 .map_err(Error::Reqwest)?;
110 match response.error_for_status_ref() {
111 Ok(_) => Ok(true),
112 Err(_) if matches!(response.status(), reqwest::StatusCode::CONFLICT) => Ok(false),
113 Err(_) if response.status().is_client_error() => {
114 Err(Error::Data(response.text().await.map_err(Error::Reqwest)?))
115 }
116 Err(err) => Err(Error::Reqwest(err)),
117 }
118 }
119
120 pub async fn locked<Fut, F, T>(self, f: F) -> Result<T, Error>
155 where
156 F: FnOnce(CancellationToken) -> Fut,
157 Fut: Future<Output = T>,
158 {
159 let function_token = CancellationToken::new();
160 let function_token_spawn = function_token.clone();
161 let lock_token = CancellationToken::new();
162 let lock_token_spawn = lock_token.clone();
163 let lock_time = loop {
164 let lock_time = tokio::time::Instant::now();
165 match self.try_lock().await {
166 Ok(true) => {
167 break lock_time;
168 }
169 Ok(false) => {}
170 Err(error) => return Err(error),
171 }
172 };
173 tokio::spawn(async move {
174 lock_forever(&self, lock_time, lock_token_spawn).await;
175 function_token_spawn.cancel();
176 });
177 let value = f(function_token).await;
178 lock_token.cancel();
179 Ok(value)
180 }
181}
182
183pub struct LockBuilder {
185 lock_id: [u8; 16],
186 client_id: Option<Vec<u8>>,
187 lifetime_seconds: Option<u16>,
188 url: Option<Url>,
189}
190
191impl LockBuilder {
192 pub fn new<L: IntoLockID>(lock_id: L) -> Self {
194 Self {
195 lock_id: lock_id.into_lock_id().0,
196 client_id: None,
197 lifetime_seconds: None,
198 url: None,
199 }
200 }
201
202 pub fn client_id(mut self, client_id: Vec<u8>) -> Self {
210 self.client_id = Some(client_id);
211 self
212 }
213
214 pub fn lifetime_seconds(mut self, seconds: u16) -> Self {
219 self.lifetime_seconds = Some(seconds);
220 self
221 }
222
223 pub fn url(mut self, url: &str) -> Result<Self, url::ParseError> {
230 self.url = Some(Url::parse(url)?.join(&id_string!(&self.lock_id))?);
231 Ok(self)
232 }
233
234 pub fn build(self) -> Result<Lock, reqwest::Error> {
236 let client_id = self.client_id.unwrap_or_else(random_client_id);
237 let lifetime_seconds = self.lifetime_seconds.unwrap_or(DEFAULT_LIFETIME_SECONDS);
238 let refresh_interval =
239 std::time::Duration::from_millis(((lifetime_seconds as u64) * 1000) / 3);
240 let mut url = self.url.unwrap_or_else(|| {
241 format!("{}{}", DEFAULT_URL, id_string!(self.lock_id))
242 .parse()
243 .unwrap() });
245 url.set_query(Some(&format!("s={lifetime_seconds}")));
246 let client = reqwest::Client::builder()
247 .timeout(std::time::Duration::from_secs(2))
248 .build()?;
249 Ok(Lock {
250 lock_id: self.lock_id,
251 client_id,
252 lifetime_seconds,
253 refresh_interval,
254 url,
255 client,
256 })
257 }
258}
259
260#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
262pub struct LockID([u8; 16]);
263
264#[allow(private_bounds)]
268pub trait IntoLockID: IntoLockIDSealed {}
269
270impl IntoLockID for LockID {}
271impl IntoLockID for [u8; 16] {}
272impl IntoLockID for i128 {}
273impl IntoLockID for u128 {}
274
275trait IntoLockIDSealed {
276 fn into_lock_id(self) -> LockID;
277}
278
279impl IntoLockIDSealed for LockID {
280 fn into_lock_id(self) -> LockID {
281 self
282 }
283}
284
285impl IntoLockIDSealed for [u8; 16] {
286 fn into_lock_id(self) -> LockID {
287 LockID(self)
288 }
289}
290
291impl IntoLockIDSealed for i128 {
292 fn into_lock_id(self) -> LockID {
293 LockID(self.to_be_bytes())
294 }
295}
296
297impl IntoLockIDSealed for u128 {
298 fn into_lock_id(self) -> LockID {
299 LockID(self.to_be_bytes())
300 }
301}
302
303async fn lock_forever(
304 lock: &Lock,
305 mut last_lock_time: tokio::time::Instant,
306 cancellation_token: CancellationToken,
307) {
308 assert!(lock.lifetime_seconds > 0);
309 let mut interval = tokio::time::interval(lock.refresh_interval);
310 loop {
311 interval.tick().await;
312 let sleep = tokio::time::sleep_until(
313 last_lock_time + std::time::Duration::from_secs(lock.lifetime_seconds as u64 - 1),
314 );
315 let now = tokio::time::Instant::now();
316 tokio::select! {
317 biased;
318 _ = sleep => {
319 return;
320 }
321 _ = cancellation_token.cancelled() => {
322 return;
323 }
324 result = lock.try_lock() => match result {
325 Ok(true) => {
326 last_lock_time = now;
327 }
328 Ok(false) => {
329 return;
330 }
331 Err(error) => {
332 tracing::warn!("failed to call justalock: {:?}", error);
333 }
334 }
335 }
336 }
337}
338
339fn random_client_id() -> Vec<u8> {
340 let duration = std::time::SystemTime::now()
341 .duration_since(std::time::SystemTime::UNIX_EPOCH)
342 .unwrap_or_default();
343 let mut data: Vec<u8> = vec![];
344 data.extend_from_slice(
346 &std::hash::Hasher::finish(&std::hash::BuildHasher::build_hasher(
347 &std::collections::hash_map::RandomState::new(),
348 ))
349 .to_be_bytes(),
350 );
351 data.extend_from_slice(&duration.as_secs().to_be_bytes());
353 data.extend_from_slice(&duration.subsec_nanos().to_be_bytes());
355 data.extend_from_slice(
357 &std::hash::Hasher::finish(&std::hash::BuildHasher::build_hasher(
358 &std::collections::hash_map::RandomState::new(),
359 ))
360 .to_be_bytes(),
361 );
362 data
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368 use assert_matches::assert_matches;
369 use std::time::Duration;
370 use wiremock::matchers::{method, path, query_param};
371 use wiremock::{Mock, MockServer, ResponseTemplate};
372
373 #[test]
374 fn random_client_id_length() {
375 for _ in 0..16 {
376 let client_id = super::random_client_id();
377 assert!(!client_id.is_empty(), "random client id is empty");
378 assert!(client_id.len() < 128, "random client id is too long");
379 }
380 }
381
382 #[test]
383 fn random_client_id_uniqueness() {
384 let mut ids = std::collections::HashSet::new();
385 for _ in 0..100 {
386 let client_id = super::random_client_id();
387 assert!(ids.insert(client_id), "Generated duplicate client ID");
388 }
389 }
390
391 #[test]
392 fn lock_id_conversions() {
393 let u128_id: u128 = 12345;
395 let lock_id = u128_id.into_lock_id();
396 assert_eq!(lock_id.0, u128_id.to_be_bytes());
397
398 let i128_id: i128 = -12345;
400 let lock_id = i128_id.into_lock_id();
401 assert_eq!(lock_id.0, i128_id.to_be_bytes());
402
403 let byte_array = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
405 let lock_id = byte_array.into_lock_id();
406 assert_eq!(lock_id.0, byte_array);
407
408 let original = LockID([0; 16]);
410 let converted = original.into_lock_id();
411 assert_eq!(original, converted);
412 }
413
414 #[test]
415 fn lock_builder_defaults() {
416 let lock_id: u128 = 42;
417 let builder = LockBuilder::new(lock_id);
418
419 assert_eq!(builder.lock_id, lock_id.to_be_bytes());
420 assert!(builder.client_id.is_none());
421 assert!(builder.lifetime_seconds.is_none());
422 assert!(builder.url.is_none());
423 }
424
425 #[test]
426 fn lock_builder_configuration() {
427 let lock_id: u128 = 42;
428 let client_id = b"test-client".to_vec();
429 let lifetime = 120;
430
431 let builder = LockBuilder::new(lock_id)
432 .client_id(client_id.clone())
433 .lifetime_seconds(lifetime);
434
435 assert_eq!(builder.client_id, Some(client_id));
436 assert_eq!(builder.lifetime_seconds, Some(lifetime));
437 }
438
439 #[test]
440 fn lock_builder_url_parsing() {
441 let lock_id: u128 = 42;
442
443 let builder = LockBuilder::new(lock_id)
445 .url("http://localhost:8080/")
446 .expect("Valid URL should parse");
447 assert!(builder.url.is_some());
448
449 let result = LockBuilder::new(lock_id).url("not-a-valid-url");
451 assert!(result.is_err());
452 }
453
454 #[tokio::test]
455 async fn lock_builder_build() {
456 let lock_id: u128 = 42;
457 let lock = Lock::builder(lock_id)
458 .build()
459 .expect("Should build successfully");
460
461 assert_eq!(lock.lock_id, lock_id.to_be_bytes());
462 assert_eq!(lock.lifetime_seconds, DEFAULT_LIFETIME_SECONDS);
463 assert!(!lock.client_id.is_empty());
464 assert!(lock.url.to_string().contains("justalock.dev"));
465 }
466
467 #[tokio::test]
468 async fn lock_builder_build_with_custom_config() {
469 let lock_id: u128 = 42;
470 let client_id = b"test-client".to_vec();
471 let lifetime = 120;
472
473 let lock = Lock::builder(lock_id)
474 .client_id(client_id.clone())
475 .lifetime_seconds(lifetime)
476 .url("http://localhost:8080/")
477 .expect("URL should parse")
478 .build()
479 .expect("Should build successfully");
480
481 assert_eq!(lock.lock_id, lock_id.to_be_bytes());
482 assert_eq!(lock.client_id, client_id);
483 assert_eq!(lock.lifetime_seconds, lifetime);
484 assert_eq!(
485 lock.refresh_interval,
486 Duration::from_millis((lifetime as u64 * 1000) / 3)
487 );
488 assert!(lock.url.to_string().contains("localhost:8080"));
489 }
490
491 #[tokio::test]
492 async fn lock_debug_format() {
493 let lock_id: u128 = 42;
494 let client_id = b"test".to_vec();
495 let lock = Lock::builder(lock_id)
496 .client_id(client_id)
497 .build()
498 .expect("Should build successfully");
499
500 let debug_str = format!("{:?}", lock);
501 assert!(debug_str.contains("Lock"));
502 assert!(debug_str.contains("lockID"));
503 assert!(debug_str.contains("clientID"));
504 assert!(debug_str.contains("lifetime"));
505 }
506
507 #[tokio::test]
508 async fn try_lock_success() {
509 let mock_server = MockServer::start().await;
510
511 let lock_id_base64 = URL_SAFE_NO_PAD.encode(42u128.to_be_bytes());
513
514 Mock::given(method("POST"))
515 .and(path(format!("/{}", lock_id_base64)))
516 .and(query_param("s", "60"))
517 .respond_with(ResponseTemplate::new(200))
518 .expect(1)
519 .mount(&mock_server)
520 .await;
521
522 let lock = Lock::builder(42u128)
523 .url(&mock_server.uri())
524 .expect("URL should parse")
525 .build()
526 .expect("Should build successfully");
527
528 let result = lock.try_lock().await.expect("Should not error");
529 assert!(result, "Lock should be acquired");
530 }
531
532 #[tokio::test]
533 async fn try_lock_conflict() {
534 let mock_server = MockServer::start().await;
535
536 let lock_id_base64 = URL_SAFE_NO_PAD.encode(42u128.to_be_bytes());
537
538 Mock::given(method("POST"))
539 .and(path(format!("/{}", lock_id_base64)))
540 .respond_with(ResponseTemplate::new(409)) .expect(1)
542 .mount(&mock_server)
543 .await;
544
545 let lock = Lock::builder(42u128)
546 .url(&mock_server.uri())
547 .expect("URL should parse")
548 .build()
549 .expect("Should build successfully");
550
551 let result = lock.try_lock().await.expect("Should not error");
552 assert!(!result, "Lock should not be acquired due to conflict");
553 }
554
555 #[tokio::test]
556 async fn try_lock_client_error() {
557 let mock_server = MockServer::start().await;
558
559 let lock_id_base64 = URL_SAFE_NO_PAD.encode(42u128.to_be_bytes());
560
561 Mock::given(method("POST"))
562 .and(path(format!("/{}", lock_id_base64)))
563 .respond_with(ResponseTemplate::new(400).set_body_string("Bad request"))
564 .expect(1)
565 .mount(&mock_server)
566 .await;
567
568 let lock = Lock::builder(42u128)
569 .url(&mock_server.uri())
570 .expect("URL should parse")
571 .build()
572 .expect("Should build successfully");
573
574 let result = lock.try_lock().await;
575 assert_matches!(result, Err(Error::Data(_)));
576 }
577
578 #[tokio::test]
579 async fn try_lock_server_error() {
580 let mock_server = MockServer::start().await;
581
582 let lock_id_base64 = URL_SAFE_NO_PAD.encode(42u128.to_be_bytes());
583
584 Mock::given(method("POST"))
585 .and(path(format!("/{}", lock_id_base64)))
586 .respond_with(ResponseTemplate::new(500))
587 .expect(1)
588 .mount(&mock_server)
589 .await;
590
591 let lock = Lock::builder(42u128)
592 .url(&mock_server.uri())
593 .expect("URL should parse")
594 .build()
595 .expect("Should build successfully");
596
597 let result = lock.try_lock().await;
598 assert_matches!(result, Err(Error::Reqwest(_)));
599 }
600
601 #[tokio::test]
602 async fn locked_function_executes_successfully() {
603 let mock_server = MockServer::start().await;
604
605 let lock_id_base64 = URL_SAFE_NO_PAD.encode(42u128.to_be_bytes());
606
607 Mock::given(method("POST"))
609 .and(path(format!("/{}", lock_id_base64)))
610 .respond_with(ResponseTemplate::new(200))
611 .mount(&mock_server)
612 .await;
613
614 let lock = Lock::builder(42u128)
615 .url(&mock_server.uri())
616 .expect("URL should parse")
617 .build()
618 .expect("Should build successfully");
619
620 let result = lock.locked(|_token| async move { "success" }).await;
621
622 assert_matches!(result, Ok("success"));
623 }
624
625 #[tokio::test]
626 async fn locked_function_receives_cancellation_token() {
627 let mock_server = MockServer::start().await;
628
629 let lock_id_base64 = URL_SAFE_NO_PAD.encode(42u128.to_be_bytes());
630
631 Mock::given(method("POST"))
633 .and(path(format!("/{}", lock_id_base64)))
634 .respond_with(ResponseTemplate::new(200))
635 .mount(&mock_server)
636 .await;
637
638 let lock = Lock::builder(42u128)
639 .url(&mock_server.uri())
640 .expect("URL should parse")
641 .build()
642 .expect("Should build successfully");
643
644 let result = lock
645 .locked(|token| async move {
646 assert!(
647 !token.is_cancelled(),
648 "Token should not be cancelled initially"
649 );
650 42
651 })
652 .await;
653
654 assert_matches!(result, Ok(42));
655 }
656
657 #[tokio::test]
658 async fn locked_with_acquisition_retry() {
659 let mock_server = MockServer::start().await;
660
661 let lock_id_base64 = URL_SAFE_NO_PAD.encode(42u128.to_be_bytes());
662
663 Mock::given(method("POST"))
665 .and(path(format!("/{}", lock_id_base64)))
666 .respond_with(ResponseTemplate::new(409))
667 .up_to_n_times(2)
668 .mount(&mock_server)
669 .await;
670
671 Mock::given(method("POST"))
672 .and(path(format!("/{}", lock_id_base64)))
673 .respond_with(ResponseTemplate::new(200))
674 .mount(&mock_server)
675 .await;
676
677 let lock = Lock::builder(42u128)
678 .url(&mock_server.uri())
679 .expect("URL should parse")
680 .build()
681 .expect("Should build successfully");
682
683 let result = lock
684 .locked(|_token| async move { "success after retry" })
685 .await;
686
687 assert_matches!(result, Ok("success after retry"));
688 }
689
690 #[tokio::test]
691 async fn locked_propagates_try_lock_errors() {
692 let mock_server = MockServer::start().await;
693
694 let lock_id_base64 = URL_SAFE_NO_PAD.encode(42u128.to_be_bytes());
695
696 Mock::given(method("POST"))
698 .and(path(format!("/{}", lock_id_base64)))
699 .respond_with(ResponseTemplate::new(400).set_body_string("Invalid lock ID"))
700 .mount(&mock_server)
701 .await;
702
703 let lock = Lock::builder(42u128)
704 .url(&mock_server.uri())
705 .expect("URL should parse")
706 .build()
707 .expect("Should build successfully");
708
709 let result = lock
710 .locked(|_token| async move { "should not execute" })
711 .await;
712
713 assert_matches!(result, Err(Error::Data(_)));
714 }
715
716 #[tokio::test]
717 async fn locked_integration_with_refresh() {
718 let mock_server = MockServer::start().await;
719
720 let lock_id_base64 = URL_SAFE_NO_PAD.encode(42u128.to_be_bytes());
721
722 Mock::given(method("POST"))
724 .and(path(format!("/{}", lock_id_base64)))
725 .respond_with(ResponseTemplate::new(200))
726 .expect(1..) .mount(&mock_server)
728 .await;
729
730 let lock = Lock::builder(42u128)
731 .lifetime_seconds(60) .url(&mock_server.uri())
733 .expect("URL should parse")
734 .build()
735 .expect("Should build successfully");
736
737 let result = lock
738 .locked(|token| async move {
739 tokio::time::sleep(Duration::from_millis(100)).await;
741 assert!(
742 !token.is_cancelled(),
743 "Token should not be cancelled during work"
744 );
745 "completed"
746 })
747 .await;
748
749 assert_matches!(result, Ok("completed"));
750 }
751
752 #[tokio::test]
753 async fn error_display() {
754 let data_error = Error::Data("test error".to_string());
755 assert!(format!("{:?}", data_error).contains("test error"));
756
757 match Error::Data("test".to_string()) {
759 Error::Data(_) => {} Error::Reqwest(_) => {} }
762 }
763
764 #[tokio::test]
766 async fn locked_with_network_timeout() {
767 let lock = Lock::builder(42u128)
769 .url("http://192.0.2.1:9999/") .expect("URL should parse")
771 .build()
772 .expect("Should build successfully");
773
774 let start = tokio::time::Instant::now();
775 let result = lock
776 .locked(|_token| async move { "should not execute" })
777 .await;
778
779 let elapsed = start.elapsed();
780
781 assert_matches!(result, Err(Error::Reqwest(_)));
783 assert!(elapsed < Duration::from_secs(10), "Should timeout quickly");
784 }
785}