use std::collections::HashMap;
use std::pin::Pin;
use std::time::Duration;
use crate::transports::etcd;
use async_stream::stream;
use async_trait::async_trait;
use etcd_client::{Compare, CompareOp, EventType, PutOptions, Txn, TxnOp, WatchOptions};
use super::{Bucket, Key, KeyValue, Store, StoreError, StoreOutcome, WatchEvent};
#[derive(Clone)]
pub struct EtcdStore {
client: etcd::Client,
}
impl EtcdStore {
pub fn new(client: etcd::Client) -> Self {
Self { client }
}
}
#[async_trait]
impl Store for EtcdStore {
type Bucket = EtcdBucket;
async fn get_or_create_bucket(
&self,
bucket_name: &str,
_ttl: Option<Duration>, ) -> Result<Self::Bucket, StoreError> {
Ok(EtcdBucket {
client: self.client.clone(),
bucket_name: bucket_name.to_string(),
})
}
async fn get_bucket(&self, bucket_name: &str) -> Result<Option<Self::Bucket>, StoreError> {
Ok(Some(EtcdBucket {
client: self.client.clone(),
bucket_name: bucket_name.to_string(),
}))
}
fn connection_id(&self) -> u64 {
self.client.lease_id()
}
fn shutdown(&self) {
}
}
pub struct EtcdBucket {
client: etcd::Client,
bucket_name: String,
}
#[async_trait]
impl Bucket for EtcdBucket {
async fn insert(
&self,
key: &Key,
value: bytes::Bytes,
revision: u64,
) -> Result<StoreOutcome, StoreError> {
let version = revision;
if version == 0 {
self.create(key, value).await
} else {
self.update(key, value, version).await
}
}
async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError> {
let k = make_key(&self.bucket_name, key);
tracing::trace!("etcd get: {k}");
let mut kvs = self
.client
.kv_get(k, None)
.await
.map_err(|e| StoreError::EtcdError(e.to_string()))?;
if kvs.is_empty() {
return Ok(None);
}
let (_, val) = kvs.swap_remove(0).into_key_value();
Ok(Some(val.into()))
}
async fn delete(&self, key: &Key) -> Result<(), StoreError> {
let k = make_key(&self.bucket_name, key);
tracing::trace!("etcd delete: {k}");
let _ = self
.client
.kv_delete(k, None)
.await
.map_err(|e| StoreError::EtcdError(e.to_string()))?;
Ok(())
}
async fn watch(
&self,
) -> Result<Pin<Box<dyn futures::Stream<Item = WatchEvent> + Send + 'life0>>, StoreError> {
let prefix = make_key(&self.bucket_name, &"".into());
tracing::trace!("etcd watch: {prefix}");
let watcher = self
.client
.kv_watch_prefix(&prefix)
.await
.map_err(|e| StoreError::EtcdError(e.to_string()))?;
let (_, mut watch_stream) = watcher.dissolve();
let output = stream! {
while let Some(event) = watch_stream.recv().await {
match event {
etcd::WatchEvent::Put(kv) => {
let (k, v) = kv.into_key_value();
let key = match String::from_utf8(k) {
Ok(k) => Key::new(k),
Err(err) => {
tracing::error!(%err, prefix, "Invalid UTF8 in etcd key");
continue;
}
};
let item = KeyValue::new(key, v.into());
yield WatchEvent::Put(item);
}
etcd::WatchEvent::Delete(kv) => {
let (k, _) = kv.into_key_value();
let key = match String::from_utf8(k) {
Ok(k) => Key::new(k),
Err(err) => {
tracing::error!(%err, prefix, "Invalid UTF8 in etcd key");
continue;
}
};
yield WatchEvent::Delete(key);
}
}
}
};
Ok(Box::pin(output))
}
async fn entries(&self) -> Result<HashMap<Key, bytes::Bytes>, StoreError> {
let k = make_key(&self.bucket_name, &"".into());
tracing::trace!("etcd entries: {k}");
let resp = self
.client
.kv_get_prefix(k)
.await
.map_err(|e| StoreError::EtcdError(e.to_string()))?;
let out: HashMap<Key, bytes::Bytes> = resp
.into_iter()
.map(|kv| {
let (k, v) = kv.into_key_value();
(Key::new(String::from_utf8_lossy(&k).to_string()), v.into())
})
.collect();
Ok(out)
}
}
impl EtcdBucket {
async fn create(
&self,
key: &Key,
value: impl Into<Vec<u8>>,
) -> Result<StoreOutcome, StoreError> {
let k = make_key(&self.bucket_name, key);
tracing::trace!("etcd create: {k}");
match self
.client
.kv_create(k.as_str(), value.into(), None)
.await
.map_err(|e| StoreError::EtcdError(e.to_string()))?
{
None => {
Ok(StoreOutcome::Created(1)) }
Some(revision) => Ok(StoreOutcome::Exists(revision)),
}
}
async fn update(
&self,
key: &Key,
value: impl AsRef<[u8]>,
revision: u64,
) -> Result<StoreOutcome, StoreError> {
let version = revision;
let k = make_key(&self.bucket_name, key);
tracing::trace!("etcd update: {k}");
let kvs = self
.client
.kv_get(k.clone(), None)
.await
.map_err(|e| StoreError::EtcdError(e.to_string()))?;
if kvs.is_empty() {
return Err(StoreError::MissingKey(key.to_string()));
}
let current_version = kvs.first().unwrap().version() as u64;
if current_version != version + 1 {
tracing::warn!(
current_version,
attempted_next_version = version,
%key,
"update: Wrong revision"
);
}
let put_options = PutOptions::new()
.with_lease(self.client.lease_id() as i64)
.with_prev_key();
let mut put_resp = self
.client
.kv_put_with_options(k, value, Some(put_options))
.await
.map_err(|e| StoreError::EtcdError(e.to_string()))?;
Ok(match put_resp.take_prev_key() {
None => StoreOutcome::Created(1),
Some(kv) if kv.version() as u64 == version + 1 => StoreOutcome::Created(version),
Some(kv) => StoreOutcome::Created(kv.version() as u64 + 1),
})
}
}
fn make_key(bucket_name: &str, key: &Key) -> String {
[bucket_name.to_string(), key.to_string()].join("/")
}
#[cfg(feature = "integration")]
#[cfg(test)]
mod concurrent_create_tests {
use super::*;
use crate::Runtime;
use crate::transports::etcd as etcd_transport;
use std::sync::Arc;
use tokio::sync::Barrier;
#[test]
fn test_concurrent_etcd_create_race_condition() {
let rt = Runtime::from_settings().unwrap();
let rt_clone = rt.clone();
rt_clone.primary().block_on(async move {
let etcd_client =
etcd_transport::Client::new(etcd_transport::ClientOptions::default(), rt)
.await
.unwrap();
let storage = crate::storage::kv::Manager::etcd(etcd_client);
test_concurrent_create(&storage).await.unwrap();
});
}
async fn test_concurrent_create(
storage: &crate::storage::kv::Manager,
) -> Result<(), StoreError> {
let bucket = Arc::new(tokio::sync::Mutex::new(
storage
.get_or_create_bucket("test_concurrent_bucket", None)
.await?,
));
let num_workers = 10;
let barrier = Arc::new(Barrier::new(num_workers));
let test_key: Key = Key::new(format!("concurrent_test_key_{}", uuid::Uuid::new_v4()));
let test_value = "test_value";
let mut handles = Vec::new();
let success_count = Arc::new(tokio::sync::Mutex::new(0));
let exists_count = Arc::new(tokio::sync::Mutex::new(0));
for worker_id in 0..num_workers {
let bucket_clone = bucket.clone();
let barrier_clone = barrier.clone();
let key_clone = test_key.clone();
let value_clone = format!("{}_from_worker_{}", test_value, worker_id);
let success_count_clone = success_count.clone();
let exists_count_clone = exists_count.clone();
let handle = tokio::spawn(async move {
barrier_clone.wait().await;
let result = bucket_clone
.lock()
.await
.insert(&key_clone, value_clone.into(), 0)
.await;
match result {
Ok(StoreOutcome::Created(version)) => {
println!(
"Worker {} successfully created key with version {}",
worker_id, version
);
let mut count = success_count_clone.lock().await;
*count += 1;
Ok(version)
}
Ok(StoreOutcome::Exists(version)) => {
println!(
"Worker {} found key already exists with version {}",
worker_id, version
);
let mut count = exists_count_clone.lock().await;
*count += 1;
Ok(version)
}
Err(e) => {
println!("Worker {} got error: {:?}", worker_id, e);
Err(e)
}
}
});
handles.push(handle);
}
let mut results = Vec::new();
for handle in handles {
let result = handle.await.unwrap();
if let Ok(version) = result {
results.push(version);
}
}
let final_success_count = *success_count.lock().await;
let final_exists_count = *exists_count.lock().await;
println!(
"Final counts - Created: {}, Exists: {}",
final_success_count, final_exists_count
);
assert_eq!(
final_success_count, 1,
"Exactly one worker should create the key"
);
assert_eq!(
final_exists_count,
num_workers - 1,
"All other workers should see key exists"
);
assert_eq!(
results.len(),
num_workers,
"All workers should complete successfully"
);
let stored_value = bucket.lock().await.get(&test_key).await?;
assert!(stored_value.is_some(), "Key should exist in etcd");
let stored_str = String::from_utf8(stored_value.unwrap().to_vec()).unwrap();
assert!(
stored_str.starts_with(test_value),
"Stored value should match expected prefix"
);
bucket.lock().await.delete(&test_key).await?;
Ok(())
}
}