1use crate::{ClientBuilder, Lease, local::LocalLocks};
2use anyhow::{Context, bail, ensure};
3use aws_sdk_dynamodb::{
4 error::SdkError,
5 operation::{
6 delete_item::{DeleteItemError, DeleteItemOutput},
7 put_item::PutItemError,
8 update_item::UpdateItemError,
9 },
10 types::{AttributeValue, KeyType, ScalarAttributeType},
11};
12use aws_smithy_runtime_api::client::orchestrator;
13use std::{
14 cmp::min,
15 sync::Arc,
16 time::{Duration, Instant},
17};
18use time::OffsetDateTime;
19use tracing::instrument;
20use uuid::Uuid;
21
22const KEY_FIELD: &str = "key";
23const LEASE_EXPIRY_FIELD: &str = "lease_expiry";
24const LEASE_VERSION_FIELD: &str = "lease_version";
25
26#[derive(Debug, Clone)]
33pub struct Client {
34 pub(crate) client: aws_sdk_dynamodb::Client,
35 pub(crate) table_name: Arc<String>,
36 pub(crate) lease_ttl_seconds: u32,
37 pub(crate) extend_period: Duration,
38 pub(crate) acquire_cooldown: Duration,
39 pub(crate) local_locks: LocalLocks,
40}
41
42impl Client {
43 pub fn builder() -> ClientBuilder {
45 <_>::default()
46 }
47
48 #[instrument(skip_all)]
54 pub async fn try_acquire(&self, key: impl Into<String>) -> anyhow::Result<Option<Lease>> {
55 let key = key.into();
56 let local_guard = match self.local_locks.try_lock(key.clone()) {
57 Ok(g) => g,
58 Err(_) => return Ok(None),
59 };
60
61 match self.put_lease(key).await {
62 Ok(Some(lease)) => Ok(Some(lease.with_local_guard(local_guard))),
63 x => x,
64 }
65 }
66
67 #[instrument(skip_all)]
72 pub async fn acquire(&self, key: impl Into<String>) -> anyhow::Result<Lease> {
73 let key = key.into();
74 let local_guard = self.local_locks.lock(key.clone()).await;
75
76 loop {
77 if let Some(lease) = self.put_lease(key.clone()).await? {
78 return Ok(lease.with_local_guard(local_guard));
79 }
80 tokio::time::sleep(self.acquire_cooldown).await;
81 }
82 }
83
84 #[instrument(skip_all)]
89 pub async fn acquire_timeout(
90 &self,
91 key: impl Into<String>,
92 max_wait: Duration,
93 ) -> anyhow::Result<Lease> {
94 let start = Instant::now();
95 let key = key.into();
96
97 let local_guard = tokio::time::timeout(max_wait, self.local_locks.lock(key.clone()))
98 .await
99 .context("Could not acquire within {max_wait:?}")?;
100
101 loop {
102 if let Some(lease) = self.put_lease(key.clone()).await? {
103 return Ok(lease.with_local_guard(local_guard));
104 }
105 let elapsed = start.elapsed();
106 if elapsed > max_wait {
107 bail!("Could not acquire within {max_wait:?}");
108 }
109 let remaining_max_wait = max_wait - elapsed;
110 tokio::time::sleep(min(self.acquire_cooldown, remaining_max_wait)).await;
111 }
112 }
113
114 async fn put_lease(&self, key: String) -> anyhow::Result<Option<Lease>> {
116 let now_timestamp = OffsetDateTime::now_utc().unix_timestamp();
117 let expiry_timestamp = now_timestamp + i64::from(self.lease_ttl_seconds);
118 let lease_v = Uuid::new_v4();
119
120 let put = self
121 .client
122 .put_item()
123 .table_name(self.table_name.as_str())
124 .item(KEY_FIELD, AttributeValue::S(key.clone()))
125 .item(
126 LEASE_EXPIRY_FIELD,
127 AttributeValue::N(expiry_timestamp.to_string()),
128 )
129 .item(LEASE_VERSION_FIELD, AttributeValue::S(lease_v.to_string()))
130 .condition_expression(format!(
131 "attribute_not_exists({LEASE_VERSION_FIELD}) OR {LEASE_EXPIRY_FIELD} < :now"
132 ))
133 .expression_attribute_values(":now", AttributeValue::N(now_timestamp.to_string()))
134 .send()
135 .await;
136
137 match put {
138 Err(SdkError::ServiceError(se))
139 if matches!(se.err(), PutItemError::ConditionalCheckFailedException(..)) =>
140 {
141 Ok(None)
142 }
143 Err(err) => Err(err.into()),
144 Ok(_) => Ok(Some(Lease::new(self.clone(), key, lease_v))),
145 }
146 }
147
148 #[instrument(skip_all)]
150 pub(crate) async fn delete_lease(
151 &self,
152 key: String,
153 lease_v: Uuid,
154 ) -> Result<DeleteItemOutput, SdkError<DeleteItemError, orchestrator::HttpResponse>> {
155 self.client
156 .delete_item()
157 .table_name(self.table_name.as_str())
158 .key(KEY_FIELD, AttributeValue::S(key))
159 .condition_expression(format!("{LEASE_VERSION_FIELD}=:lease_v"))
160 .expression_attribute_values(":lease_v", AttributeValue::S(lease_v.to_string()))
161 .send()
162 .await
163 }
164
165 pub(crate) fn try_clean_local_lock(&self, key: String) {
167 self.local_locks.try_remove(key)
168 }
169
170 #[instrument(skip_all)]
172 pub(crate) async fn extend_lease(
173 &self,
174 key: String,
175 lease_v: Uuid,
176 ) -> Result<Uuid, SdkError<UpdateItemError, orchestrator::HttpResponse>> {
177 let expiry_timestamp =
178 OffsetDateTime::now_utc().unix_timestamp() + i64::from(self.lease_ttl_seconds);
179 let new_lease_v = Uuid::new_v4();
180
181 self.client
182 .update_item()
183 .table_name(self.table_name.as_str())
184 .key(KEY_FIELD, AttributeValue::S(key))
185 .update_expression(format!(
186 "SET {LEASE_VERSION_FIELD}=:new_lease_v, {LEASE_EXPIRY_FIELD}=:expiry"
187 ))
188 .condition_expression(format!("{LEASE_VERSION_FIELD}=:lease_v"))
189 .expression_attribute_values(":new_lease_v", AttributeValue::S(new_lease_v.to_string()))
190 .expression_attribute_values(":lease_v", AttributeValue::S(lease_v.to_string()))
191 .expression_attribute_values(":expiry", AttributeValue::N(expiry_timestamp.to_string()))
192 .send()
193 .await?;
194
195 Ok(new_lease_v)
196 }
197
198 pub(crate) async fn check_schema(&self) -> anyhow::Result<()> {
200 let (table_desc, ttl_desc) = tokio::join!(
202 self.client
203 .describe_table()
204 .table_name(self.table_name.as_str())
205 .send(),
206 self.client
207 .describe_time_to_live()
208 .table_name(self.table_name.as_str())
209 .send()
210 );
211
212 let desc = table_desc
213 .with_context(|| format!("Missing table `{}`?", self.table_name))?
214 .table
215 .context("no table description")?;
216
217 let attrs = desc.attribute_definitions.unwrap_or_default();
219 let key_schema = desc.key_schema.unwrap_or_default();
220 ensure!(
221 key_schema.len() == 1,
222 "Unexpected number of keys ({}) in key_schema, expected 1. Got {:?}",
223 key_schema.len(),
224 vec(key_schema.iter().map(|k| k.attribute_name())),
225 );
226 let described_kind = attrs
227 .iter()
228 .find(|attr| attr.attribute_name() == KEY_FIELD)
229 .with_context(|| {
230 format!(
231 "Missing attribute definition for {KEY_FIELD}, available {:?}",
232 vec(attrs.iter().map(|a| a.attribute_name()))
233 )
234 })?
235 .attribute_type();
236 ensure!(
237 described_kind == &ScalarAttributeType::S,
238 "Unexpected attribute type `{:?}` for {}, expected `{:?}`",
239 described_kind,
240 KEY_FIELD,
241 ScalarAttributeType::S,
242 );
243
244 let described_key_type = key_schema
245 .iter()
246 .find(|k| k.attribute_name() == KEY_FIELD)
247 .with_context(|| {
248 format!(
249 "Missing key schema for {KEY_FIELD}, available {:?}",
250 vec(key_schema.iter().map(|k| k.attribute_name()))
251 )
252 })?
253 .key_type();
254 ensure!(
255 described_key_type == &KeyType::Hash,
256 "Unexpected key type `{:?}` for {}, expected `{:?}`",
257 described_key_type,
258 KEY_FIELD,
259 KeyType::Hash,
260 );
261
262 let update_time_to_live_desc = ttl_desc
264 .with_context(|| format!("Missing time_to_live for table `{}`?", self.table_name))?
265 .time_to_live_description
266 .context("no time to live description")?;
267
268 ensure!(
269 update_time_to_live_desc.attribute_name() == Some(LEASE_EXPIRY_FIELD),
270 "time to live for {} is not set",
271 LEASE_EXPIRY_FIELD,
272 );
273
274 Ok(())
275 }
276}
277
278#[inline]
279fn vec<T>(iter: impl Iterator<Item = T>) -> Vec<T> {
280 iter.collect()
281}