dynamo_runtime/storage/key_value_store/
etcd.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashMap;
5use std::pin::Pin;
6use std::time::Duration;
7
8use crate::{storage::key_value_store::Key, transports::etcd::Client};
9use async_stream::stream;
10use async_trait::async_trait;
11use etcd_client::{Compare, CompareOp, EventType, PutOptions, Txn, TxnOp, WatchOptions};
12
13use super::{KeyValueBucket, KeyValueStore, StoreError, StoreOutcome};
14
15#[derive(Clone)]
16pub struct EtcdStore {
17    client: Client,
18}
19
20impl EtcdStore {
21    pub fn new(client: Client) -> Self {
22        Self { client }
23    }
24}
25
26#[async_trait]
27impl KeyValueStore for EtcdStore {
28    /// A "bucket" in etcd is a path prefix
29    async fn get_or_create_bucket(
30        &self,
31        bucket_name: &str,
32        _ttl: Option<Duration>, // TODO ttl not used yet
33    ) -> Result<Box<dyn KeyValueBucket>, StoreError> {
34        Ok(self.get_bucket(bucket_name).await?.unwrap())
35    }
36
37    /// A "bucket" in etcd is a path prefix. This creates an EtcdBucket object without doing
38    /// any network calls.
39    async fn get_bucket(
40        &self,
41        bucket_name: &str,
42    ) -> Result<Option<Box<dyn KeyValueBucket>>, StoreError> {
43        Ok(Some(Box::new(EtcdBucket {
44            client: self.client.clone(),
45            bucket_name: bucket_name.to_string(),
46        })))
47    }
48
49    fn connection_id(&self) -> u64 {
50        // This conversion from i64 to u64 is safe because etcd lease IDs are u64 internally.
51        // They present as i64 because of the limitations of the etcd grpc/HTTP JSON API.
52        self.client.lease_id() as u64
53    }
54}
55
56pub struct EtcdBucket {
57    client: Client,
58    bucket_name: String,
59}
60
61#[async_trait]
62impl KeyValueBucket for EtcdBucket {
63    async fn insert(
64        &self,
65        key: &Key,
66        value: &str,
67        // "version" in etcd speak. revision is a global cluster-wide value
68        revision: u64,
69    ) -> Result<StoreOutcome, StoreError> {
70        let version = revision;
71        if version == 0 {
72            self.create(key, value).await
73        } else {
74            self.update(key, value, version).await
75        }
76    }
77
78    async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError> {
79        let k = make_key(&self.bucket_name, key);
80        tracing::trace!("etcd get: {k}");
81
82        let mut kvs = self
83            .client
84            .kv_get(k, None)
85            .await
86            .map_err(|e| StoreError::EtcdError(e.to_string()))?;
87        if kvs.is_empty() {
88            return Ok(None);
89        }
90        let (_, val) = kvs.swap_remove(0).into_key_value();
91        Ok(Some(val.into()))
92    }
93
94    async fn delete(&self, key: &Key) -> Result<(), StoreError> {
95        let k = make_key(&self.bucket_name, key);
96        tracing::trace!("etcd delete: {k}");
97        let _ = self
98            .client
99            .kv_delete(k, None)
100            .await
101            .map_err(|e| StoreError::EtcdError(e.to_string()))?;
102        Ok(())
103    }
104
105    async fn watch(
106        &self,
107    ) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StoreError>
108    {
109        let k = make_key(&self.bucket_name, &"".into());
110        tracing::trace!("etcd watch: {k}");
111        let (_watcher, mut watch_stream) = self
112            .client
113            .etcd_client()
114            .clone()
115            .watch(k.as_bytes(), Some(WatchOptions::new().with_prefix()))
116            .await
117            .map_err(|e| StoreError::EtcdError(e.to_string()))?;
118        let output = stream! {
119            while let Ok(Some(resp)) = watch_stream.message().await {
120                for e in resp.events() {
121                    if matches!(e.event_type(), EventType::Put) && e.kv().is_some() {
122                        let b: bytes::Bytes = e.kv().unwrap().value().to_vec().into();
123                        yield b;
124                    }
125                }
126            }
127        };
128        Ok(Box::pin(output))
129    }
130
131    async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StoreError> {
132        let k = make_key(&self.bucket_name, &"".into());
133        tracing::trace!("etcd entries: {k}");
134
135        let resp = self
136            .client
137            .kv_get_prefix(k)
138            .await
139            .map_err(|e| StoreError::EtcdError(e.to_string()))?;
140        let out: HashMap<String, bytes::Bytes> = resp
141            .into_iter()
142            .map(|kv| {
143                let (k, v) = kv.into_key_value();
144                (String::from_utf8_lossy(&k).to_string(), v.into())
145            })
146            .collect();
147
148        Ok(out)
149    }
150}
151
152impl EtcdBucket {
153    async fn create(&self, key: &Key, value: &str) -> Result<StoreOutcome, StoreError> {
154        let k = make_key(&self.bucket_name, key);
155        tracing::trace!("etcd create: {k}");
156
157        // Use atomic transaction to check and create in one operation
158        let put_options = PutOptions::new().with_lease(self.client.primary_lease().id());
159
160        // Build transaction that creates key only if it doesn't exist
161        let txn = Txn::new()
162            .when(vec![Compare::version(k.as_str(), CompareOp::Equal, 0)]) // Atomic check
163            .and_then(vec![TxnOp::put(k.as_str(), value, Some(put_options))]) // Only if check passes
164            .or_else(vec![
165                TxnOp::get(k.as_str(), None), // Key exists, get its info
166            ]);
167
168        // Execute the transaction
169        let result = self
170            .client
171            .etcd_client()
172            .kv_client()
173            .txn(txn)
174            .await
175            .map_err(|e| StoreError::EtcdError(e.to_string()))?;
176
177        if result.succeeded() {
178            // Key was created successfully
179            return Ok(StoreOutcome::Created(1)); // version of new key is always 1
180        }
181
182        // Key already existed, get its version
183        if let Some(etcd_client::TxnOpResponse::Get(get_resp)) =
184            result.op_responses().into_iter().next()
185            && let Some(kv) = get_resp.kvs().first()
186        {
187            let version = kv.version() as u64;
188            return Ok(StoreOutcome::Exists(version));
189        }
190        // Shouldn't happen, but handle edge case
191        Err(StoreError::EtcdError(
192            "Unexpected transaction response".to_string(),
193        ))
194    }
195
196    async fn update(
197        &self,
198        key: &Key,
199        value: &str,
200        revision: u64,
201    ) -> Result<StoreOutcome, StoreError> {
202        let version = revision;
203        let k = make_key(&self.bucket_name, key);
204        tracing::trace!("etcd update: {k}");
205
206        let kvs = self
207            .client
208            .kv_get(k.clone(), None)
209            .await
210            .map_err(|e| StoreError::EtcdError(e.to_string()))?;
211        if kvs.is_empty() {
212            return Err(StoreError::MissingKey(key.to_string()));
213        }
214        let current_version = kvs.first().unwrap().version() as u64;
215        if current_version != version + 1 {
216            tracing::warn!(
217                current_version,
218                attempted_next_version = version,
219                %key,
220                "update: Wrong revision"
221            );
222            // NATS does a resync_update, overwriting the key anyway and getting the new revision.
223            // So we do too in etcd.
224        }
225
226        let put_options = PutOptions::new()
227            .with_lease(self.client.primary_lease().id())
228            .with_prev_key();
229        let mut put_resp = self
230            .client
231            .kv_put_with_options(k, value, Some(put_options))
232            .await
233            .map_err(|e| StoreError::EtcdError(e.to_string()))?;
234        Ok(match put_resp.take_prev_key() {
235            // Should this be an error?
236            // The key was deleted between our get and put. We re-created it.
237            // Version of new key is always 1.
238            // <https://etcd.io/docs/v3.5/learning/data_model/>
239            None => StoreOutcome::Created(1),
240            // Expected case, success
241            Some(kv) if kv.version() as u64 == version + 1 => StoreOutcome::Created(version),
242            // Should this be an error? Something updated the version between our get and put
243            Some(kv) => StoreOutcome::Created(kv.version() as u64 + 1),
244        })
245    }
246}
247
248fn make_key(bucket_name: &str, key: &Key) -> String {
249    [bucket_name.to_string(), key.to_string()].join("/")
250}
251
252#[cfg(feature = "integration")]
253#[cfg(test)]
254mod concurrent_create_tests {
255    use super::*;
256    use crate::{DistributedRuntime, Runtime, distributed::DistributedConfig};
257    use std::sync::Arc;
258    use tokio::sync::Barrier;
259
260    #[test]
261    fn test_concurrent_etcd_create_race_condition() {
262        let rt = Runtime::from_settings().unwrap();
263        let rt_clone = rt.clone();
264        let config = DistributedConfig::from_settings(false);
265
266        rt_clone.primary().block_on(async move {
267            let drt = DistributedRuntime::new(rt, config).await.unwrap();
268            test_concurrent_create(drt).await.unwrap();
269        });
270    }
271
272    async fn test_concurrent_create(drt: DistributedRuntime) -> Result<(), StoreError> {
273        let etcd_client = drt.etcd_client().expect("etcd client should be available");
274        let storage = EtcdStore::new(etcd_client);
275
276        // Create a bucket for testing
277        let bucket = Arc::new(tokio::sync::Mutex::new(
278            storage
279                .get_or_create_bucket("test_concurrent_bucket", None)
280                .await?,
281        ));
282
283        // Number of concurrent workers
284        let num_workers = 10;
285        let barrier = Arc::new(Barrier::new(num_workers));
286
287        // Shared test data
288        let test_key: Key = Key::new(&format!("concurrent_test_key_{}", uuid::Uuid::new_v4()));
289        let test_value = "test_value";
290
291        // Spawn multiple tasks that will all try to create the same key simultaneously
292        let mut handles = Vec::new();
293        let success_count = Arc::new(tokio::sync::Mutex::new(0));
294        let exists_count = Arc::new(tokio::sync::Mutex::new(0));
295
296        for worker_id in 0..num_workers {
297            let bucket_clone = bucket.clone();
298            let barrier_clone = barrier.clone();
299            let key_clone = test_key.clone();
300            let value_clone = format!("{}_from_worker_{}", test_value, worker_id);
301            let success_count_clone = success_count.clone();
302            let exists_count_clone = exists_count.clone();
303
304            let handle = tokio::spawn(async move {
305                // Wait for all workers to be ready
306                barrier_clone.wait().await;
307
308                // All workers try to create the same key at the same time
309                let result = bucket_clone
310                    .lock()
311                    .await
312                    .insert(&key_clone, &value_clone, 0)
313                    .await;
314
315                match result {
316                    Ok(StoreOutcome::Created(version)) => {
317                        println!(
318                            "Worker {} successfully created key with version {}",
319                            worker_id, version
320                        );
321                        let mut count = success_count_clone.lock().await;
322                        *count += 1;
323                        Ok(version)
324                    }
325                    Ok(StoreOutcome::Exists(version)) => {
326                        println!(
327                            "Worker {} found key already exists with version {}",
328                            worker_id, version
329                        );
330                        let mut count = exists_count_clone.lock().await;
331                        *count += 1;
332                        Ok(version)
333                    }
334                    Err(e) => {
335                        println!("Worker {} got error: {:?}", worker_id, e);
336                        Err(e)
337                    }
338                }
339            });
340
341            handles.push(handle);
342        }
343
344        // Wait for all workers to complete
345        let mut results = Vec::new();
346        for handle in handles {
347            let result = handle.await.unwrap();
348            if let Ok(version) = result {
349                results.push(version);
350            }
351        }
352
353        // Verify results
354        let final_success_count = *success_count.lock().await;
355        let final_exists_count = *exists_count.lock().await;
356
357        println!(
358            "Final counts - Created: {}, Exists: {}",
359            final_success_count, final_exists_count
360        );
361
362        // CRITICAL ASSERTIONS:
363        // 1. Exactly ONE worker should have successfully created the key
364        assert_eq!(
365            final_success_count, 1,
366            "Exactly one worker should create the key"
367        );
368
369        // 2. All other workers should have gotten "Exists" response
370        assert_eq!(
371            final_exists_count,
372            num_workers - 1,
373            "All other workers should see key exists"
374        );
375
376        // 3. Total successful operations should equal number of workers
377        assert_eq!(
378            results.len(),
379            num_workers,
380            "All workers should complete successfully"
381        );
382
383        // 4. Verify the key actually exists in etcd
384        let stored_value = bucket.lock().await.get(&test_key).await?;
385        assert!(stored_value.is_some(), "Key should exist in etcd");
386
387        // 5. The stored value should be from one of the workers
388        let stored_str = String::from_utf8(stored_value.unwrap().to_vec()).unwrap();
389        assert!(
390            stored_str.starts_with(test_value),
391            "Stored value should match expected prefix"
392        );
393
394        // Clean up
395        bucket.lock().await.delete(&test_key).await?;
396
397        Ok(())
398    }
399}