dynamo_runtime/storage/kv/
etcd.rs1use std::collections::HashMap;
5use std::pin::Pin;
6use std::time::Duration;
7
8use crate::transports::etcd;
9use async_stream::stream;
10use async_trait::async_trait;
11use etcd_client::{Compare, CompareOp, EventType, PutOptions, Txn, TxnOp, WatchOptions};
12
13use super::{Bucket, Key, KeyValue, Store, StoreError, StoreOutcome, WatchEvent};
14
15#[derive(Clone)]
16pub struct EtcdStore {
17 client: etcd::Client,
18}
19
20impl EtcdStore {
21 pub fn new(client: etcd::Client) -> Self {
22 Self { client }
23 }
24}
25
26#[async_trait]
27impl Store for EtcdStore {
28 type Bucket = EtcdBucket;
29
30 async fn get_or_create_bucket(
32 &self,
33 bucket_name: &str,
34 _ttl: Option<Duration>, ) -> Result<Self::Bucket, StoreError> {
36 Ok(EtcdBucket {
37 client: self.client.clone(),
38 bucket_name: bucket_name.to_string(),
39 })
40 }
41
42 async fn get_bucket(&self, bucket_name: &str) -> Result<Option<Self::Bucket>, StoreError> {
45 Ok(Some(EtcdBucket {
46 client: self.client.clone(),
47 bucket_name: bucket_name.to_string(),
48 }))
49 }
50
51 fn connection_id(&self) -> u64 {
52 self.client.lease_id()
53 }
54
55 fn shutdown(&self) {
56 }
58}
59
60pub struct EtcdBucket {
61 client: etcd::Client,
62 bucket_name: String,
63}
64
65#[async_trait]
66impl Bucket for EtcdBucket {
67 async fn insert(
68 &self,
69 key: &Key,
70 value: bytes::Bytes,
71 revision: u64,
73 ) -> Result<StoreOutcome, StoreError> {
74 let version = revision;
75 if version == 0 {
76 self.create(key, value).await
77 } else {
78 self.update(key, value, version).await
79 }
80 }
81
82 async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError> {
83 let k = make_key(&self.bucket_name, key);
84 tracing::trace!("etcd get: {k}");
85
86 let mut kvs = self
87 .client
88 .kv_get(k, None)
89 .await
90 .map_err(|e| StoreError::EtcdError(e.to_string()))?;
91 if kvs.is_empty() {
92 return Ok(None);
93 }
94 let (_, val) = kvs.swap_remove(0).into_key_value();
95 Ok(Some(val.into()))
96 }
97
98 async fn delete(&self, key: &Key) -> Result<(), StoreError> {
99 let k = make_key(&self.bucket_name, key);
100 tracing::trace!("etcd delete: {k}");
101 let _ = self
102 .client
103 .kv_delete(k, None)
104 .await
105 .map_err(|e| StoreError::EtcdError(e.to_string()))?;
106 Ok(())
107 }
108
109 async fn watch(
110 &self,
111 ) -> Result<Pin<Box<dyn futures::Stream<Item = WatchEvent> + Send + 'life0>>, StoreError> {
112 let prefix = make_key(&self.bucket_name, &"".into());
113 tracing::trace!("etcd watch: {prefix}");
114 let watcher = self
115 .client
116 .kv_watch_prefix(&prefix)
117 .await
118 .map_err(|e| StoreError::EtcdError(e.to_string()))?;
119 let (_, mut watch_stream) = watcher.dissolve();
120 let output = stream! {
121 while let Some(event) = watch_stream.recv().await {
122 match event {
123 etcd::WatchEvent::Put(kv) => {
124 let (k, v) = kv.into_key_value();
125 let key = match String::from_utf8(k) {
126 Ok(k) => Key::new(k),
127 Err(err) => {
128 tracing::error!(%err, prefix, "Invalid UTF8 in etcd key");
129 continue;
130 }
131 };
132 let item = KeyValue::new(key, v.into());
133 yield WatchEvent::Put(item);
134 }
135 etcd::WatchEvent::Delete(kv) => {
136 let (k, _) = kv.into_key_value();
137 let key = match String::from_utf8(k) {
138 Ok(k) => Key::new(k),
139 Err(err) => {
140 tracing::error!(%err, prefix, "Invalid UTF8 in etcd key");
141 continue;
142 }
143 };
144 yield WatchEvent::Delete(key);
145 }
146 }
147 }
148 };
149 Ok(Box::pin(output))
150 }
151
152 async fn entries(&self) -> Result<HashMap<Key, bytes::Bytes>, StoreError> {
153 let k = make_key(&self.bucket_name, &"".into());
154 tracing::trace!("etcd entries: {k}");
155
156 let resp = self
157 .client
158 .kv_get_prefix(k)
159 .await
160 .map_err(|e| StoreError::EtcdError(e.to_string()))?;
161 let out: HashMap<Key, bytes::Bytes> = resp
162 .into_iter()
163 .map(|kv| {
164 let (k, v) = kv.into_key_value();
165 (Key::new(String::from_utf8_lossy(&k).to_string()), v.into())
166 })
167 .collect();
168
169 Ok(out)
170 }
171}
172
173impl EtcdBucket {
174 async fn create(
175 &self,
176 key: &Key,
177 value: impl Into<Vec<u8>>,
178 ) -> Result<StoreOutcome, StoreError> {
179 let k = make_key(&self.bucket_name, key);
180 tracing::trace!("etcd create: {k}");
181
182 match self
183 .client
184 .kv_create(k.as_str(), value.into(), None)
185 .await
186 .map_err(|e| StoreError::EtcdError(e.to_string()))?
187 {
188 None => {
189 Ok(StoreOutcome::Created(1)) }
192 Some(revision) => Ok(StoreOutcome::Exists(revision)),
193 }
194 }
195
196 async fn update(
197 &self,
198 key: &Key,
199 value: impl AsRef<[u8]>,
200 revision: u64,
201 ) -> Result<StoreOutcome, StoreError> {
202 let version = revision;
203 let k = make_key(&self.bucket_name, key);
204 tracing::trace!("etcd update: {k}");
205
206 let kvs = self
207 .client
208 .kv_get(k.clone(), None)
209 .await
210 .map_err(|e| StoreError::EtcdError(e.to_string()))?;
211 if kvs.is_empty() {
212 return Err(StoreError::MissingKey(key.to_string()));
213 }
214 let current_version = kvs.first().unwrap().version() as u64;
215 if current_version != version + 1 {
216 tracing::warn!(
217 current_version,
218 attempted_next_version = version,
219 %key,
220 "update: Wrong revision"
221 );
222 }
225
226 let put_options = PutOptions::new()
227 .with_lease(self.client.lease_id() as i64)
228 .with_prev_key();
229 let mut put_resp = self
230 .client
231 .kv_put_with_options(k, value, Some(put_options))
232 .await
233 .map_err(|e| StoreError::EtcdError(e.to_string()))?;
234 Ok(match put_resp.take_prev_key() {
235 None => StoreOutcome::Created(1),
240 Some(kv) if kv.version() as u64 == version + 1 => StoreOutcome::Created(version),
242 Some(kv) => StoreOutcome::Created(kv.version() as u64 + 1),
244 })
245 }
246}
247
248fn make_key(bucket_name: &str, key: &Key) -> String {
249 [bucket_name.to_string(), key.to_string()].join("/")
250}
251
252#[cfg(feature = "integration")]
253#[cfg(test)]
254mod concurrent_create_tests {
255 use super::*;
256 use crate::{DistributedRuntime, Runtime, distributed::DistributedConfig};
257 use std::sync::Arc;
258 use tokio::sync::Barrier;
259
260 #[test]
261 fn test_concurrent_etcd_create_race_condition() {
262 let rt = Runtime::from_settings().unwrap();
263 let rt_clone = rt.clone();
264 let config = DistributedConfig::from_settings();
265
266 rt_clone.primary().block_on(async move {
267 let drt = DistributedRuntime::new(rt, config).await.unwrap();
268 test_concurrent_create(drt).await.unwrap();
269 });
270 }
271
272 async fn test_concurrent_create(drt: DistributedRuntime) -> Result<(), StoreError> {
273 let storage = drt.store();
274
275 let bucket = Arc::new(tokio::sync::Mutex::new(
277 storage
278 .get_or_create_bucket("test_concurrent_bucket", None)
279 .await?,
280 ));
281
282 let num_workers = 10;
284 let barrier = Arc::new(Barrier::new(num_workers));
285
286 let test_key: Key = Key::new(format!("concurrent_test_key_{}", uuid::Uuid::new_v4()));
288 let test_value = "test_value";
289
290 let mut handles = Vec::new();
292 let success_count = Arc::new(tokio::sync::Mutex::new(0));
293 let exists_count = Arc::new(tokio::sync::Mutex::new(0));
294
295 for worker_id in 0..num_workers {
296 let bucket_clone = bucket.clone();
297 let barrier_clone = barrier.clone();
298 let key_clone = test_key.clone();
299 let value_clone = format!("{}_from_worker_{}", test_value, worker_id);
300 let success_count_clone = success_count.clone();
301 let exists_count_clone = exists_count.clone();
302
303 let handle = tokio::spawn(async move {
304 barrier_clone.wait().await;
306
307 let result = bucket_clone
309 .lock()
310 .await
311 .insert(&key_clone, value_clone.into(), 0)
312 .await;
313
314 match result {
315 Ok(StoreOutcome::Created(version)) => {
316 println!(
317 "Worker {} successfully created key with version {}",
318 worker_id, version
319 );
320 let mut count = success_count_clone.lock().await;
321 *count += 1;
322 Ok(version)
323 }
324 Ok(StoreOutcome::Exists(version)) => {
325 println!(
326 "Worker {} found key already exists with version {}",
327 worker_id, version
328 );
329 let mut count = exists_count_clone.lock().await;
330 *count += 1;
331 Ok(version)
332 }
333 Err(e) => {
334 println!("Worker {} got error: {:?}", worker_id, e);
335 Err(e)
336 }
337 }
338 });
339
340 handles.push(handle);
341 }
342
343 let mut results = Vec::new();
345 for handle in handles {
346 let result = handle.await.unwrap();
347 if let Ok(version) = result {
348 results.push(version);
349 }
350 }
351
352 let final_success_count = *success_count.lock().await;
354 let final_exists_count = *exists_count.lock().await;
355
356 println!(
357 "Final counts - Created: {}, Exists: {}",
358 final_success_count, final_exists_count
359 );
360
361 assert_eq!(
364 final_success_count, 1,
365 "Exactly one worker should create the key"
366 );
367
368 assert_eq!(
370 final_exists_count,
371 num_workers - 1,
372 "All other workers should see key exists"
373 );
374
375 assert_eq!(
377 results.len(),
378 num_workers,
379 "All workers should complete successfully"
380 );
381
382 let stored_value = bucket.lock().await.get(&test_key).await?;
384 assert!(stored_value.is_some(), "Key should exist in etcd");
385
386 let stored_str = String::from_utf8(stored_value.unwrap().to_vec()).unwrap();
388 assert!(
389 stored_str.starts_with(test_value),
390 "Stored value should match expected prefix"
391 );
392
393 bucket.lock().await.delete(&test_key).await?;
395
396 Ok(())
397 }
398}