dynamo_runtime/storage/key_value_store/
etcd.rs1use std::collections::HashMap;
5use std::pin::Pin;
6use std::time::Duration;
7
8use crate::{
9 storage::key_value_store::{Key, KeyValue, WatchEvent},
10 transports::etcd::Client,
11};
12use async_stream::stream;
13use async_trait::async_trait;
14use etcd_client::{Compare, CompareOp, EventType, PutOptions, Txn, TxnOp, WatchOptions};
15
16use super::{KeyValueBucket, KeyValueStore, StoreError, StoreOutcome};
17
18#[derive(Clone)]
19pub struct EtcdStore {
20 client: Client,
21}
22
23impl EtcdStore {
24 pub fn new(client: Client) -> Self {
25 Self { client }
26 }
27}
28
29#[async_trait]
30impl KeyValueStore for EtcdStore {
31 type Bucket = EtcdBucket;
32
33 async fn get_or_create_bucket(
35 &self,
36 bucket_name: &str,
37 _ttl: Option<Duration>, ) -> Result<Self::Bucket, StoreError> {
39 Ok(EtcdBucket {
40 client: self.client.clone(),
41 bucket_name: bucket_name.to_string(),
42 })
43 }
44
45 async fn get_bucket(&self, bucket_name: &str) -> Result<Option<Self::Bucket>, StoreError> {
48 Ok(Some(EtcdBucket {
49 client: self.client.clone(),
50 bucket_name: bucket_name.to_string(),
51 }))
52 }
53
54 fn connection_id(&self) -> u64 {
55 self.client.lease_id()
56 }
57}
58
59pub struct EtcdBucket {
60 client: Client,
61 bucket_name: String,
62}
63
64#[async_trait]
65impl KeyValueBucket for EtcdBucket {
66 async fn insert(
67 &self,
68 key: &Key,
69 value: &str,
70 revision: u64,
72 ) -> Result<StoreOutcome, StoreError> {
73 let version = revision;
74 if version == 0 {
75 self.create(key, value).await
76 } else {
77 self.update(key, value, version).await
78 }
79 }
80
81 async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError> {
82 let k = make_key(&self.bucket_name, key);
83 tracing::trace!("etcd get: {k}");
84
85 let mut kvs = self
86 .client
87 .kv_get(k, None)
88 .await
89 .map_err(|e| StoreError::EtcdError(e.to_string()))?;
90 if kvs.is_empty() {
91 return Ok(None);
92 }
93 let (_, val) = kvs.swap_remove(0).into_key_value();
94 Ok(Some(val.into()))
95 }
96
97 async fn delete(&self, key: &Key) -> Result<(), StoreError> {
98 let k = make_key(&self.bucket_name, key);
99 tracing::trace!("etcd delete: {k}");
100 let _ = self
101 .client
102 .kv_delete(k, None)
103 .await
104 .map_err(|e| StoreError::EtcdError(e.to_string()))?;
105 Ok(())
106 }
107
108 async fn watch(
109 &self,
110 ) -> Result<Pin<Box<dyn futures::Stream<Item = WatchEvent> + Send + 'life0>>, StoreError> {
111 let prefix = make_key(&self.bucket_name, &"".into());
112 tracing::trace!("etcd watch: {prefix}");
113 let (watcher, mut watch_stream) = self
114 .client
115 .etcd_client()
116 .clone()
117 .watch(prefix.as_bytes(), Some(WatchOptions::new().with_prefix()))
118 .await
119 .map_err(|e| StoreError::EtcdError(e.to_string()))?;
120 let output = stream! {
121 let _watcher = watcher; while let Ok(Some(resp)) = watch_stream.message().await {
123 for e in resp.events() {
124 let Some(kv) = e.kv() else {
125 continue;
126 };
127 let (k_bytes, v_bytes) = kv.clone().into_key_value();
128 let key = match String::from_utf8(k_bytes) {
129 Ok(k) => k,
130 Err(err) => {
131 tracing::error!(%err, prefix, "Invalid UTF8 in etcd key");
132 continue;
133 }
134 };
135 let item = KeyValue::new(key, v_bytes.into());
136 match e.event_type() {
137 EventType::Put => {
138 yield WatchEvent::Put(item);
139 }
140 EventType::Delete => {
141 yield WatchEvent::Delete(item);
142 }
143 }
144 }
145 }
146 };
147 Ok(Box::pin(output))
148 }
149
150 async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StoreError> {
151 let k = make_key(&self.bucket_name, &"".into());
152 tracing::trace!("etcd entries: {k}");
153
154 let resp = self
155 .client
156 .kv_get_prefix(k)
157 .await
158 .map_err(|e| StoreError::EtcdError(e.to_string()))?;
159 let out: HashMap<String, bytes::Bytes> = resp
160 .into_iter()
161 .map(|kv| {
162 let (k, v) = kv.into_key_value();
163 (String::from_utf8_lossy(&k).to_string(), v.into())
164 })
165 .collect();
166
167 Ok(out)
168 }
169}
170
171impl EtcdBucket {
172 async fn create(&self, key: &Key, value: &str) -> Result<StoreOutcome, StoreError> {
173 let k = make_key(&self.bucket_name, key);
174 tracing::trace!("etcd create: {k}");
175
176 let put_options = PutOptions::new().with_lease(self.client.primary_lease().id() as i64);
178
179 let txn = Txn::new()
181 .when(vec![Compare::version(k.as_str(), CompareOp::Equal, 0)]) .and_then(vec![TxnOp::put(k.as_str(), value, Some(put_options))]) .or_else(vec![
184 TxnOp::get(k.as_str(), None), ]);
186
187 let result = self
189 .client
190 .etcd_client()
191 .kv_client()
192 .txn(txn)
193 .await
194 .map_err(|e| StoreError::EtcdError(e.to_string()))?;
195
196 if result.succeeded() {
197 return Ok(StoreOutcome::Created(1)); }
200
201 if let Some(etcd_client::TxnOpResponse::Get(get_resp)) =
203 result.op_responses().into_iter().next()
204 && let Some(kv) = get_resp.kvs().first()
205 {
206 let version = kv.version() as u64;
207 return Ok(StoreOutcome::Exists(version));
208 }
209 Err(StoreError::EtcdError(
211 "Unexpected transaction response".to_string(),
212 ))
213 }
214
215 async fn update(
216 &self,
217 key: &Key,
218 value: &str,
219 revision: u64,
220 ) -> Result<StoreOutcome, StoreError> {
221 let version = revision;
222 let k = make_key(&self.bucket_name, key);
223 tracing::trace!("etcd update: {k}");
224
225 let kvs = self
226 .client
227 .kv_get(k.clone(), None)
228 .await
229 .map_err(|e| StoreError::EtcdError(e.to_string()))?;
230 if kvs.is_empty() {
231 return Err(StoreError::MissingKey(key.to_string()));
232 }
233 let current_version = kvs.first().unwrap().version() as u64;
234 if current_version != version + 1 {
235 tracing::warn!(
236 current_version,
237 attempted_next_version = version,
238 %key,
239 "update: Wrong revision"
240 );
241 }
244
245 let put_options = PutOptions::new()
246 .with_lease(self.client.primary_lease().id() as i64)
247 .with_prev_key();
248 let mut put_resp = self
249 .client
250 .kv_put_with_options(k, value, Some(put_options))
251 .await
252 .map_err(|e| StoreError::EtcdError(e.to_string()))?;
253 Ok(match put_resp.take_prev_key() {
254 None => StoreOutcome::Created(1),
259 Some(kv) if kv.version() as u64 == version + 1 => StoreOutcome::Created(version),
261 Some(kv) => StoreOutcome::Created(kv.version() as u64 + 1),
263 })
264 }
265}
266
267fn make_key(bucket_name: &str, key: &Key) -> String {
268 [bucket_name.to_string(), key.to_string()].join("/")
269}
270
271#[cfg(feature = "integration")]
272#[cfg(test)]
273mod concurrent_create_tests {
274 use super::*;
275 use crate::{DistributedRuntime, Runtime, distributed::DistributedConfig};
276 use std::sync::Arc;
277 use tokio::sync::Barrier;
278
279 #[test]
280 fn test_concurrent_etcd_create_race_condition() {
281 let rt = Runtime::from_settings().unwrap();
282 let rt_clone = rt.clone();
283 let config = DistributedConfig::from_settings(false);
284
285 rt_clone.primary().block_on(async move {
286 let drt = DistributedRuntime::new(rt, config).await.unwrap();
287 test_concurrent_create(drt).await.unwrap();
288 });
289 }
290
291 async fn test_concurrent_create(drt: DistributedRuntime) -> Result<(), StoreError> {
292 let etcd_client = drt.etcd_client().expect("etcd client should be available");
293 let storage = EtcdStore::new(etcd_client);
294
295 let bucket = Arc::new(tokio::sync::Mutex::new(
297 storage
298 .get_or_create_bucket("test_concurrent_bucket", None)
299 .await?,
300 ));
301
302 let num_workers = 10;
304 let barrier = Arc::new(Barrier::new(num_workers));
305
306 let test_key: Key = Key::new(&format!("concurrent_test_key_{}", uuid::Uuid::new_v4()));
308 let test_value = "test_value";
309
310 let mut handles = Vec::new();
312 let success_count = Arc::new(tokio::sync::Mutex::new(0));
313 let exists_count = Arc::new(tokio::sync::Mutex::new(0));
314
315 for worker_id in 0..num_workers {
316 let bucket_clone = bucket.clone();
317 let barrier_clone = barrier.clone();
318 let key_clone = test_key.clone();
319 let value_clone = format!("{}_from_worker_{}", test_value, worker_id);
320 let success_count_clone = success_count.clone();
321 let exists_count_clone = exists_count.clone();
322
323 let handle = tokio::spawn(async move {
324 barrier_clone.wait().await;
326
327 let result = bucket_clone
329 .lock()
330 .await
331 .insert(&key_clone, &value_clone, 0)
332 .await;
333
334 match result {
335 Ok(StoreOutcome::Created(version)) => {
336 println!(
337 "Worker {} successfully created key with version {}",
338 worker_id, version
339 );
340 let mut count = success_count_clone.lock().await;
341 *count += 1;
342 Ok(version)
343 }
344 Ok(StoreOutcome::Exists(version)) => {
345 println!(
346 "Worker {} found key already exists with version {}",
347 worker_id, version
348 );
349 let mut count = exists_count_clone.lock().await;
350 *count += 1;
351 Ok(version)
352 }
353 Err(e) => {
354 println!("Worker {} got error: {:?}", worker_id, e);
355 Err(e)
356 }
357 }
358 });
359
360 handles.push(handle);
361 }
362
363 let mut results = Vec::new();
365 for handle in handles {
366 let result = handle.await.unwrap();
367 if let Ok(version) = result {
368 results.push(version);
369 }
370 }
371
372 let final_success_count = *success_count.lock().await;
374 let final_exists_count = *exists_count.lock().await;
375
376 println!(
377 "Final counts - Created: {}, Exists: {}",
378 final_success_count, final_exists_count
379 );
380
381 assert_eq!(
384 final_success_count, 1,
385 "Exactly one worker should create the key"
386 );
387
388 assert_eq!(
390 final_exists_count,
391 num_workers - 1,
392 "All other workers should see key exists"
393 );
394
395 assert_eq!(
397 results.len(),
398 num_workers,
399 "All workers should complete successfully"
400 );
401
402 let stored_value = bucket.lock().await.get(&test_key).await?;
404 assert!(stored_value.is_some(), "Key should exist in etcd");
405
406 let stored_str = String::from_utf8(stored_value.unwrap().to_vec()).unwrap();
408 assert!(
409 stored_str.starts_with(test_value),
410 "Stored value should match expected prefix"
411 );
412
413 bucket.lock().await.delete(&test_key).await?;
415
416 Ok(())
417 }
418}