Skip to main content

dynamo_runtime/storage/kv/
etcd.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashMap;
5use std::pin::Pin;
6use std::time::Duration;
7
8use crate::transports::etcd;
9use async_stream::stream;
10use async_trait::async_trait;
11use etcd_client::{Compare, CompareOp, EventType, PutOptions, Txn, TxnOp, WatchOptions};
12
13use super::{Bucket, Key, KeyValue, Store, StoreError, StoreOutcome, WatchEvent};
14
15#[derive(Clone)]
16pub struct EtcdStore {
17    client: etcd::Client,
18}
19
20impl EtcdStore {
21    pub fn new(client: etcd::Client) -> Self {
22        Self { client }
23    }
24}
25
26#[async_trait]
27impl Store for EtcdStore {
28    type Bucket = EtcdBucket;
29
30    /// A "bucket" in etcd is a path prefix
31    async fn get_or_create_bucket(
32        &self,
33        bucket_name: &str,
34        _ttl: Option<Duration>, // TODO ttl not used yet
35    ) -> Result<Self::Bucket, StoreError> {
36        Ok(EtcdBucket {
37            client: self.client.clone(),
38            bucket_name: bucket_name.to_string(),
39        })
40    }
41
42    /// A "bucket" in etcd is a path prefix. This creates an EtcdBucket object without doing
43    /// any network calls.
44    async fn get_bucket(&self, bucket_name: &str) -> Result<Option<Self::Bucket>, StoreError> {
45        Ok(Some(EtcdBucket {
46            client: self.client.clone(),
47            bucket_name: bucket_name.to_string(),
48        }))
49    }
50
51    fn connection_id(&self) -> u64 {
52        self.client.lease_id()
53    }
54
55    fn shutdown(&self) {
56        // Revoke the lease? etcd will do it for us on disconnect.
57    }
58}
59
60pub struct EtcdBucket {
61    client: etcd::Client,
62    bucket_name: String,
63}
64
65#[async_trait]
66impl Bucket for EtcdBucket {
67    async fn insert(
68        &self,
69        key: &Key,
70        value: bytes::Bytes,
71        // "version" in etcd speak. revision is a global cluster-wide value
72        revision: u64,
73    ) -> Result<StoreOutcome, StoreError> {
74        let version = revision;
75        if version == 0 {
76            self.create(key, value).await
77        } else {
78            self.update(key, value, version).await
79        }
80    }
81
82    async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError> {
83        let k = make_key(&self.bucket_name, key);
84        tracing::trace!("etcd get: {k}");
85
86        let mut kvs = self
87            .client
88            .kv_get(k, None)
89            .await
90            .map_err(|e| StoreError::EtcdError(e.to_string()))?;
91        if kvs.is_empty() {
92            return Ok(None);
93        }
94        let (_, val) = kvs.swap_remove(0).into_key_value();
95        Ok(Some(val.into()))
96    }
97
98    async fn delete(&self, key: &Key) -> Result<(), StoreError> {
99        let k = make_key(&self.bucket_name, key);
100        tracing::trace!("etcd delete: {k}");
101        let _ = self
102            .client
103            .kv_delete(k, None)
104            .await
105            .map_err(|e| StoreError::EtcdError(e.to_string()))?;
106        Ok(())
107    }
108
109    async fn watch(
110        &self,
111    ) -> Result<Pin<Box<dyn futures::Stream<Item = WatchEvent> + Send + 'life0>>, StoreError> {
112        let prefix = make_key(&self.bucket_name, &"".into());
113        tracing::trace!("etcd watch: {prefix}");
114        let watcher = self
115            .client
116            .kv_watch_prefix(&prefix)
117            .await
118            .map_err(|e| StoreError::EtcdError(e.to_string()))?;
119        let (_, mut watch_stream) = watcher.dissolve();
120        let output = stream! {
121            while let Some(event) = watch_stream.recv().await {
122                match event {
123                    etcd::WatchEvent::Put(kv) => {
124                        let (k, v) = kv.into_key_value();
125                        let key = match String::from_utf8(k) {
126                            Ok(k) => Key::new(k),
127                            Err(err) => {
128                                tracing::error!(%err, prefix, "Invalid UTF8 in etcd key");
129                                continue;
130                            }
131                        };
132                        let item = KeyValue::new(key, v.into());
133                        yield WatchEvent::Put(item);
134                    }
135                    etcd::WatchEvent::Delete(kv) => {
136                        let (k, _) = kv.into_key_value();
137                        let key = match String::from_utf8(k) {
138                            Ok(k) => Key::new(k),
139                            Err(err) => {
140                                tracing::error!(%err, prefix, "Invalid UTF8 in etcd key");
141                                continue;
142                            }
143                        };
144                        yield WatchEvent::Delete(key);
145                    }
146                }
147            }
148        };
149        Ok(Box::pin(output))
150    }
151
152    async fn entries(&self) -> Result<HashMap<Key, bytes::Bytes>, StoreError> {
153        let k = make_key(&self.bucket_name, &"".into());
154        tracing::trace!("etcd entries: {k}");
155
156        let resp = self
157            .client
158            .kv_get_prefix(k)
159            .await
160            .map_err(|e| StoreError::EtcdError(e.to_string()))?;
161        let out: HashMap<Key, bytes::Bytes> = resp
162            .into_iter()
163            .map(|kv| {
164                let (k, v) = kv.into_key_value();
165                (Key::new(String::from_utf8_lossy(&k).to_string()), v.into())
166            })
167            .collect();
168
169        Ok(out)
170    }
171}
172
173impl EtcdBucket {
174    async fn create(
175        &self,
176        key: &Key,
177        value: impl Into<Vec<u8>>,
178    ) -> Result<StoreOutcome, StoreError> {
179        let k = make_key(&self.bucket_name, key);
180        tracing::trace!("etcd create: {k}");
181
182        match self
183            .client
184            .kv_create(k.as_str(), value.into(), None)
185            .await
186            .map_err(|e| StoreError::EtcdError(e.to_string()))?
187        {
188            None => {
189                // Key was created successfully
190                Ok(StoreOutcome::Created(1)) // version of new key is always 1
191            }
192            Some(revision) => Ok(StoreOutcome::Exists(revision)),
193        }
194    }
195
196    async fn update(
197        &self,
198        key: &Key,
199        value: impl AsRef<[u8]>,
200        revision: u64,
201    ) -> Result<StoreOutcome, StoreError> {
202        let version = revision;
203        let k = make_key(&self.bucket_name, key);
204        tracing::trace!("etcd update: {k}");
205
206        let kvs = self
207            .client
208            .kv_get(k.clone(), None)
209            .await
210            .map_err(|e| StoreError::EtcdError(e.to_string()))?;
211        if kvs.is_empty() {
212            return Err(StoreError::MissingKey(key.to_string()));
213        }
214        let current_version = kvs.first().unwrap().version() as u64;
215        if current_version != version + 1 {
216            tracing::warn!(
217                current_version,
218                attempted_next_version = version,
219                %key,
220                "update: Wrong revision"
221            );
222            // NATS does a resync_update, overwriting the key anyway and getting the new revision.
223            // So we do too in etcd.
224        }
225
226        let put_options = PutOptions::new()
227            .with_lease(self.client.lease_id() as i64)
228            .with_prev_key();
229        let mut put_resp = self
230            .client
231            .kv_put_with_options(k, value, Some(put_options))
232            .await
233            .map_err(|e| StoreError::EtcdError(e.to_string()))?;
234        Ok(match put_resp.take_prev_key() {
235            // Should this be an error?
236            // The key was deleted between our get and put. We re-created it.
237            // Version of new key is always 1.
238            // <https://etcd.io/docs/v3.5/learning/data_model/>
239            None => StoreOutcome::Created(1),
240            // Expected case, success
241            Some(kv) if kv.version() as u64 == version + 1 => StoreOutcome::Created(version),
242            // Should this be an error? Something updated the version between our get and put
243            Some(kv) => StoreOutcome::Created(kv.version() as u64 + 1),
244        })
245    }
246}
247
248fn make_key(bucket_name: &str, key: &Key) -> String {
249    [bucket_name.to_string(), key.to_string()].join("/")
250}
251
252#[cfg(feature = "integration")]
253#[cfg(test)]
254mod concurrent_create_tests {
255    use super::*;
256    use crate::Runtime;
257    use crate::transports::etcd as etcd_transport;
258    use std::sync::Arc;
259    use tokio::sync::Barrier;
260
261    #[test]
262    fn test_concurrent_etcd_create_race_condition() {
263        let rt = Runtime::from_settings().unwrap();
264        let rt_clone = rt.clone();
265
266        rt_clone.primary().block_on(async move {
267            let etcd_client =
268                etcd_transport::Client::new(etcd_transport::ClientOptions::default(), rt)
269                    .await
270                    .unwrap();
271            let storage = crate::storage::kv::Manager::etcd(etcd_client);
272            test_concurrent_create(&storage).await.unwrap();
273        });
274    }
275
276    async fn test_concurrent_create(
277        storage: &crate::storage::kv::Manager,
278    ) -> Result<(), StoreError> {
279        // Create a bucket for testing
280        let bucket = Arc::new(tokio::sync::Mutex::new(
281            storage
282                .get_or_create_bucket("test_concurrent_bucket", None)
283                .await?,
284        ));
285
286        // Number of concurrent workers
287        let num_workers = 10;
288        let barrier = Arc::new(Barrier::new(num_workers));
289
290        // Shared test data
291        let test_key: Key = Key::new(format!("concurrent_test_key_{}", uuid::Uuid::new_v4()));
292        let test_value = "test_value";
293
294        // Spawn multiple tasks that will all try to create the same key simultaneously
295        let mut handles = Vec::new();
296        let success_count = Arc::new(tokio::sync::Mutex::new(0));
297        let exists_count = Arc::new(tokio::sync::Mutex::new(0));
298
299        for worker_id in 0..num_workers {
300            let bucket_clone = bucket.clone();
301            let barrier_clone = barrier.clone();
302            let key_clone = test_key.clone();
303            let value_clone = format!("{}_from_worker_{}", test_value, worker_id);
304            let success_count_clone = success_count.clone();
305            let exists_count_clone = exists_count.clone();
306
307            let handle = tokio::spawn(async move {
308                // Wait for all workers to be ready
309                barrier_clone.wait().await;
310
311                // All workers try to create the same key at the same time
312                let result = bucket_clone
313                    .lock()
314                    .await
315                    .insert(&key_clone, value_clone.into(), 0)
316                    .await;
317
318                match result {
319                    Ok(StoreOutcome::Created(version)) => {
320                        println!(
321                            "Worker {} successfully created key with version {}",
322                            worker_id, version
323                        );
324                        let mut count = success_count_clone.lock().await;
325                        *count += 1;
326                        Ok(version)
327                    }
328                    Ok(StoreOutcome::Exists(version)) => {
329                        println!(
330                            "Worker {} found key already exists with version {}",
331                            worker_id, version
332                        );
333                        let mut count = exists_count_clone.lock().await;
334                        *count += 1;
335                        Ok(version)
336                    }
337                    Err(e) => {
338                        println!("Worker {} got error: {:?}", worker_id, e);
339                        Err(e)
340                    }
341                }
342            });
343
344            handles.push(handle);
345        }
346
347        // Wait for all workers to complete
348        let mut results = Vec::new();
349        for handle in handles {
350            let result = handle.await.unwrap();
351            if let Ok(version) = result {
352                results.push(version);
353            }
354        }
355
356        // Verify results
357        let final_success_count = *success_count.lock().await;
358        let final_exists_count = *exists_count.lock().await;
359
360        println!(
361            "Final counts - Created: {}, Exists: {}",
362            final_success_count, final_exists_count
363        );
364
365        // CRITICAL ASSERTIONS:
366        // 1. Exactly ONE worker should have successfully created the key
367        assert_eq!(
368            final_success_count, 1,
369            "Exactly one worker should create the key"
370        );
371
372        // 2. All other workers should have gotten "Exists" response
373        assert_eq!(
374            final_exists_count,
375            num_workers - 1,
376            "All other workers should see key exists"
377        );
378
379        // 3. Total successful operations should equal number of workers
380        assert_eq!(
381            results.len(),
382            num_workers,
383            "All workers should complete successfully"
384        );
385
386        // 4. Verify the key actually exists in etcd
387        let stored_value = bucket.lock().await.get(&test_key).await?;
388        assert!(stored_value.is_some(), "Key should exist in etcd");
389
390        // 5. The stored value should be from one of the workers
391        let stored_str = String::from_utf8(stored_value.unwrap().to_vec()).unwrap();
392        assert!(
393            stored_str.starts_with(test_value),
394            "Stored value should match expected prefix"
395        );
396
397        // Clean up
398        bucket.lock().await.delete(&test_key).await?;
399
400        Ok(())
401    }
402}