dynamo_runtime/storage/key_value_store/
etcd.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashMap;
5use std::pin::Pin;
6use std::time::Duration;
7
8use crate::{
9    storage::key_value_store::{Key, KeyValue, WatchEvent},
10    transports::etcd::Client,
11};
12use async_stream::stream;
13use async_trait::async_trait;
14use etcd_client::{Compare, CompareOp, EventType, PutOptions, Txn, TxnOp, WatchOptions};
15
16use super::{KeyValueBucket, KeyValueStore, StoreError, StoreOutcome};
17
18#[derive(Clone)]
19pub struct EtcdStore {
20    client: Client,
21}
22
23impl EtcdStore {
24    pub fn new(client: Client) -> Self {
25        Self { client }
26    }
27}
28
29#[async_trait]
30impl KeyValueStore for EtcdStore {
31    type Bucket = EtcdBucket;
32
33    /// A "bucket" in etcd is a path prefix
34    async fn get_or_create_bucket(
35        &self,
36        bucket_name: &str,
37        _ttl: Option<Duration>, // TODO ttl not used yet
38    ) -> Result<Self::Bucket, StoreError> {
39        Ok(EtcdBucket {
40            client: self.client.clone(),
41            bucket_name: bucket_name.to_string(),
42        })
43    }
44
45    /// A "bucket" in etcd is a path prefix. This creates an EtcdBucket object without doing
46    /// any network calls.
47    async fn get_bucket(&self, bucket_name: &str) -> Result<Option<Self::Bucket>, StoreError> {
48        Ok(Some(EtcdBucket {
49            client: self.client.clone(),
50            bucket_name: bucket_name.to_string(),
51        }))
52    }
53
54    fn connection_id(&self) -> u64 {
55        self.client.lease_id()
56    }
57}
58
59pub struct EtcdBucket {
60    client: Client,
61    bucket_name: String,
62}
63
64#[async_trait]
65impl KeyValueBucket for EtcdBucket {
66    async fn insert(
67        &self,
68        key: &Key,
69        value: &str,
70        // "version" in etcd speak. revision is a global cluster-wide value
71        revision: u64,
72    ) -> Result<StoreOutcome, StoreError> {
73        let version = revision;
74        if version == 0 {
75            self.create(key, value).await
76        } else {
77            self.update(key, value, version).await
78        }
79    }
80
81    async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError> {
82        let k = make_key(&self.bucket_name, key);
83        tracing::trace!("etcd get: {k}");
84
85        let mut kvs = self
86            .client
87            .kv_get(k, None)
88            .await
89            .map_err(|e| StoreError::EtcdError(e.to_string()))?;
90        if kvs.is_empty() {
91            return Ok(None);
92        }
93        let (_, val) = kvs.swap_remove(0).into_key_value();
94        Ok(Some(val.into()))
95    }
96
97    async fn delete(&self, key: &Key) -> Result<(), StoreError> {
98        let k = make_key(&self.bucket_name, key);
99        tracing::trace!("etcd delete: {k}");
100        let _ = self
101            .client
102            .kv_delete(k, None)
103            .await
104            .map_err(|e| StoreError::EtcdError(e.to_string()))?;
105        Ok(())
106    }
107
108    async fn watch(
109        &self,
110    ) -> Result<Pin<Box<dyn futures::Stream<Item = WatchEvent> + Send + 'life0>>, StoreError> {
111        let prefix = make_key(&self.bucket_name, &"".into());
112        tracing::trace!("etcd watch: {prefix}");
113        let (watcher, mut watch_stream) = self
114            .client
115            .etcd_client()
116            .clone()
117            .watch(prefix.as_bytes(), Some(WatchOptions::new().with_prefix()))
118            .await
119            .map_err(|e| StoreError::EtcdError(e.to_string()))?;
120        let output = stream! {
121            let _watcher = watcher; // Keep it alive. Not sure if necessary.
122            while let Ok(Some(resp)) = watch_stream.message().await {
123                for e in resp.events() {
124                    let Some(kv) = e.kv() else {
125                        continue;
126                    };
127                    let (k_bytes, v_bytes) = kv.clone().into_key_value();
128                    let key = match String::from_utf8(k_bytes) {
129                        Ok(k) => k,
130                        Err(err) => {
131                            tracing::error!(%err, prefix, "Invalid UTF8 in etcd key");
132                            continue;
133                        }
134                    };
135                    let item = KeyValue::new(key, v_bytes.into());
136                    match e.event_type() {
137                        EventType::Put => {
138                            yield WatchEvent::Put(item);
139                        }
140                        EventType::Delete => {
141                            yield WatchEvent::Delete(item);
142                        }
143                    }
144                }
145            }
146        };
147        Ok(Box::pin(output))
148    }
149
150    async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StoreError> {
151        let k = make_key(&self.bucket_name, &"".into());
152        tracing::trace!("etcd entries: {k}");
153
154        let resp = self
155            .client
156            .kv_get_prefix(k)
157            .await
158            .map_err(|e| StoreError::EtcdError(e.to_string()))?;
159        let out: HashMap<String, bytes::Bytes> = resp
160            .into_iter()
161            .map(|kv| {
162                let (k, v) = kv.into_key_value();
163                (String::from_utf8_lossy(&k).to_string(), v.into())
164            })
165            .collect();
166
167        Ok(out)
168    }
169}
170
171impl EtcdBucket {
172    async fn create(&self, key: &Key, value: &str) -> Result<StoreOutcome, StoreError> {
173        let k = make_key(&self.bucket_name, key);
174        tracing::trace!("etcd create: {k}");
175
176        // Use atomic transaction to check and create in one operation
177        let put_options = PutOptions::new().with_lease(self.client.primary_lease().id() as i64);
178
179        // Build transaction that creates key only if it doesn't exist
180        let txn = Txn::new()
181            .when(vec![Compare::version(k.as_str(), CompareOp::Equal, 0)]) // Atomic check
182            .and_then(vec![TxnOp::put(k.as_str(), value, Some(put_options))]) // Only if check passes
183            .or_else(vec![
184                TxnOp::get(k.as_str(), None), // Key exists, get its info
185            ]);
186
187        // Execute the transaction
188        let result = self
189            .client
190            .etcd_client()
191            .kv_client()
192            .txn(txn)
193            .await
194            .map_err(|e| StoreError::EtcdError(e.to_string()))?;
195
196        if result.succeeded() {
197            // Key was created successfully
198            return Ok(StoreOutcome::Created(1)); // version of new key is always 1
199        }
200
201        // Key already existed, get its version
202        if let Some(etcd_client::TxnOpResponse::Get(get_resp)) =
203            result.op_responses().into_iter().next()
204            && let Some(kv) = get_resp.kvs().first()
205        {
206            let version = kv.version() as u64;
207            return Ok(StoreOutcome::Exists(version));
208        }
209        // Shouldn't happen, but handle edge case
210        Err(StoreError::EtcdError(
211            "Unexpected transaction response".to_string(),
212        ))
213    }
214
215    async fn update(
216        &self,
217        key: &Key,
218        value: &str,
219        revision: u64,
220    ) -> Result<StoreOutcome, StoreError> {
221        let version = revision;
222        let k = make_key(&self.bucket_name, key);
223        tracing::trace!("etcd update: {k}");
224
225        let kvs = self
226            .client
227            .kv_get(k.clone(), None)
228            .await
229            .map_err(|e| StoreError::EtcdError(e.to_string()))?;
230        if kvs.is_empty() {
231            return Err(StoreError::MissingKey(key.to_string()));
232        }
233        let current_version = kvs.first().unwrap().version() as u64;
234        if current_version != version + 1 {
235            tracing::warn!(
236                current_version,
237                attempted_next_version = version,
238                %key,
239                "update: Wrong revision"
240            );
241            // NATS does a resync_update, overwriting the key anyway and getting the new revision.
242            // So we do too in etcd.
243        }
244
245        let put_options = PutOptions::new()
246            .with_lease(self.client.primary_lease().id() as i64)
247            .with_prev_key();
248        let mut put_resp = self
249            .client
250            .kv_put_with_options(k, value, Some(put_options))
251            .await
252            .map_err(|e| StoreError::EtcdError(e.to_string()))?;
253        Ok(match put_resp.take_prev_key() {
254            // Should this be an error?
255            // The key was deleted between our get and put. We re-created it.
256            // Version of new key is always 1.
257            // <https://etcd.io/docs/v3.5/learning/data_model/>
258            None => StoreOutcome::Created(1),
259            // Expected case, success
260            Some(kv) if kv.version() as u64 == version + 1 => StoreOutcome::Created(version),
261            // Should this be an error? Something updated the version between our get and put
262            Some(kv) => StoreOutcome::Created(kv.version() as u64 + 1),
263        })
264    }
265}
266
267fn make_key(bucket_name: &str, key: &Key) -> String {
268    [bucket_name.to_string(), key.to_string()].join("/")
269}
270
271#[cfg(feature = "integration")]
272#[cfg(test)]
273mod concurrent_create_tests {
274    use super::*;
275    use crate::{DistributedRuntime, Runtime, distributed::DistributedConfig};
276    use std::sync::Arc;
277    use tokio::sync::Barrier;
278
279    #[test]
280    fn test_concurrent_etcd_create_race_condition() {
281        let rt = Runtime::from_settings().unwrap();
282        let rt_clone = rt.clone();
283        let config = DistributedConfig::from_settings(false);
284
285        rt_clone.primary().block_on(async move {
286            let drt = DistributedRuntime::new(rt, config).await.unwrap();
287            test_concurrent_create(drt).await.unwrap();
288        });
289    }
290
291    async fn test_concurrent_create(drt: DistributedRuntime) -> Result<(), StoreError> {
292        let etcd_client = drt.etcd_client().expect("etcd client should be available");
293        let storage = EtcdStore::new(etcd_client);
294
295        // Create a bucket for testing
296        let bucket = Arc::new(tokio::sync::Mutex::new(
297            storage
298                .get_or_create_bucket("test_concurrent_bucket", None)
299                .await?,
300        ));
301
302        // Number of concurrent workers
303        let num_workers = 10;
304        let barrier = Arc::new(Barrier::new(num_workers));
305
306        // Shared test data
307        let test_key: Key = Key::new(&format!("concurrent_test_key_{}", uuid::Uuid::new_v4()));
308        let test_value = "test_value";
309
310        // Spawn multiple tasks that will all try to create the same key simultaneously
311        let mut handles = Vec::new();
312        let success_count = Arc::new(tokio::sync::Mutex::new(0));
313        let exists_count = Arc::new(tokio::sync::Mutex::new(0));
314
315        for worker_id in 0..num_workers {
316            let bucket_clone = bucket.clone();
317            let barrier_clone = barrier.clone();
318            let key_clone = test_key.clone();
319            let value_clone = format!("{}_from_worker_{}", test_value, worker_id);
320            let success_count_clone = success_count.clone();
321            let exists_count_clone = exists_count.clone();
322
323            let handle = tokio::spawn(async move {
324                // Wait for all workers to be ready
325                barrier_clone.wait().await;
326
327                // All workers try to create the same key at the same time
328                let result = bucket_clone
329                    .lock()
330                    .await
331                    .insert(&key_clone, &value_clone, 0)
332                    .await;
333
334                match result {
335                    Ok(StoreOutcome::Created(version)) => {
336                        println!(
337                            "Worker {} successfully created key with version {}",
338                            worker_id, version
339                        );
340                        let mut count = success_count_clone.lock().await;
341                        *count += 1;
342                        Ok(version)
343                    }
344                    Ok(StoreOutcome::Exists(version)) => {
345                        println!(
346                            "Worker {} found key already exists with version {}",
347                            worker_id, version
348                        );
349                        let mut count = exists_count_clone.lock().await;
350                        *count += 1;
351                        Ok(version)
352                    }
353                    Err(e) => {
354                        println!("Worker {} got error: {:?}", worker_id, e);
355                        Err(e)
356                    }
357                }
358            });
359
360            handles.push(handle);
361        }
362
363        // Wait for all workers to complete
364        let mut results = Vec::new();
365        for handle in handles {
366            let result = handle.await.unwrap();
367            if let Ok(version) = result {
368                results.push(version);
369            }
370        }
371
372        // Verify results
373        let final_success_count = *success_count.lock().await;
374        let final_exists_count = *exists_count.lock().await;
375
376        println!(
377            "Final counts - Created: {}, Exists: {}",
378            final_success_count, final_exists_count
379        );
380
381        // CRITICAL ASSERTIONS:
382        // 1. Exactly ONE worker should have successfully created the key
383        assert_eq!(
384            final_success_count, 1,
385            "Exactly one worker should create the key"
386        );
387
388        // 2. All other workers should have gotten "Exists" response
389        assert_eq!(
390            final_exists_count,
391            num_workers - 1,
392            "All other workers should see key exists"
393        );
394
395        // 3. Total successful operations should equal number of workers
396        assert_eq!(
397            results.len(),
398            num_workers,
399            "All workers should complete successfully"
400        );
401
402        // 4. Verify the key actually exists in etcd
403        let stored_value = bucket.lock().await.get(&test_key).await?;
404        assert!(stored_value.is_some(), "Key should exist in etcd");
405
406        // 5. The stored value should be from one of the workers
407        let stored_str = String::from_utf8(stored_value.unwrap().to_vec()).unwrap();
408        assert!(
409            stored_str.starts_with(test_value),
410            "Stored value should match expected prefix"
411        );
412
413        // Clean up
414        bucket.lock().await.delete(&test_key).await?;
415
416        Ok(())
417    }
418}