use std::any::Any;
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use bytes::Bytes;
use crate::processor::record::RecordContext;
use crate::processor::serde::Serde;
use crate::store::api::{KeyValueStore, StateStore};
use crate::store::byte::{ByteKeyValueStore, InMemoryBytes};
use crate::store::cache::kv::CachingKeyValueStore;
use crate::store::cache::named::NamedCache;
enum Backing {
Plain(Box<dyn ByteKeyValueStore>),
Cached(CachingKeyValueStore),
}
impl Backing {
async fn get(&self, key: &[u8]) -> Option<Bytes> {
match self {
Backing::Plain(b) => b.get(key).await,
Backing::Cached(c) => c.get(key).await,
}
}
async fn range(&self, lo: &[u8], hi: &[u8]) -> Vec<(Bytes, Bytes)> {
match self {
Backing::Plain(b) => b.range(lo, hi).await,
Backing::Cached(c) => c.range(lo, hi).await,
}
}
async fn scan_all(&self) -> Vec<(Bytes, Bytes)> {
match self {
Backing::Plain(b) => b.scan_all().await,
Backing::Cached(c) => c.scan_all().await,
}
}
async fn approx_len(&self) -> u64 {
match self {
Backing::Plain(b) => b.approx_len().await,
Backing::Cached(c) => c.scan_all().await.len() as u64,
}
}
async fn put(&mut self, key: Bytes, value: Bytes, ctx: RecordContext) {
match self {
Backing::Plain(b) => b.put(key, value).await,
Backing::Cached(c) => c.put(key, value, ctx).await,
}
}
async fn delete(&mut self, key: Bytes, ctx: RecordContext) -> Option<Bytes> {
match self {
Backing::Plain(b) => b.delete(&key).await,
Backing::Cached(c) => {
let prev = c.get(&key).await;
c.delete(key, ctx).await;
prev
}
}
}
async fn apply(&mut self, key: Bytes, value: Option<Bytes>) {
match (self, value) {
(Backing::Plain(b), Some(v)) => b.put(key, v).await,
(Backing::Plain(b), None) => {
b.delete(&key).await;
}
(Backing::Cached(c), Some(v)) => c.put_inner(key, v).await,
(Backing::Cached(c), None) => c.delete_inner(&key).await,
}
}
async fn clear(&mut self) {
match self {
Backing::Plain(b) => b.clear().await,
Backing::Cached(c) => c.clear().await,
}
}
}
pub struct KeyValueBytesStore<K, V> {
name: String,
changelog_topic: String,
backing: Backing,
key_serde: Box<dyn Serde<K>>,
value_serde: Box<dyn Serde<V>>,
changelog: Vec<(Bytes, Option<Bytes>)>,
logging: bool,
pending_ctx: Option<RecordContext>,
}
impl<K: 'static, V: 'static> KeyValueBytesStore<K, V> {
#[must_use]
pub(crate) fn new(
name: String,
backend: Box<dyn ByteKeyValueStore>,
key_serde: Box<dyn Serde<K>>,
value_serde: Box<dyn Serde<V>>,
changelog_topic: String,
) -> Self {
Self {
name,
changelog_topic,
backing: Backing::Plain(backend),
key_serde,
value_serde,
changelog: Vec::new(),
logging: true,
pending_ctx: None,
}
}
pub(crate) fn enable_cache(&mut self, cache: Arc<Mutex<NamedCache>>) {
if !matches!(self.backing, Backing::Plain(_)) {
return; }
let placeholder = Backing::Plain(Box::new(InMemoryBytes::default()));
let Backing::Plain(backend) = std::mem::replace(&mut self.backing, placeholder) else {
unreachable!("guarded by the matches! above")
};
self.backing = Backing::Cached(CachingKeyValueStore::with_name(
cache,
backend,
self.name.clone(),
));
}
#[must_use]
pub(crate) fn is_cached(&self) -> bool {
matches!(self.backing, Backing::Cached(_))
}
#[must_use]
pub fn in_memory(
name: String,
key_serde: Box<dyn Serde<K>>,
value_serde: Box<dyn Serde<V>>,
changelog_topic: String,
) -> Self {
Self::new(
name,
Box::new(InMemoryBytes::default()),
key_serde,
value_serde,
changelog_topic,
)
}
fn write_ctx(&self) -> RecordContext {
self.pending_ctx.clone().unwrap_or(RecordContext {
topic: self.changelog_topic.clone(),
partition: 0,
offset: 0,
timestamp: 0,
})
}
}
#[async_trait]
impl<K: Send + 'static, V: Send + 'static> StateStore for KeyValueBytesStore<K, V> {
fn name(&self) -> &str {
&self.name
}
async fn flush(&mut self) {}
fn close(&mut self) {}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
fn changelog_topic(&self) -> &str {
&self.changelog_topic
}
fn take_changelog(&mut self) -> Vec<(Bytes, Option<Bytes>)> {
std::mem::take(&mut self.changelog)
}
async fn apply_changelog(&mut self, key: Bytes, value: Option<Bytes>) {
self.backing.apply(key, value).await;
}
fn set_logging(&mut self, on: bool) {
self.logging = on;
}
fn as_iq(&self) -> Option<&dyn crate::store::iq::IqQueryable> {
Some(self)
}
fn set_record_context(&mut self, ctx: RecordContext) {
self.pending_ctx = Some(ctx);
}
#[allow(private_interfaces)]
fn enable_cache_erased(&mut self, cache: Arc<Mutex<NamedCache>>) -> bool {
self.enable_cache(cache);
true
}
fn is_cached_erased(&self) -> bool {
self.is_cached()
}
#[allow(private_interfaces)]
async fn flush_cache_into(
&mut self,
buffer: &mut std::collections::VecDeque<(usize, crate::processor::erased::ErasedRecord)>,
children: &[usize],
) {
use crate::dsl::processors::change::Change;
use crate::processor::erased::ErasedRecord;
let Backing::Cached(cache) = &self.backing else {
return;
};
let drained = cache.flush_with_old().await; for (kb, old_vb, new_vb, ctx) in drained {
if self.logging {
self.changelog.push((kb.clone(), new_vb.clone()));
}
for &child in children {
let k: K = self
.key_serde
.deserialize(&self.changelog_topic, &kb)
.expect("flush_cache_into key deserialize");
let old: Option<V> = old_vb.as_ref().map(|b| {
self.value_serde
.deserialize(&self.changelog_topic, b)
.expect("flush_cache_into old value deserialize")
});
let new: Option<V> = new_vb.as_ref().map(|b| {
self.value_serde
.deserialize(&self.changelog_topic, b)
.expect("flush_cache_into new value deserialize")
});
let change = Change { old, new };
buffer.push_back((
child,
ErasedRecord::new(Some(Box::new(k)), Box::new(change), ctx.timestamp),
));
}
}
}
async fn clear(&mut self) {
self.backing.clear().await;
self.changelog.clear();
}
}
#[async_trait::async_trait]
impl<K: Send + 'static, V: Send + 'static> crate::store::iq::IqQueryable
for KeyValueBytesStore<K, V>
{
fn kind(&self) -> crate::store::iq::StoreKind {
crate::store::iq::StoreKind::KeyValue
}
async fn iq_kv_get(&self, key: &[u8]) -> Option<bytes::Bytes> {
self.backing.get(key).await
}
async fn iq_kv_range(&self, lo: &[u8], hi: &[u8]) -> Vec<(bytes::Bytes, bytes::Bytes)> {
let mut hi_succ = hi.to_vec();
hi_succ.push(0);
self.backing.range(lo, &hi_succ).await
}
async fn iq_kv_all(&self) -> Vec<(bytes::Bytes, bytes::Bytes)> {
self.backing.scan_all().await
}
async fn iq_kv_approx_count(&self) -> u64 {
self.backing.approx_len().await
}
async fn iq2_execute(
&self,
query: &crate::store::iq::Iq2Query,
) -> Result<Box<dyn Any + Send>, crate::store::iq::Iq2Failure> {
use crate::store::iq::{Iq2Failure, Iq2Query};
let ser = |b: &Box<dyn Any + Send + Sync>| -> Result<bytes::Bytes, Iq2Failure> {
let k = b.downcast_ref::<K>().ok_or(Iq2Failure::KeyTypeMismatch)?;
Ok(self.key_serde.serialize(&self.changelog_topic, k))
};
match query {
Iq2Query::Key { key } => {
let kb = ser(key)?;
let out: Option<V> = self.backing.get(&kb).await.map(|vb| {
self.value_serde
.deserialize(&self.changelog_topic, &vb)
.expect("iqv2 kv value deserialize")
});
Ok(Box::new(out))
}
Iq2Query::Range { lo, hi, descending } => {
let lo_b = match lo {
Some(b) => Some(ser(b)?),
None => None,
};
let hi_b = match hi {
Some(b) => Some(ser(b)?),
None => None,
};
let mut rows: Vec<(K, V)> = Vec::new();
for (kb, vb) in self.backing.scan_all().await {
if lo_b.as_ref().is_some_and(|l| kb.as_ref() < l.as_ref()) {
continue;
}
if hi_b.as_ref().is_some_and(|h| kb.as_ref() > h.as_ref()) {
continue;
}
rows.push((
self.key_serde
.deserialize(&self.changelog_topic, &kb)
.expect("iqv2 kv range key deserialize"),
self.value_serde
.deserialize(&self.changelog_topic, &vb)
.expect("iqv2 kv range value deserialize"),
));
}
if *descending {
rows.reverse();
}
Ok(Box::new(rows))
}
_ => Err(Iq2Failure::UnknownQueryType),
}
}
}
#[async_trait]
impl<K: Send + Sync + 'static, V: Send + 'static> KeyValueStore<K, V> for KeyValueBytesStore<K, V> {
async fn get(&self, key: &K) -> Option<V> {
let kb = self.key_serde.serialize(&self.changelog_topic, key);
self.backing.get(&kb).await.map(|vb| {
self.value_serde
.deserialize(&self.changelog_topic, &vb)
.expect("store value deserialize")
})
}
async fn put(&mut self, key: K, value: V) {
let kb = self.key_serde.serialize(&self.changelog_topic, &key);
let vb = self.value_serde.serialize(&self.changelog_topic, &value);
match &self.backing {
Backing::Plain(_) => {
self.backing
.put(kb.clone(), vb.clone(), self.write_ctx())
.await;
if self.logging {
self.changelog.push((kb, Some(vb)));
}
}
Backing::Cached(_) => {
let ctx = self.write_ctx();
self.backing.put(kb, vb, ctx).await;
}
}
}
async fn delete(&mut self, key: &K) -> Option<V> {
let kb = self.key_serde.serialize(&self.changelog_topic, key);
let ctx = self.write_ctx();
let cached = matches!(self.backing, Backing::Cached(_));
let prev = self.backing.delete(kb.clone(), ctx).await.map(|vb| {
self.value_serde
.deserialize(&self.changelog_topic, &vb)
.expect("store value deserialize")
});
if self.logging && !cached {
self.changelog.push((kb, None));
}
prev
}
async fn range(&self, lo: &K, hi: &K) -> Vec<(K, V)> {
let lo_b = self.key_serde.serialize(&self.changelog_topic, lo);
let hi_b = self.key_serde.serialize(&self.changelog_topic, hi);
self.backing
.range(&lo_b, &hi_b)
.await
.into_iter()
.map(|(kb, vb)| {
(
self.key_serde
.deserialize(&self.changelog_topic, &kb)
.expect("kv range key deserialize"),
self.value_serde
.deserialize(&self.changelog_topic, &vb)
.expect("kv range value deserialize"),
)
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::processor::serde::{I64Serde, StringSerde};
use assert2::check;
fn store() -> KeyValueBytesStore<String, i64> {
KeyValueBytesStore::in_memory(
"s".into(),
Box::new(StringSerde),
Box::new(I64Serde),
"s-changelog".into(),
)
}
fn cached_store() -> KeyValueBytesStore<String, i64> {
let mut s = store();
s.enable_cache(Arc::new(Mutex::new(NamedCache::new("s".into()))));
s
}
fn ctx_at(ts: i64) -> RecordContext {
RecordContext {
topic: "t".into(),
partition: 0,
offset: 0,
timestamp: ts,
}
}
#[tokio::test]
async fn put_get_delete_and_changelog_buffer() {
let mut s = store();
s.put("a".into(), 1).await;
s.put("a".into(), 2).await;
check!(s.get(&"a".to_string()).await == Some(2));
check!(s.delete(&"a".to_string()).await == Some(2));
check!(s.get(&"a".to_string()).await == None);
let cl = s.take_changelog();
check!(cl.len() == 3);
check!(cl[2].1.is_none());
check!(s.take_changelog().is_empty());
}
#[tokio::test]
async fn range_returns_ordered_half_open() {
use crate::processor::serde::BytesSerde;
use bytes::Bytes;
let mut s = KeyValueBytesStore::<Bytes, Bytes>::in_memory(
"r".into(),
Box::new(BytesSerde),
Box::new(BytesSerde),
"r-cl".into(),
);
s.put(Bytes::from_static(&[1, 0]), Bytes::from_static(b"a"))
.await;
s.put(Bytes::from_static(&[1, 5]), Bytes::from_static(b"b"))
.await;
s.put(Bytes::from_static(&[2, 0]), Bytes::from_static(b"c"))
.await;
let r = s
.range(&Bytes::from_static(&[1, 0]), &Bytes::from_static(&[2, 0]))
.await; assert_eq!(
r,
vec![
(Bytes::from_static(&[1, 0]), Bytes::from_static(b"a")),
(Bytes::from_static(&[1, 5]), Bytes::from_static(b"b")),
]
);
}
#[tokio::test]
async fn iq2_key_and_range() {
use crate::store::iq::{Iq2Failure, Iq2Query, IqQueryable};
let mut s = store();
s.put("a".into(), 1).await;
s.put("b".into(), 2).await;
s.put("c".into(), 3).await;
let q: &dyn IqQueryable = s.as_iq().unwrap();
let got = q
.iq2_execute(&Iq2Query::Key {
key: Box::new("b".to_string()),
})
.await
.unwrap();
assert_eq!(*got.downcast::<Option<i64>>().unwrap(), Some(2));
let miss = q
.iq2_execute(&Iq2Query::Key {
key: Box::new("z".to_string()),
})
.await
.unwrap();
assert_eq!(*miss.downcast::<Option<i64>>().unwrap(), None);
let r = q
.iq2_execute(&Iq2Query::Range {
lo: Some(Box::new("a".to_string())),
hi: Some(Box::new("b".to_string())),
descending: false,
})
.await
.unwrap();
assert_eq!(
*r.downcast::<Vec<(String, i64)>>().unwrap(),
vec![("a".to_string(), 1), ("b".to_string(), 2)]
);
let all_desc = q
.iq2_execute(&Iq2Query::Range {
lo: None,
hi: None,
descending: true,
})
.await
.unwrap();
assert_eq!(
*all_desc.downcast::<Vec<(String, i64)>>().unwrap(),
vec![
("c".to_string(), 3),
("b".to_string(), 2),
("a".to_string(), 1)
]
);
let bad = q
.iq2_execute(&Iq2Query::Key {
key: Box::new(7_i64),
})
.await;
assert_eq!(bad.err(), Some(Iq2Failure::KeyTypeMismatch));
}
#[tokio::test]
async fn apply_changelog_restores_without_re_logging() {
let mut s = store();
s.apply_changelog(
b"k".to_vec().into(),
Some(bytes::Bytes::from_static(&[0, 0, 0, 0, 0, 0, 0, 7])),
)
.await;
check!(s.get(&"k".to_string()).await == Some(7));
check!(s.take_changelog().is_empty());
s.apply_changelog(b"k".to_vec().into(), None).await;
check!(s.get(&"k".to_string()).await == None);
}
#[tokio::test]
async fn cached_store_reads_your_writes() {
let mut s = cached_store();
s.set_record_context(ctx_at(0));
s.put("a".into(), 1).await;
s.put("a".into(), 2).await;
check!(s.get(&"a".to_string()).await == Some(2));
let mut buffer = std::collections::VecDeque::new();
s.flush_cache_into(&mut buffer, &[0]).await;
check!(buffer.len() == 1);
}
#[tokio::test]
async fn cached_store_defers_changelog_until_flush() {
let mut s = cached_store();
s.set_record_context(ctx_at(0));
s.put("a".into(), 1).await;
s.put("a".into(), 2).await;
check!(s.take_changelog().is_empty());
let mut buffer = std::collections::VecDeque::new();
s.flush_cache_into(&mut buffer, &[0]).await;
let cl = s.take_changelog();
check!(cl.len() == 1);
check!(cl[0].0 == StringSerde.serialize("s-changelog", &"a".to_string()));
check!(cl[0].1 == Some(I64Serde.serialize("s-changelog", &2)));
}
#[tokio::test]
async fn flush_cache_into_emits_deduped_change() {
use crate::dsl::processors::change::Change;
let mut s = cached_store();
s.set_record_context(ctx_at(0));
s.put("a".into(), 1).await;
let mut seed = std::collections::VecDeque::new();
s.flush_cache_into(&mut seed, &[0]).await;
let _ = s.take_changelog();
s.set_record_context(ctx_at(7));
s.put("a".into(), 2).await;
s.put("a".into(), 3).await;
let mut buffer = std::collections::VecDeque::new();
s.flush_cache_into(&mut buffer, &[7]).await;
check!(buffer.len() == 1);
let (child, rec) = &buffer[0];
check!(*child == 7);
check!(rec.timestamp == 7);
let key = rec.key.as_ref().unwrap().downcast_ref::<String>().unwrap();
check!(key == "a");
let change = rec.value.downcast_ref::<Change<i64>>().unwrap();
check!(change.old == Some(1));
check!(change.new == Some(3));
check!(s.get(&"a".to_string()).await == Some(3));
}
#[tokio::test]
async fn cached_store_range_scan_and_count_overlay() {
use crate::store::iq::IqQueryable;
let mut s = cached_store();
s.set_record_context(ctx_at(0));
s.put("a".into(), 1).await;
s.put("b".into(), 2).await;
s.put("c".into(), 3).await;
let r = s.range(&"a".to_string(), &"c".to_string()).await; check!(r == vec![("a".to_string(), 1), ("b".to_string(), 2)]);
check!(s.iq_kv_all().await.len() == 3);
check!(s.iq_kv_approx_count().await == 3);
check!(s.iq_kv_get(b"b").await == Some(I64Serde.serialize("s-changelog", &2)));
}
#[tokio::test]
async fn cached_store_delete_returns_prev_and_stages_tombstone() {
let mut s = cached_store();
s.set_record_context(ctx_at(0));
s.put("a".into(), 5).await;
check!(s.delete(&"a".to_string()).await == Some(5));
check!(s.take_changelog().is_empty());
check!(s.get(&"a".to_string()).await == None);
let mut buffer = std::collections::VecDeque::new();
s.flush_cache_into(&mut buffer, &[0]).await;
let cl = s.take_changelog();
check!(cl.len() == 1);
check!(cl[0].1.is_none());
}
#[tokio::test]
async fn cached_store_apply_changelog_goes_below_cache() {
let mut s = cached_store();
s.apply_changelog(
b"k".to_vec().into(),
Some(bytes::Bytes::from_static(&[0, 0, 0, 0, 0, 0, 0, 7])),
)
.await;
check!(s.get(&"k".to_string()).await == Some(7));
let mut buffer = std::collections::VecDeque::new();
s.flush_cache_into(&mut buffer, &[0]).await;
check!(buffer.is_empty());
check!(s.take_changelog().is_empty());
s.apply_changelog(b"k".to_vec().into(), None).await;
check!(s.get(&"k".to_string()).await == None);
}
#[tokio::test]
async fn cached_store_clear_empties_everything() {
use crate::store::iq::IqQueryable;
let mut s = cached_store();
s.set_record_context(ctx_at(0));
s.put("a".into(), 1).await;
StateStore::clear(&mut s).await;
check!(s.get(&"a".to_string()).await == None);
check!(s.iq_kv_all().await.is_empty());
let mut buffer = std::collections::VecDeque::new();
s.flush_cache_into(&mut buffer, &[0]).await;
check!(buffer.is_empty());
}
#[tokio::test]
async fn enable_cache_is_idempotent() {
let mut s = store();
check!(!s.is_cached());
s.enable_cache(Arc::new(Mutex::new(NamedCache::new("s".into()))));
check!(s.is_cached());
s.enable_cache(Arc::new(Mutex::new(NamedCache::new("s".into()))));
check!(s.is_cached());
check!(s.is_cached_erased());
}
#[tokio::test]
async fn plain_store_unchanged() {
let mut s = store();
s.put("a".into(), 1).await;
check!(s.get(&"a".to_string()).await == Some(1));
let cl = s.take_changelog();
check!(cl.len() == 1);
check!(cl[0].1 == Some(I64Serde.serialize("s-changelog", &1)));
let mut buffer = std::collections::VecDeque::new();
s.flush_cache_into(&mut buffer, &[0]).await;
check!(buffer.is_empty());
s.flush().await;
s.close();
check!(s.get(&"a".to_string()).await == Some(1));
}
}