jacquard_host_support/
claims.rs1use alloc::collections::BTreeSet;
19use core::fmt;
20
21#[cfg(not(feature = "std"))]
22use alloc::rc::Rc;
23#[cfg(not(feature = "std"))]
24use core::cell::RefCell;
25#[cfg(feature = "std")]
26use std::sync::{Arc, Mutex};
27
28use jacquard_macros::public_model;
29use serde::{Deserialize, Serialize};
30
31#[public_model]
32#[derive(Clone, Copy, Debug, PartialEq, Eq)]
33pub struct ClaimRejected;
34
35impl fmt::Display for ClaimRejected {
36 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
37 formatter.write_str("claim is already held")
38 }
39}
40
41#[cfg(feature = "std")]
42impl std::error::Error for ClaimRejected {}
43
44#[cfg(feature = "std")]
45type SharedClaims<Key> = Arc<Mutex<BTreeSet<Key>>>;
46
47#[cfg(not(feature = "std"))]
48type SharedClaims<Key> = Rc<RefCell<BTreeSet<Key>>>;
49
50#[derive(Clone)]
51pub struct PendingClaims<Key: Ord> {
52 claimed: SharedClaims<Key>,
53}
54
55impl<Key: Ord> Default for PendingClaims<Key> {
56 fn default() -> Self {
57 Self {
58 claimed: new_claim_set(),
59 }
60 }
61}
62
63pub struct ClaimGuard<Key: Ord> {
64 claimed: SharedClaims<Key>,
65 key: Option<Key>,
66}
67
68#[cfg(feature = "std")]
69fn new_claim_set<Key>() -> SharedClaims<Key> {
70 Arc::new(Mutex::new(BTreeSet::new()))
71}
72
73#[cfg(not(feature = "std"))]
74fn new_claim_set<Key>() -> SharedClaims<Key> {
75 Rc::new(RefCell::new(BTreeSet::new()))
76}
77
78#[cfg(feature = "std")]
79fn with_claims<Key, Output>(
80 claims: &SharedClaims<Key>,
81 operation: impl FnOnce(&mut BTreeSet<Key>) -> Output,
82) -> Output {
83 let mut guard = claims
84 .lock()
85 .unwrap_or_else(|poisoned| poisoned.into_inner());
86 operation(&mut guard)
87}
88
89#[cfg(not(feature = "std"))]
90fn with_claims<Key, Output>(
91 claims: &SharedClaims<Key>,
92 operation: impl FnOnce(&mut BTreeSet<Key>) -> Output,
93) -> Output {
94 let mut guard = claims.borrow_mut();
95 operation(&mut guard)
96}
97
98impl<Key> PendingClaims<Key>
99where
100 Key: Clone + Ord,
101{
102 #[must_use]
103 pub fn new() -> Self {
104 Self::default()
105 }
106
107 pub fn try_claim(&self, key: Key) -> Result<ClaimGuard<Key>, ClaimRejected> {
108 if !with_claims(&self.claimed, |claimed| claimed.insert(key.clone())) {
109 return Err(ClaimRejected);
110 }
111 Ok(ClaimGuard {
112 claimed: self.claimed.clone(),
113 key: Some(key),
114 })
115 }
116
117 #[must_use]
118 pub fn contains(&self, key: &Key) -> bool {
119 with_claims(&self.claimed, |claimed| claimed.contains(key))
120 }
121}
122
123impl<Key: Ord> ClaimGuard<Key> {
124 #[must_use]
125 pub fn key(&self) -> &Key {
126 self.key.as_ref().expect("claim guard key")
127 }
128}
129
130impl<Key: Ord> Drop for ClaimGuard<Key> {
131 fn drop(&mut self) {
132 if let Some(key) = self.key.take() {
133 with_claims(&self.claimed, |claimed| claimed.remove(&key));
134 }
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::{ClaimRejected, PendingClaims};
141
142 #[test]
143 fn duplicate_claims_are_rejected() {
144 let claims = PendingClaims::new();
145 let _guard = claims.try_claim("peer-a").expect("first claim");
146 assert_eq!(claims.try_claim("peer-a").err(), Some(ClaimRejected));
147 }
148
149 #[test]
150 fn dropped_guards_release_claims() {
151 let claims = PendingClaims::new();
152 let guard = claims.try_claim("peer-a").expect("claim");
153 assert!(claims.contains(&"peer-a"));
154 assert_eq!(guard.key(), &"peer-a");
155
156 drop(guard);
157
158 assert!(!claims.contains(&"peer-a"));
159 }
160}