dynamo_runtime/storage/key_value_store/
etcd.rs1use std::collections::HashMap;
5use std::pin::Pin;
6use std::time::Duration;
7
8use crate::{storage::key_value_store::Key, transports::etcd::Client};
9use async_stream::stream;
10use async_trait::async_trait;
11use etcd_client::{Compare, CompareOp, EventType, PutOptions, Txn, TxnOp, WatchOptions};
12
13use super::{KeyValueBucket, KeyValueStore, StoreError, StoreOutcome};
14
15#[derive(Clone)]
16pub struct EtcdStore {
17 client: Client,
18}
19
20impl EtcdStore {
21 pub fn new(client: Client) -> Self {
22 Self { client }
23 }
24}
25
26#[async_trait]
27impl KeyValueStore for EtcdStore {
28 async fn get_or_create_bucket(
30 &self,
31 bucket_name: &str,
32 _ttl: Option<Duration>, ) -> Result<Box<dyn KeyValueBucket>, StoreError> {
34 Ok(self.get_bucket(bucket_name).await?.unwrap())
35 }
36
37 async fn get_bucket(
40 &self,
41 bucket_name: &str,
42 ) -> Result<Option<Box<dyn KeyValueBucket>>, StoreError> {
43 Ok(Some(Box::new(EtcdBucket {
44 client: self.client.clone(),
45 bucket_name: bucket_name.to_string(),
46 })))
47 }
48
49 fn connection_id(&self) -> u64 {
50 self.client.lease_id() as u64
53 }
54}
55
56pub struct EtcdBucket {
57 client: Client,
58 bucket_name: String,
59}
60
61#[async_trait]
62impl KeyValueBucket for EtcdBucket {
63 async fn insert(
64 &self,
65 key: &Key,
66 value: &str,
67 revision: u64,
69 ) -> Result<StoreOutcome, StoreError> {
70 let version = revision;
71 if version == 0 {
72 self.create(key, value).await
73 } else {
74 self.update(key, value, version).await
75 }
76 }
77
78 async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError> {
79 let k = make_key(&self.bucket_name, key);
80 tracing::trace!("etcd get: {k}");
81
82 let mut kvs = self
83 .client
84 .kv_get(k, None)
85 .await
86 .map_err(|e| StoreError::EtcdError(e.to_string()))?;
87 if kvs.is_empty() {
88 return Ok(None);
89 }
90 let (_, val) = kvs.swap_remove(0).into_key_value();
91 Ok(Some(val.into()))
92 }
93
94 async fn delete(&self, key: &Key) -> Result<(), StoreError> {
95 let k = make_key(&self.bucket_name, key);
96 tracing::trace!("etcd delete: {k}");
97 let _ = self
98 .client
99 .kv_delete(k, None)
100 .await
101 .map_err(|e| StoreError::EtcdError(e.to_string()))?;
102 Ok(())
103 }
104
105 async fn watch(
106 &self,
107 ) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StoreError>
108 {
109 let k = make_key(&self.bucket_name, &"".into());
110 tracing::trace!("etcd watch: {k}");
111 let (_watcher, mut watch_stream) = self
112 .client
113 .etcd_client()
114 .clone()
115 .watch(k.as_bytes(), Some(WatchOptions::new().with_prefix()))
116 .await
117 .map_err(|e| StoreError::EtcdError(e.to_string()))?;
118 let output = stream! {
119 while let Ok(Some(resp)) = watch_stream.message().await {
120 for e in resp.events() {
121 if matches!(e.event_type(), EventType::Put) && e.kv().is_some() {
122 let b: bytes::Bytes = e.kv().unwrap().value().to_vec().into();
123 yield b;
124 }
125 }
126 }
127 };
128 Ok(Box::pin(output))
129 }
130
131 async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StoreError> {
132 let k = make_key(&self.bucket_name, &"".into());
133 tracing::trace!("etcd entries: {k}");
134
135 let resp = self
136 .client
137 .kv_get_prefix(k)
138 .await
139 .map_err(|e| StoreError::EtcdError(e.to_string()))?;
140 let out: HashMap<String, bytes::Bytes> = resp
141 .into_iter()
142 .map(|kv| {
143 let (k, v) = kv.into_key_value();
144 (String::from_utf8_lossy(&k).to_string(), v.into())
145 })
146 .collect();
147
148 Ok(out)
149 }
150}
151
152impl EtcdBucket {
153 async fn create(&self, key: &Key, value: &str) -> Result<StoreOutcome, StoreError> {
154 let k = make_key(&self.bucket_name, key);
155 tracing::trace!("etcd create: {k}");
156
157 let put_options = PutOptions::new().with_lease(self.client.primary_lease().id());
159
160 let txn = Txn::new()
162 .when(vec![Compare::version(k.as_str(), CompareOp::Equal, 0)]) .and_then(vec![TxnOp::put(k.as_str(), value, Some(put_options))]) .or_else(vec![
165 TxnOp::get(k.as_str(), None), ]);
167
168 let result = self
170 .client
171 .etcd_client()
172 .kv_client()
173 .txn(txn)
174 .await
175 .map_err(|e| StoreError::EtcdError(e.to_string()))?;
176
177 if result.succeeded() {
178 return Ok(StoreOutcome::Created(1)); }
181
182 if let Some(etcd_client::TxnOpResponse::Get(get_resp)) =
184 result.op_responses().into_iter().next()
185 && let Some(kv) = get_resp.kvs().first()
186 {
187 let version = kv.version() as u64;
188 return Ok(StoreOutcome::Exists(version));
189 }
190 Err(StoreError::EtcdError(
192 "Unexpected transaction response".to_string(),
193 ))
194 }
195
196 async fn update(
197 &self,
198 key: &Key,
199 value: &str,
200 revision: u64,
201 ) -> Result<StoreOutcome, StoreError> {
202 let version = revision;
203 let k = make_key(&self.bucket_name, key);
204 tracing::trace!("etcd update: {k}");
205
206 let kvs = self
207 .client
208 .kv_get(k.clone(), None)
209 .await
210 .map_err(|e| StoreError::EtcdError(e.to_string()))?;
211 if kvs.is_empty() {
212 return Err(StoreError::MissingKey(key.to_string()));
213 }
214 let current_version = kvs.first().unwrap().version() as u64;
215 if current_version != version + 1 {
216 tracing::warn!(
217 current_version,
218 attempted_next_version = version,
219 %key,
220 "update: Wrong revision"
221 );
222 }
225
226 let put_options = PutOptions::new()
227 .with_lease(self.client.primary_lease().id())
228 .with_prev_key();
229 let mut put_resp = self
230 .client
231 .kv_put_with_options(k, value, Some(put_options))
232 .await
233 .map_err(|e| StoreError::EtcdError(e.to_string()))?;
234 Ok(match put_resp.take_prev_key() {
235 None => StoreOutcome::Created(1),
240 Some(kv) if kv.version() as u64 == version + 1 => StoreOutcome::Created(version),
242 Some(kv) => StoreOutcome::Created(kv.version() as u64 + 1),
244 })
245 }
246}
247
248fn make_key(bucket_name: &str, key: &Key) -> String {
249 [bucket_name.to_string(), key.to_string()].join("/")
250}
251
252#[cfg(feature = "integration")]
253#[cfg(test)]
254mod concurrent_create_tests {
255 use super::*;
256 use crate::{DistributedRuntime, Runtime, distributed::DistributedConfig};
257 use std::sync::Arc;
258 use tokio::sync::Barrier;
259
260 #[test]
261 fn test_concurrent_etcd_create_race_condition() {
262 let rt = Runtime::from_settings().unwrap();
263 let rt_clone = rt.clone();
264 let config = DistributedConfig::from_settings(false);
265
266 rt_clone.primary().block_on(async move {
267 let drt = DistributedRuntime::new(rt, config).await.unwrap();
268 test_concurrent_create(drt).await.unwrap();
269 });
270 }
271
272 async fn test_concurrent_create(drt: DistributedRuntime) -> Result<(), StoreError> {
273 let etcd_client = drt.etcd_client().expect("etcd client should be available");
274 let storage = EtcdStore::new(etcd_client);
275
276 let bucket = Arc::new(tokio::sync::Mutex::new(
278 storage
279 .get_or_create_bucket("test_concurrent_bucket", None)
280 .await?,
281 ));
282
283 let num_workers = 10;
285 let barrier = Arc::new(Barrier::new(num_workers));
286
287 let test_key: Key = Key::new(&format!("concurrent_test_key_{}", uuid::Uuid::new_v4()));
289 let test_value = "test_value";
290
291 let mut handles = Vec::new();
293 let success_count = Arc::new(tokio::sync::Mutex::new(0));
294 let exists_count = Arc::new(tokio::sync::Mutex::new(0));
295
296 for worker_id in 0..num_workers {
297 let bucket_clone = bucket.clone();
298 let barrier_clone = barrier.clone();
299 let key_clone = test_key.clone();
300 let value_clone = format!("{}_from_worker_{}", test_value, worker_id);
301 let success_count_clone = success_count.clone();
302 let exists_count_clone = exists_count.clone();
303
304 let handle = tokio::spawn(async move {
305 barrier_clone.wait().await;
307
308 let result = bucket_clone
310 .lock()
311 .await
312 .insert(&key_clone, &value_clone, 0)
313 .await;
314
315 match result {
316 Ok(StoreOutcome::Created(version)) => {
317 println!(
318 "Worker {} successfully created key with version {}",
319 worker_id, version
320 );
321 let mut count = success_count_clone.lock().await;
322 *count += 1;
323 Ok(version)
324 }
325 Ok(StoreOutcome::Exists(version)) => {
326 println!(
327 "Worker {} found key already exists with version {}",
328 worker_id, version
329 );
330 let mut count = exists_count_clone.lock().await;
331 *count += 1;
332 Ok(version)
333 }
334 Err(e) => {
335 println!("Worker {} got error: {:?}", worker_id, e);
336 Err(e)
337 }
338 }
339 });
340
341 handles.push(handle);
342 }
343
344 let mut results = Vec::new();
346 for handle in handles {
347 let result = handle.await.unwrap();
348 if let Ok(version) = result {
349 results.push(version);
350 }
351 }
352
353 let final_success_count = *success_count.lock().await;
355 let final_exists_count = *exists_count.lock().await;
356
357 println!(
358 "Final counts - Created: {}, Exists: {}",
359 final_success_count, final_exists_count
360 );
361
362 assert_eq!(
365 final_success_count, 1,
366 "Exactly one worker should create the key"
367 );
368
369 assert_eq!(
371 final_exists_count,
372 num_workers - 1,
373 "All other workers should see key exists"
374 );
375
376 assert_eq!(
378 results.len(),
379 num_workers,
380 "All workers should complete successfully"
381 );
382
383 let stored_value = bucket.lock().await.get(&test_key).await?;
385 assert!(stored_value.is_some(), "Key should exist in etcd");
386
387 let stored_str = String::from_utf8(stored_value.unwrap().to_vec()).unwrap();
389 assert!(
390 stored_str.starts_with(test_value),
391 "Stored value should match expected prefix"
392 );
393
394 bucket.lock().await.delete(&test_key).await?;
396
397 Ok(())
398 }
399}