dynamo_runtime/storage/key_value_store/
etcd.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashMap;
5use std::pin::Pin;
6use std::time::Duration;
7
8use crate::{
9    storage::key_value_store::{Key, KeyValue, WatchEvent},
10    transports::etcd,
11};
12use async_stream::stream;
13use async_trait::async_trait;
14use etcd_client::{Compare, CompareOp, EventType, PutOptions, Txn, TxnOp, WatchOptions};
15
16use super::{KeyValueBucket, KeyValueStore, StoreError, StoreOutcome};
17
18#[derive(Clone)]
19pub struct EtcdStore {
20    client: etcd::Client,
21}
22
23impl EtcdStore {
24    pub fn new(client: etcd::Client) -> Self {
25        Self { client }
26    }
27}
28
29#[async_trait]
30impl KeyValueStore for EtcdStore {
31    type Bucket = EtcdBucket;
32
33    /// A "bucket" in etcd is a path prefix
34    async fn get_or_create_bucket(
35        &self,
36        bucket_name: &str,
37        _ttl: Option<Duration>, // TODO ttl not used yet
38    ) -> Result<Self::Bucket, StoreError> {
39        Ok(EtcdBucket {
40            client: self.client.clone(),
41            bucket_name: bucket_name.to_string(),
42        })
43    }
44
45    /// A "bucket" in etcd is a path prefix. This creates an EtcdBucket object without doing
46    /// any network calls.
47    async fn get_bucket(&self, bucket_name: &str) -> Result<Option<Self::Bucket>, StoreError> {
48        Ok(Some(EtcdBucket {
49            client: self.client.clone(),
50            bucket_name: bucket_name.to_string(),
51        }))
52    }
53
54    fn connection_id(&self) -> u64 {
55        self.client.lease_id()
56    }
57
58    fn shutdown(&self) {
59        // Revoke the lease? etcd will do it for us on disconnect.
60    }
61}
62
63pub struct EtcdBucket {
64    client: etcd::Client,
65    bucket_name: String,
66}
67
68#[async_trait]
69impl KeyValueBucket for EtcdBucket {
70    async fn insert(
71        &self,
72        key: &Key,
73        value: bytes::Bytes,
74        // "version" in etcd speak. revision is a global cluster-wide value
75        revision: u64,
76    ) -> Result<StoreOutcome, StoreError> {
77        let version = revision;
78        if version == 0 {
79            self.create(key, value).await
80        } else {
81            self.update(key, value, version).await
82        }
83    }
84
85    async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError> {
86        let k = make_key(&self.bucket_name, key);
87        tracing::trace!("etcd get: {k}");
88
89        let mut kvs = self
90            .client
91            .kv_get(k, None)
92            .await
93            .map_err(|e| StoreError::EtcdError(e.to_string()))?;
94        if kvs.is_empty() {
95            return Ok(None);
96        }
97        let (_, val) = kvs.swap_remove(0).into_key_value();
98        Ok(Some(val.into()))
99    }
100
101    async fn delete(&self, key: &Key) -> Result<(), StoreError> {
102        let k = make_key(&self.bucket_name, key);
103        tracing::trace!("etcd delete: {k}");
104        let _ = self
105            .client
106            .kv_delete(k, None)
107            .await
108            .map_err(|e| StoreError::EtcdError(e.to_string()))?;
109        Ok(())
110    }
111
112    async fn watch(
113        &self,
114    ) -> Result<Pin<Box<dyn futures::Stream<Item = WatchEvent> + Send + 'life0>>, StoreError> {
115        let prefix = make_key(&self.bucket_name, &"".into());
116        tracing::trace!("etcd watch: {prefix}");
117        let watcher = self
118            .client
119            .kv_watch_prefix(&prefix)
120            .await
121            .map_err(|e| StoreError::EtcdError(e.to_string()))?;
122        let (_, mut watch_stream) = watcher.dissolve();
123        let output = stream! {
124            while let Some(event) = watch_stream.recv().await {
125                match event {
126                    etcd::WatchEvent::Put(kv) => {
127                        let (k, v) = kv.into_key_value();
128                        let key = match String::from_utf8(k) {
129                            Ok(k) => k,
130                            Err(err) => {
131                                tracing::error!(%err, prefix, "Invalid UTF8 in etcd key");
132                                continue;
133                            }
134                        };
135                        let item = KeyValue::new(key, v.into());
136                        yield WatchEvent::Put(item);
137                    }
138                    etcd::WatchEvent::Delete(kv) => {
139                        let (k, _) = kv.into_key_value();
140                        let key = match String::from_utf8(k) {
141                            Ok(k) => k,
142                            Err(err) => {
143                                tracing::error!(%err, prefix, "Invalid UTF8 in etcd key");
144                                continue;
145                            }
146                        };
147                        yield WatchEvent::Delete(Key::from_raw(key));
148                    }
149                }
150            }
151        };
152        Ok(Box::pin(output))
153    }
154
155    async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StoreError> {
156        let k = make_key(&self.bucket_name, &"".into());
157        tracing::trace!("etcd entries: {k}");
158
159        let resp = self
160            .client
161            .kv_get_prefix(k)
162            .await
163            .map_err(|e| StoreError::EtcdError(e.to_string()))?;
164        let out: HashMap<String, bytes::Bytes> = resp
165            .into_iter()
166            .map(|kv| {
167                let (k, v) = kv.into_key_value();
168                (String::from_utf8_lossy(&k).to_string(), v.into())
169            })
170            .collect();
171
172        Ok(out)
173    }
174}
175
176impl EtcdBucket {
177    async fn create(
178        &self,
179        key: &Key,
180        value: impl Into<Vec<u8>>,
181    ) -> Result<StoreOutcome, StoreError> {
182        let k = make_key(&self.bucket_name, key);
183        tracing::trace!("etcd create: {k}");
184
185        // Use atomic transaction to check and create in one operation
186        let put_options = PutOptions::new().with_lease(self.client.lease_id() as i64);
187
188        // Build transaction that creates key only if it doesn't exist
189        let txn = Txn::new()
190            .when(vec![Compare::version(k.as_str(), CompareOp::Equal, 0)]) // Atomic check
191            .and_then(vec![TxnOp::put(k.as_str(), value, Some(put_options))]) // Only if check passes
192            .or_else(vec![
193                TxnOp::get(k.as_str(), None), // Key exists, get its info
194            ]);
195
196        // Execute the transaction
197        let result = self
198            .client
199            .etcd_client()
200            .kv_client()
201            .txn(txn)
202            .await
203            .map_err(|e| StoreError::EtcdError(e.to_string()))?;
204
205        if result.succeeded() {
206            // Key was created successfully
207            return Ok(StoreOutcome::Created(1)); // version of new key is always 1
208        }
209
210        // Key already existed, get its version
211        if let Some(etcd_client::TxnOpResponse::Get(get_resp)) =
212            result.op_responses().into_iter().next()
213            && let Some(kv) = get_resp.kvs().first()
214        {
215            let version = kv.version() as u64;
216            return Ok(StoreOutcome::Exists(version));
217        }
218        // Shouldn't happen, but handle edge case
219        Err(StoreError::EtcdError(
220            "Unexpected transaction response".to_string(),
221        ))
222    }
223
224    async fn update(
225        &self,
226        key: &Key,
227        value: impl AsRef<[u8]>,
228        revision: u64,
229    ) -> Result<StoreOutcome, StoreError> {
230        let version = revision;
231        let k = make_key(&self.bucket_name, key);
232        tracing::trace!("etcd update: {k}");
233
234        let kvs = self
235            .client
236            .kv_get(k.clone(), None)
237            .await
238            .map_err(|e| StoreError::EtcdError(e.to_string()))?;
239        if kvs.is_empty() {
240            return Err(StoreError::MissingKey(key.to_string()));
241        }
242        let current_version = kvs.first().unwrap().version() as u64;
243        if current_version != version + 1 {
244            tracing::warn!(
245                current_version,
246                attempted_next_version = version,
247                %key,
248                "update: Wrong revision"
249            );
250            // NATS does a resync_update, overwriting the key anyway and getting the new revision.
251            // So we do too in etcd.
252        }
253
254        let put_options = PutOptions::new()
255            .with_lease(self.client.lease_id() as i64)
256            .with_prev_key();
257        let mut put_resp = self
258            .client
259            .kv_put_with_options(k, value, Some(put_options))
260            .await
261            .map_err(|e| StoreError::EtcdError(e.to_string()))?;
262        Ok(match put_resp.take_prev_key() {
263            // Should this be an error?
264            // The key was deleted between our get and put. We re-created it.
265            // Version of new key is always 1.
266            // <https://etcd.io/docs/v3.5/learning/data_model/>
267            None => StoreOutcome::Created(1),
268            // Expected case, success
269            Some(kv) if kv.version() as u64 == version + 1 => StoreOutcome::Created(version),
270            // Should this be an error? Something updated the version between our get and put
271            Some(kv) => StoreOutcome::Created(kv.version() as u64 + 1),
272        })
273    }
274}
275
276fn make_key(bucket_name: &str, key: &Key) -> String {
277    [bucket_name.to_string(), key.to_string()].join("/")
278}
279
280#[cfg(feature = "integration")]
281#[cfg(test)]
282mod concurrent_create_tests {
283    use super::*;
284    use crate::{DistributedRuntime, Runtime, distributed::DistributedConfig};
285    use std::sync::Arc;
286    use tokio::sync::Barrier;
287
288    #[test]
289    fn test_concurrent_etcd_create_race_condition() {
290        let rt = Runtime::from_settings().unwrap();
291        let rt_clone = rt.clone();
292        let config = DistributedConfig::from_settings(false);
293
294        rt_clone.primary().block_on(async move {
295            let drt = DistributedRuntime::new(rt, config).await.unwrap();
296            test_concurrent_create(drt).await.unwrap();
297        });
298    }
299
300    async fn test_concurrent_create(drt: DistributedRuntime) -> Result<(), StoreError> {
301        let etcd_client = drt.etcd_client().expect("etcd client should be available");
302        let storage = EtcdStore::new(etcd_client);
303
304        // Create a bucket for testing
305        let bucket = Arc::new(tokio::sync::Mutex::new(
306            storage
307                .get_or_create_bucket("test_concurrent_bucket", None)
308                .await?,
309        ));
310
311        // Number of concurrent workers
312        let num_workers = 10;
313        let barrier = Arc::new(Barrier::new(num_workers));
314
315        // Shared test data
316        let test_key: Key = Key::new(&format!("concurrent_test_key_{}", uuid::Uuid::new_v4()));
317        let test_value = "test_value";
318
319        // Spawn multiple tasks that will all try to create the same key simultaneously
320        let mut handles = Vec::new();
321        let success_count = Arc::new(tokio::sync::Mutex::new(0));
322        let exists_count = Arc::new(tokio::sync::Mutex::new(0));
323
324        for worker_id in 0..num_workers {
325            let bucket_clone = bucket.clone();
326            let barrier_clone = barrier.clone();
327            let key_clone = test_key.clone();
328            let value_clone = format!("{}_from_worker_{}", test_value, worker_id);
329            let success_count_clone = success_count.clone();
330            let exists_count_clone = exists_count.clone();
331
332            let handle = tokio::spawn(async move {
333                // Wait for all workers to be ready
334                barrier_clone.wait().await;
335
336                // All workers try to create the same key at the same time
337                let result = bucket_clone
338                    .lock()
339                    .await
340                    .insert(&key_clone, value_clone.into(), 0)
341                    .await;
342
343                match result {
344                    Ok(StoreOutcome::Created(version)) => {
345                        println!(
346                            "Worker {} successfully created key with version {}",
347                            worker_id, version
348                        );
349                        let mut count = success_count_clone.lock().await;
350                        *count += 1;
351                        Ok(version)
352                    }
353                    Ok(StoreOutcome::Exists(version)) => {
354                        println!(
355                            "Worker {} found key already exists with version {}",
356                            worker_id, version
357                        );
358                        let mut count = exists_count_clone.lock().await;
359                        *count += 1;
360                        Ok(version)
361                    }
362                    Err(e) => {
363                        println!("Worker {} got error: {:?}", worker_id, e);
364                        Err(e)
365                    }
366                }
367            });
368
369            handles.push(handle);
370        }
371
372        // Wait for all workers to complete
373        let mut results = Vec::new();
374        for handle in handles {
375            let result = handle.await.unwrap();
376            if let Ok(version) = result {
377                results.push(version);
378            }
379        }
380
381        // Verify results
382        let final_success_count = *success_count.lock().await;
383        let final_exists_count = *exists_count.lock().await;
384
385        println!(
386            "Final counts - Created: {}, Exists: {}",
387            final_success_count, final_exists_count
388        );
389
390        // CRITICAL ASSERTIONS:
391        // 1. Exactly ONE worker should have successfully created the key
392        assert_eq!(
393            final_success_count, 1,
394            "Exactly one worker should create the key"
395        );
396
397        // 2. All other workers should have gotten "Exists" response
398        assert_eq!(
399            final_exists_count,
400            num_workers - 1,
401            "All other workers should see key exists"
402        );
403
404        // 3. Total successful operations should equal number of workers
405        assert_eq!(
406            results.len(),
407            num_workers,
408            "All workers should complete successfully"
409        );
410
411        // 4. Verify the key actually exists in etcd
412        let stored_value = bucket.lock().await.get(&test_key).await?;
413        assert!(stored_value.is_some(), "Key should exist in etcd");
414
415        // 5. The stored value should be from one of the workers
416        let stored_str = String::from_utf8(stored_value.unwrap().to_vec()).unwrap();
417        assert!(
418            stored_str.starts_with(test_value),
419            "Stored value should match expected prefix"
420        );
421
422        // Clean up
423        bucket.lock().await.delete(&test_key).await?;
424
425        Ok(())
426    }
427}