use std::cell::RefCell;
use std::convert::Infallible;
use std::rc::Rc;
use std::time::Duration;
use crate::proxy_wasm::types::{Bytes, Status};
use log::{trace, warn};
use serde::de::DeserializeOwned;
use serde::Serialize;
use crate::extract::{Extract, FromContext};
use crate::host::clock::Clock;
use crate::host::shared_data::SharedData;
use crate::utils::random_generator;
use super::in_memory_cache::InMemoryCache;
#[derive(Clone, PartialEq, Eq, Debug)]
pub enum TransactionStatus {
Complete,
InternalError,
Rejected,
}
#[doc(hidden)]
#[derive(Clone, PartialEq, Eq, Debug)]
pub enum UpdateError {
Desist,
}
const CACHE_EXPIRATION_TIME_IN_MILLIS: Duration = Duration::from_secs(120);
pub struct ConcurrentSharedData {
shared_data: Rc<dyn SharedData>,
lock_version_cache: RefCell<InMemoryCache<Option<u32>>>,
}
impl<C> FromContext<C> for ConcurrentSharedData
where
Rc<dyn Clock>: FromContext<C, Error = Infallible>,
Rc<dyn SharedData>: FromContext<C, Error = Infallible>,
{
type Error = Infallible;
fn from_context(context: &C) -> Result<Self, Self::Error> {
let clock: Rc<dyn Clock> = context.extract()?;
let shared_data: Rc<dyn SharedData> = context.extract()?;
Ok(ConcurrentSharedData::new(clock, shared_data))
}
}
impl ConcurrentSharedData {
pub fn new(clock: Rc<dyn Clock>, shared_data: Rc<dyn SharedData>) -> Self {
Self {
shared_data,
lock_version_cache: RefCell::new(InMemoryCache::new(
clock,
CACHE_EXPIRATION_TIME_IN_MILLIS,
)),
}
}
pub fn insert<T, F>(
&self,
key: String,
data: T,
handle_consistency: F,
) -> (TransactionStatus, T)
where
T: Clone + Serialize + DeserializeOwned,
F: Fn(T, T) -> Option<T>,
{
let mut cache = self.lock_version_cache.borrow_mut();
let lock_version = match cache.get(&key) {
Some(cached_cas) => cached_cas,
None => {
if let (_, Some(stored_version)) = self.shared_data.shared_data_get(&key) {
Some(stored_version)
} else {
Self::generate_random_lock_version()
}
}
};
let (transaction_status, algorithm_state) =
self.lock_and_save(&key, data, lock_version, handle_consistency);
cache.remove(&key);
(transaction_status, algorithm_state)
}
pub fn update<T, F>(&self, key: &str, update_function: F) -> (TransactionStatus, Option<T>)
where
T: Clone + Serialize + DeserializeOwned,
F: Fn(Option<&T>) -> Result<Option<T>, UpdateError>,
{
loop {
let (data, lock) = self.shared_data.shared_data_get(key);
let data: Option<T> = data.and_then(|data| bincode::deserialize(&data).ok());
let lock = lock
.map(Option::Some)
.unwrap_or_else(Self::generate_random_lock_version);
match update_function(data.as_ref()) {
Ok(Some(value)) => match self.save(key, &value, lock) {
Ok(()) => return (TransactionStatus::Complete, Some(value)),
Err(e) => {
trace!(
"Failed to persist data for identifier {} with error {:?}",
&key,
e
);
if e != Status::CasMismatch {
return (TransactionStatus::InternalError, None);
}
}
},
Ok(None) => match self.shared_data.shared_data_remove(key, None) {
Ok(_) => return (TransactionStatus::Complete, None),
Err(_) => return (TransactionStatus::InternalError, None),
},
Err(UpdateError::Desist) => {
return (TransactionStatus::Rejected, data);
}
}
}
}
pub fn get<T>(&self, key: &str) -> Option<T>
where
T: Clone + Serialize + DeserializeOwned,
{
if let (Some(data), cas) = self.shared_data.shared_data_get(key) {
self.lock_version_cache
.borrow_mut()
.save(String::from(key), cas);
self.deserialize_value(key, data)
} else {
None
}
}
pub fn remove<T>(&self, key: &str) -> Option<T>
where
T: Clone + Serialize + DeserializeOwned,
{
match self.shared_data.shared_data_remove(key, None) {
Ok(Some(value)) => {
self.lock_version_cache.borrow_mut().remove(key);
self.deserialize_value(key, value)
}
Ok(None) => None,
Err(err) => {
let error_message = self.interpret_envoy_shared_data_errors(err);
trace!(
"Failed to remove data for identifier {} with error {}",
&key,
&error_message
);
None
}
}
}
pub fn keys(&self) -> Vec<String> {
self.shared_data.shared_data_keys()
}
pub fn safe_remove(&self, key: &str) {
match self.shared_data.shared_data_remove(key, None) {
Ok(_) => {
self.lock_version_cache.borrow_mut().remove(key);
}
Err(err) => {
let error_message = self.interpret_envoy_shared_data_errors(err);
trace!(
"Failed to remove data for identifier {} with error {}",
&key,
&error_message
);
}
}
}
fn interpret_envoy_shared_data_errors(&self, error_status: Status) -> String {
match error_status {
Status::BadArgument => String::from("Bad Argument"),
Status::InternalFailure => String::from("Internal Failure"),
Status::ParseFailure => String::from("Parse Failure"),
Status::NotFound => String::from("Entity Not Found"),
Status::Empty => String::from("Empty Entity"),
Status::CasMismatch => String::from("Cas Mismatch"),
_ => String::from("OK"),
}
}
fn lock_and_save<T, F>(
&self,
key: &str,
mut data_to_persist: T,
lock_version: Option<u32>,
handle_consistency: F,
) -> (TransactionStatus, T)
where
T: Clone + Serialize + DeserializeOwned,
F: Fn(T, T) -> Option<T>,
{
let mut result: Result<(), Status> = self.save(key, &data_to_persist, lock_version);
while let Err(Status::CasMismatch) = result {
trace!(
"Failed to persist data for identifier {} with error {:?}",
&key,
Status::CasMismatch
);
match self.shared_data.shared_data_get(key) {
(Some(data), lock_version) => {
trace!("Executing handle consistency function for key {key}");
if let Ok(current_data) = bincode::deserialize::<T>(&data) {
let consistency_result =
handle_consistency(current_data.clone(), data_to_persist.clone());
if let Some(value) = consistency_result {
data_to_persist = value;
result = self.save(key, &data_to_persist, lock_version);
} else {
return (TransactionStatus::Rejected, current_data);
}
} else {
return (TransactionStatus::InternalError, data_to_persist);
}
}
(None, Some(version)) => {
trace!("No value found for {key}, but lock version present, retrying store");
result = self.save(key, &data_to_persist, Some(version));
}
(None, None) => {
trace!("No value found for {key}, retrying store");
result = self.save(key, &data_to_persist, lock_version);
}
}
}
if let Err(err) = result {
let error_message = self.interpret_envoy_shared_data_errors(err);
trace!(
"Failed to persist data for identifier {} with error {}",
&key,
&error_message
);
(TransactionStatus::InternalError, data_to_persist)
} else {
(TransactionStatus::Complete, data_to_persist)
}
}
fn save<T>(&self, key: &str, state: &T, lock_version: Option<u32>) -> Result<(), Status>
where
T: Serialize + DeserializeOwned,
{
let serialized_state = bincode::serialize(state).unwrap();
self.shared_data
.shared_data_set(key, serialized_state.as_slice(), lock_version)
}
fn generate_random_lock_version() -> Option<u32> {
match random_generator::generate_u32() {
Ok(version) => Some(version),
Err(e) => {
log::error!("Error trying to generate version: {e}");
None
}
}
}
fn deserialize_value<T>(&self, key: &str, data: Bytes) -> Option<T>
where
T: Clone + Serialize + DeserializeOwned,
{
let result = bincode::deserialize(&data);
match result {
Ok(value) => Some(value),
Err(err) => {
warn!("Unexpected error trying to deserialize value for key {key}: {err:?}");
None
}
}
}
}
#[cfg(test)]
mod test {
use std::cell::RefCell;
use std::ops::Add;
use std::rc::Rc;
use std::time::{Duration, SystemTime};
use crate::proxy_wasm::types::{Bytes, Status};
use mockall::mock;
use mockall::predicate::{always, eq};
use mockall::Sequence;
use serde::{Deserialize, Serialize};
use super::InMemoryCache;
use super::{ConcurrentSharedData, TransactionStatus, CACHE_EXPIRATION_TIME_IN_MILLIS};
use crate::host::clock::TimeUnit;
mock! {
pub SharedData {}
impl crate::host::shared_data::SharedData for SharedData {
fn shared_data_get(&self, key: &str) -> (Option<Bytes>, Option<u32>);
fn shared_data_set(&self, key: &str, value: &[u8], version: Option<u32>) -> Result<(), Status>;
fn shared_data_remove (&self, key: &str, version: Option<u32>) -> Result<Option<Bytes>, Status>;
fn shared_data_keys (&self) -> Vec<String>;
}
}
mock! {
pub Clock {}
impl crate::host::clock::Clock for Clock {
fn get_current_time(&self) -> SystemTime;
fn get_current_time_unit(&self, unit:TimeUnit) ->u128;
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Copy, Clone)]
pub struct SerializableObject {
property_one: u64,
property_two: u64,
}
#[test]
fn get_state_successfully() {
let state = SerializableObject {
property_one: 1,
property_two: 10,
};
let serialized_state = bincode::serialize(&state);
let mut mock_clock = MockClock::new();
let now = SystemTime::now();
let now_plus_five_seconds = now.add(Duration::new(5, 0));
let mut mock_shared_data = MockSharedData::new();
mock_shared_data_get(
&mut mock_shared_data,
Some(serialized_state.as_ref().unwrap().clone()),
None,
None,
);
mock_clock
.expect_get_current_time()
.times(1)
.returning(move || now);
mock_clock
.expect_get_current_time()
.times(1)
.returning(move || now_plus_five_seconds);
let storage = ConcurrentSharedData {
lock_version_cache: RefCell::new(InMemoryCache::new(
Rc::new(mock_clock),
CACHE_EXPIRATION_TIME_IN_MILLIS,
)),
shared_data: Rc::new(mock_shared_data),
};
let found_state = storage.get("key");
assert_eq!(state, found_state.unwrap());
}
#[test]
fn get_non_existent_state() {
let mut mock_clock = MockClock::new();
let mut mock_shared_data = MockSharedData::new();
mock_clock_now(&mut mock_clock, &mut Sequence::new());
mock_shared_data_get(&mut mock_shared_data, None, None, None);
let storage = ConcurrentSharedData {
lock_version_cache: RefCell::new(InMemoryCache::new(
Rc::new(mock_clock),
CACHE_EXPIRATION_TIME_IN_MILLIS,
)),
shared_data: Rc::new(mock_shared_data),
};
let found_state: Option<SerializableObject> = storage.get("key");
assert!(found_state.is_none());
}
#[test]
fn remove_non_existent() {
let mut mock_clock = MockClock::new();
let mut mock_shared_data = MockSharedData::new();
mock_shared_data
.expect_shared_data_remove()
.with(eq("key"), eq(None))
.times(1)
.returning(move |_x: &str, _y: Option<u32>| Ok(None));
mock_clock_now(&mut mock_clock, &mut Sequence::new());
let storage = ConcurrentSharedData {
lock_version_cache: RefCell::new(InMemoryCache::new(
Rc::new(mock_clock),
CACHE_EXPIRATION_TIME_IN_MILLIS,
)),
shared_data: Rc::new(mock_shared_data),
};
let found_state: Option<SerializableObject> = storage.remove("key");
assert!(found_state.is_none());
}
#[test]
fn remove_stored_object() {
let mut mock_clock = MockClock::new();
let mut mock_shared_data = MockSharedData::new();
let persisted_state = SerializableObject {
property_one: 1,
property_two: 100,
};
let serialized_persisted_state = bincode::serialize(&persisted_state);
mock_shared_data
.expect_shared_data_remove()
.with(eq("key"), eq(None))
.times(1)
.returning(move |_x: &str, _y: Option<u32>| {
Ok(Some(serialized_persisted_state.as_ref().unwrap().clone()))
});
mock_clock_now(&mut mock_clock, &mut Sequence::new());
mock_clock_now(&mut mock_clock, &mut Sequence::new());
let mut lock_version_cache: InMemoryCache<Option<u32>> =
InMemoryCache::new(Rc::new(mock_clock), CACHE_EXPIRATION_TIME_IN_MILLIS);
lock_version_cache.save(String::from("key"), Option::from(1));
let storage = ConcurrentSharedData {
lock_version_cache: RefCell::new(lock_version_cache),
shared_data: Rc::new(mock_shared_data),
};
let found_state: Option<SerializableObject> = storage.remove("key");
assert_eq!(found_state, Some(persisted_state));
let cache = storage.lock_version_cache.borrow();
assert_eq!(cache.get("key"), None)
}
#[test]
fn save_state_without_optimistic_locking() {
let mut mock_clock = MockClock::new();
let state = SerializableObject {
property_one: 1,
property_two: 10,
};
let expected_state = state;
let mut mock_shared_data = MockSharedData::new();
mock_shared_data_get(&mut mock_shared_data, None, None, None);
mock_clock_now(&mut mock_clock, &mut Sequence::new());
mock_shared_data
.expect_shared_data_set()
.with(eq("key"), always(), always())
.times(1)
.returning(move |_x: &str, _value: &[u8], _lock: Option<u32>| Ok(()));
let storage = ConcurrentSharedData {
lock_version_cache: RefCell::new(InMemoryCache::new(
Rc::new(mock_clock),
CACHE_EXPIRATION_TIME_IN_MILLIS,
)),
shared_data: Rc::new(mock_shared_data),
};
let (status, found_state) =
storage.insert(String::from("key"), state, |_previous, _new| Option::None);
assert_eq!(TransactionStatus::Complete, status);
assert_eq!(expected_state, found_state);
}
#[test]
fn save_state_with_correct_lock_version() {
let now = SystemTime::now();
let now_plus_five_seconds = now.add(Duration::new(5, 0));
let state = SerializableObject {
property_one: 1,
property_two: 10,
};
let expected_state = state;
let mut mock_clock = MockClock::new();
let mut mock_shared_data = MockSharedData::new();
mock_shared_data_set(&mut mock_shared_data, &mut Sequence::new());
let mut seq = Sequence::new();
mock_clock
.expect_get_current_time()
.times(1)
.returning(move || now)
.in_sequence(&mut seq);
mock_clock
.expect_get_current_time()
.times(1)
.returning(move || now_plus_five_seconds)
.in_sequence(&mut seq);
let mut lock_version_cache: InMemoryCache<Option<u32>> =
InMemoryCache::new(Rc::new(mock_clock), CACHE_EXPIRATION_TIME_IN_MILLIS);
lock_version_cache.save(String::from("key"), Option::from(1));
let mut storage = ConcurrentSharedData {
lock_version_cache: RefCell::new(lock_version_cache),
shared_data: Rc::new(mock_shared_data),
};
let (status, algorithm_obtained_state) =
storage.insert(String::from("key"), state, |_previous, _new| Option::None);
let cache = storage.lock_version_cache.get_mut();
cache.remove("key");
assert_eq!(TransactionStatus::Complete, status);
assert_eq!(expected_state, algorithm_obtained_state);
assert!(cache.is_empty())
}
#[test]
fn save_state_first_time_lock_version_collision() {
let mut mock_clock = MockClock::new();
let state_to_persist = SerializableObject {
property_one: 1,
property_two: 100,
};
let persisted_state = SerializableObject {
property_one: 1,
property_two: 100,
};
let serialized_persisted_state = bincode::serialize(&persisted_state);
let mut mock_shared_data = MockSharedData::new();
let mut sequence = Sequence::new();
mock_clock_now(&mut mock_clock, &mut sequence);
mock_shared_data_get(&mut mock_shared_data, None, None, Some(&mut sequence));
mock_shared_data_set_cas_mismatch(&mut mock_shared_data, &mut sequence);
mock_shared_data_get(
&mut mock_shared_data,
Some(serialized_persisted_state.as_ref().unwrap().clone()),
Some(1),
None,
);
mock_shared_data_set(&mut mock_shared_data, &mut sequence);
let lock_version_cache: InMemoryCache<Option<u32>> =
InMemoryCache::new(Rc::new(mock_clock), CACHE_EXPIRATION_TIME_IN_MILLIS);
let handle_consistency_mock_first_insertion =
|_x: SerializableObject, _y: SerializableObject| -> Option<SerializableObject> {
Option::from(SerializableObject {
property_one: 2,
property_two: 100,
})
};
let storage = ConcurrentSharedData {
lock_version_cache: RefCell::new(lock_version_cache),
shared_data: Rc::new(mock_shared_data),
};
let (status, found_state): (_, SerializableObject) = storage.insert(
String::from("key"),
state_to_persist,
handle_consistency_mock_first_insertion,
);
let cache = storage.lock_version_cache.borrow();
let option_key = cache.get("key");
assert_eq!(TransactionStatus::Complete, status);
assert_eq!(2, found_state.property_one);
assert_eq!(100, found_state.property_two);
assert!(option_key.is_none())
}
#[test]
fn save_state_with_incorrect_lock_version_and_retry_success() {
let mut mock_clock = MockClock::new();
let now = SystemTime::now();
let now_plus_five_seconds = now.add(Duration::new(5, 0));
let persisted_state = SerializableObject {
property_one: 1,
property_two: 10,
};
let serialized_persisted_state = bincode::serialize(&persisted_state);
let state_to_persist = SerializableObject {
property_one: 2,
property_two: 11,
};
let mut mock_shared_data = MockSharedData::new();
let mut sequence = Sequence::new();
mock_shared_data_set_cas_mismatch(&mut mock_shared_data, &mut sequence);
mock_shared_data_set(&mut mock_shared_data, &mut sequence);
mock_shared_data_get(
&mut mock_shared_data,
Some(serialized_persisted_state.as_ref().unwrap().clone()),
Some(1),
None,
);
let mut seq = Sequence::new();
mock_clock_now(&mut mock_clock, &mut seq);
mock_clock
.expect_get_current_time()
.times(1)
.returning(move || now_plus_five_seconds)
.in_sequence(&mut seq);
let handle_consistency_mock =
|_x: SerializableObject, _y: SerializableObject| -> Option<SerializableObject> {
Option::from(SerializableObject {
property_one: 2,
property_two: 11,
})
};
let mut lock_version_cache: InMemoryCache<Option<u32>> =
InMemoryCache::new(Rc::new(mock_clock), CACHE_EXPIRATION_TIME_IN_MILLIS);
lock_version_cache.save(String::from("key"), Option::from(1));
let storage = ConcurrentSharedData {
lock_version_cache: RefCell::new(lock_version_cache),
shared_data: Rc::new(mock_shared_data),
};
let (status, algorithm_obtained_state) = storage.insert(
String::from("key"),
state_to_persist,
handle_consistency_mock,
);
let cache = storage.lock_version_cache.borrow();
let option_key = cache.get("key");
assert_eq!(TransactionStatus::Complete, status);
assert_eq!(2, algorithm_obtained_state.property_one);
assert_eq!(11, algorithm_obtained_state.property_two);
assert!(option_key.is_none())
}
#[test]
fn save_state_with_incorrect_lock_version_and_retry_with_inconsistency() {
let mut mock_clock = MockClock::new();
let now = SystemTime::now();
let now_plus_five_seconds = now.add(Duration::new(5, 0));
let persisted_state = SerializableObject {
property_one: 1,
property_two: 10,
};
let serialized_persisted_state = bincode::serialize(&persisted_state);
let state_to_persist = SerializableObject {
property_one: 2,
property_two: 11,
};
let mut mock_shared_data = MockSharedData::new();
let mut seq = Sequence::new();
mock_clock_now(&mut mock_clock, &mut seq);
mock_clock
.expect_get_current_time()
.times(1)
.returning(move || now_plus_five_seconds)
.in_sequence(&mut seq);
let mut sequence = Sequence::new();
mock_shared_data_set_cas_mismatch(&mut mock_shared_data, &mut sequence);
mock_shared_data_get(
&mut mock_shared_data,
Some(serialized_persisted_state.as_ref().unwrap().clone()),
Some(1),
None,
);
let mut lock_version_cache: InMemoryCache<Option<u32>> =
InMemoryCache::new(Rc::new(mock_clock), CACHE_EXPIRATION_TIME_IN_MILLIS);
lock_version_cache.save(String::from("key"), Option::from(1));
let storage = ConcurrentSharedData {
lock_version_cache: RefCell::new(lock_version_cache),
shared_data: Rc::new(mock_shared_data),
};
let (status, algorithm_obtained_state) =
storage.insert(String::from("key"), state_to_persist, |_previous, _new| {
Option::None
});
let cache = storage.lock_version_cache.borrow();
let option_key = cache.get("key");
assert_eq!(TransactionStatus::Rejected, status);
assert_eq!(persisted_state, algorithm_obtained_state);
assert!(option_key.is_none())
}
#[test]
fn save_state_with_incorrect_lock_version_and_retry_value_not_found() {
let mut mock_clock = MockClock::new();
let now = SystemTime::now();
let now_plus_five_seconds = now.add(Duration::new(5, 0));
let state_to_persist = SerializableObject {
property_one: 2,
property_two: 11,
};
let mut mock_shared_data = MockSharedData::new();
let mut sequence = Sequence::new();
mock_shared_data_set_cas_mismatch(&mut mock_shared_data, &mut sequence);
mock_shared_data_get(&mut mock_shared_data, None, None, None);
mock_shared_data_set(&mut mock_shared_data, &mut sequence);
let mut seq = Sequence::new();
mock_clock_now(&mut mock_clock, &mut seq);
mock_clock
.expect_get_current_time()
.times(1)
.returning(move || now_plus_five_seconds)
.in_sequence(&mut seq);
let mut lock_version_cache: InMemoryCache<Option<u32>> =
InMemoryCache::new(Rc::new(mock_clock), CACHE_EXPIRATION_TIME_IN_MILLIS);
lock_version_cache.save(String::from("key"), Option::from(1));
let storage = ConcurrentSharedData {
lock_version_cache: RefCell::new(lock_version_cache),
shared_data: Rc::new(mock_shared_data),
};
let (status, algorithm_obtained_state) =
storage.insert(String::from("key"), state_to_persist, |_previous, _new| {
Option::None
});
let cache = storage.lock_version_cache.borrow();
let option_key = cache.get("key");
assert_eq!(TransactionStatus::Complete, status);
assert_eq!(
state_to_persist.property_one,
algorithm_obtained_state.property_one
);
assert_eq!(
state_to_persist.property_two,
algorithm_obtained_state.property_two
);
assert!(option_key.is_none())
}
fn mock_shared_data_get(
mock_shared_data: &mut MockSharedData,
value: Option<Vec<u8>>,
cas: Option<u32>,
seq: Option<&mut Sequence>,
) {
let ongoing = mock_shared_data
.expect_shared_data_get()
.with(eq("key"))
.times(1)
.returning(move |_x: &str| (value.clone(), cas));
if let Some(s) = seq {
ongoing.in_sequence(s);
}
}
fn mock_shared_data_set(mock_shared_data: &mut MockSharedData, sequence: &mut Sequence) {
mock_shared_data
.expect_shared_data_set()
.times(1)
.in_sequence(sequence)
.returning(move |_x: &str, _value: &[u8], _lock: Option<u32>| Ok(()));
}
fn mock_shared_data_set_cas_mismatch(
mock_shared_data: &mut MockSharedData,
sequence: &mut Sequence,
) {
mock_shared_data
.expect_shared_data_set()
.times(1)
.in_sequence(sequence)
.returning(move |_x: &str, _value: &[u8], _lock: Option<u32>| Err(Status::CasMismatch));
}
fn mock_clock_now(mock_clock: &mut MockClock, seq: &mut Sequence) {
let now = SystemTime::now();
mock_clock
.expect_get_current_time()
.times(1)
.returning(move || now)
.in_sequence(seq);
}
}