dynamodb_lease/
client.rs

1use crate::{ClientBuilder, Lease, local::LocalLocks};
2use anyhow::{Context, bail, ensure};
3use aws_sdk_dynamodb::{
4    error::SdkError,
5    operation::{
6        delete_item::{DeleteItemError, DeleteItemOutput},
7        put_item::PutItemError,
8        update_item::UpdateItemError,
9    },
10    types::{AttributeValue, KeyType, ScalarAttributeType},
11};
12use aws_smithy_runtime_api::client::orchestrator;
13use std::{
14    cmp::min,
15    sync::Arc,
16    time::{Duration, Instant},
17};
18use time::OffsetDateTime;
19use tracing::instrument;
20use uuid::Uuid;
21
22const KEY_FIELD: &str = "key";
23const LEASE_EXPIRY_FIELD: &str = "lease_expiry";
24const LEASE_VERSION_FIELD: &str = "lease_version";
25
26/// Client for acquiring [`Lease`]s.
27///
28/// Communicates with dynamodb to acquire, extend and delete distributed leases.
29///
30/// Local mutex locks are also used to eliminate db contention for usage within
31/// a single `Client` instance or clone.
32#[derive(Debug, Clone)]
33pub struct Client {
34    pub(crate) client: aws_sdk_dynamodb::Client,
35    pub(crate) table_name: Arc<String>,
36    pub(crate) lease_ttl_seconds: u32,
37    pub(crate) extend_period: Duration,
38    pub(crate) acquire_cooldown: Duration,
39    pub(crate) local_locks: LocalLocks,
40}
41
42impl Client {
43    /// Returns a new [`Client`] builder.
44    pub fn builder() -> ClientBuilder {
45        <_>::default()
46    }
47
48    /// Tries to acquire a new [`Lease`] for the given `key`.
49    ///
50    /// If this lease has already been acquired elsewhere `Ok(None)` is returned.
51    ///
52    /// Does not wait to acquire a lease, to do so see [`Client::acquire`].
53    #[instrument(skip_all)]
54    pub async fn try_acquire(&self, key: impl Into<String>) -> anyhow::Result<Option<Lease>> {
55        let key = key.into();
56        let local_guard = match self.local_locks.try_lock(key.clone()) {
57            Ok(g) => g,
58            Err(_) => return Ok(None),
59        };
60
61        match self.put_lease(key).await {
62            Ok(Some(lease)) => Ok(Some(lease.with_local_guard(local_guard))),
63            x => x,
64        }
65    }
66
67    /// Acquires a new [`Lease`] for the given `key`. May wait until successful if the lease
68    /// has already been acquired elsewhere.
69    ///
70    /// To try to acquire without waiting see [`Client::try_acquire`].
71    #[instrument(skip_all)]
72    pub async fn acquire(&self, key: impl Into<String>) -> anyhow::Result<Lease> {
73        let key = key.into();
74        let local_guard = self.local_locks.lock(key.clone()).await;
75
76        loop {
77            if let Some(lease) = self.put_lease(key.clone()).await? {
78                return Ok(lease.with_local_guard(local_guard));
79            }
80            tokio::time::sleep(self.acquire_cooldown).await;
81        }
82    }
83
84    /// Acquires a new [`Lease`] for the given `key`. May wait until successful if the lease
85    /// has already been acquired elsewhere up to a max of `max_wait`.
86    ///
87    /// To try to acquire without waiting see [`Client::try_acquire`].
88    #[instrument(skip_all)]
89    pub async fn acquire_timeout(
90        &self,
91        key: impl Into<String>,
92        max_wait: Duration,
93    ) -> anyhow::Result<Lease> {
94        let start = Instant::now();
95        let key = key.into();
96
97        let local_guard = tokio::time::timeout(max_wait, self.local_locks.lock(key.clone()))
98            .await
99            .context("Could not acquire within {max_wait:?}")?;
100
101        loop {
102            if let Some(lease) = self.put_lease(key.clone()).await? {
103                return Ok(lease.with_local_guard(local_guard));
104            }
105            let elapsed = start.elapsed();
106            if elapsed > max_wait {
107                bail!("Could not acquire within {max_wait:?}");
108            }
109            let remaining_max_wait = max_wait - elapsed;
110            tokio::time::sleep(min(self.acquire_cooldown, remaining_max_wait)).await;
111        }
112    }
113
114    /// Put a new lease into the db.
115    async fn put_lease(&self, key: String) -> anyhow::Result<Option<Lease>> {
116        let now_timestamp = OffsetDateTime::now_utc().unix_timestamp();
117        let expiry_timestamp = now_timestamp + i64::from(self.lease_ttl_seconds);
118        let lease_v = Uuid::new_v4();
119
120        let put = self
121            .client
122            .put_item()
123            .table_name(self.table_name.as_str())
124            .item(KEY_FIELD, AttributeValue::S(key.clone()))
125            .item(
126                LEASE_EXPIRY_FIELD,
127                AttributeValue::N(expiry_timestamp.to_string()),
128            )
129            .item(LEASE_VERSION_FIELD, AttributeValue::S(lease_v.to_string()))
130            .condition_expression(format!(
131                "attribute_not_exists({LEASE_VERSION_FIELD}) OR {LEASE_EXPIRY_FIELD} < :now"
132            ))
133            .expression_attribute_values(":now", AttributeValue::N(now_timestamp.to_string()))
134            .send()
135            .await;
136
137        match put {
138            Err(SdkError::ServiceError(se))
139                if matches!(se.err(), PutItemError::ConditionalCheckFailedException(..)) =>
140            {
141                Ok(None)
142            }
143            Err(err) => Err(err.into()),
144            Ok(_) => Ok(Some(Lease::new(self.clone(), key, lease_v))),
145        }
146    }
147
148    /// Delete a lease with a given `key` & `lease_v`.
149    #[instrument(skip_all)]
150    pub(crate) async fn delete_lease(
151        &self,
152        key: String,
153        lease_v: Uuid,
154    ) -> Result<DeleteItemOutput, SdkError<DeleteItemError, orchestrator::HttpResponse>> {
155        self.client
156            .delete_item()
157            .table_name(self.table_name.as_str())
158            .key(KEY_FIELD, AttributeValue::S(key))
159            .condition_expression(format!("{LEASE_VERSION_FIELD}=:lease_v"))
160            .expression_attribute_values(":lease_v", AttributeValue::S(lease_v.to_string()))
161            .send()
162            .await
163    }
164
165    /// Cleanup local lock memory for the given `key` if not in use.
166    pub(crate) fn try_clean_local_lock(&self, key: String) {
167        self.local_locks.try_remove(key)
168    }
169
170    /// Extends an active lease. Returns the new `lease_v` uuid.
171    #[instrument(skip_all)]
172    pub(crate) async fn extend_lease(
173        &self,
174        key: String,
175        lease_v: Uuid,
176    ) -> Result<Uuid, SdkError<UpdateItemError, orchestrator::HttpResponse>> {
177        let expiry_timestamp =
178            OffsetDateTime::now_utc().unix_timestamp() + i64::from(self.lease_ttl_seconds);
179        let new_lease_v = Uuid::new_v4();
180
181        self.client
182            .update_item()
183            .table_name(self.table_name.as_str())
184            .key(KEY_FIELD, AttributeValue::S(key))
185            .update_expression(format!(
186                "SET {LEASE_VERSION_FIELD}=:new_lease_v, {LEASE_EXPIRY_FIELD}=:expiry"
187            ))
188            .condition_expression(format!("{LEASE_VERSION_FIELD}=:lease_v"))
189            .expression_attribute_values(":new_lease_v", AttributeValue::S(new_lease_v.to_string()))
190            .expression_attribute_values(":lease_v", AttributeValue::S(lease_v.to_string()))
191            .expression_attribute_values(":expiry", AttributeValue::N(expiry_timestamp.to_string()))
192            .send()
193            .await?;
194
195        Ok(new_lease_v)
196    }
197
198    /// Checks table is active & has a valid schema.
199    pub(crate) async fn check_schema(&self) -> anyhow::Result<()> {
200        // fetch table & ttl descriptions concurrently
201        let (table_desc, ttl_desc) = tokio::join!(
202            self.client
203                .describe_table()
204                .table_name(self.table_name.as_str())
205                .send(),
206            self.client
207                .describe_time_to_live()
208                .table_name(self.table_name.as_str())
209                .send()
210        );
211
212        let desc = table_desc
213            .with_context(|| format!("Missing table `{}`?", self.table_name))?
214            .table
215            .context("no table description")?;
216
217        // check "key" field is a S hash key
218        let attrs = desc.attribute_definitions.unwrap_or_default();
219        let key_schema = desc.key_schema.unwrap_or_default();
220        ensure!(
221            key_schema.len() == 1,
222            "Unexpected number of keys ({}) in key_schema, expected 1. Got {:?}",
223            key_schema.len(),
224            vec(key_schema.iter().map(|k| k.attribute_name())),
225        );
226        let described_kind = attrs
227            .iter()
228            .find(|attr| attr.attribute_name() == KEY_FIELD)
229            .with_context(|| {
230                format!(
231                    "Missing attribute definition for {KEY_FIELD}, available {:?}",
232                    vec(attrs.iter().map(|a| a.attribute_name()))
233                )
234            })?
235            .attribute_type();
236        ensure!(
237            described_kind == &ScalarAttributeType::S,
238            "Unexpected attribute type `{:?}` for {}, expected `{:?}`",
239            described_kind,
240            KEY_FIELD,
241            ScalarAttributeType::S,
242        );
243
244        let described_key_type = key_schema
245            .iter()
246            .find(|k| k.attribute_name() == KEY_FIELD)
247            .with_context(|| {
248                format!(
249                    "Missing key schema for {KEY_FIELD}, available {:?}",
250                    vec(key_schema.iter().map(|k| k.attribute_name()))
251                )
252            })?
253            .key_type();
254        ensure!(
255            described_key_type == &KeyType::Hash,
256            "Unexpected key type `{:?}` for {}, expected `{:?}`",
257            described_key_type,
258            KEY_FIELD,
259            KeyType::Hash,
260        );
261
262        // check "lease_expiry" is a ttl field
263        let update_time_to_live_desc = ttl_desc
264            .with_context(|| format!("Missing time_to_live for table `{}`?", self.table_name))?
265            .time_to_live_description
266            .context("no time to live description")?;
267
268        ensure!(
269            update_time_to_live_desc.attribute_name() == Some(LEASE_EXPIRY_FIELD),
270            "time to live for {} is not set",
271            LEASE_EXPIRY_FIELD,
272        );
273
274        Ok(())
275    }
276}
277
278#[inline]
279fn vec<T>(iter: impl Iterator<Item = T>) -> Vec<T> {
280    iter.collect()
281}