use crate::constants;
use crate::error::KinesisErrorResponse;
use crate::types::{Consumer, StoredRecord, Stream};
use redb::backends::InMemoryBackend;
use redb::{Database, ReadableDatabase, ReadableTable, TableDefinition};
use serde_json::Value;
use std::collections::BTreeMap;
use std::sync::Arc;
const STREAMS: TableDefinition<&str, &[u8]> = TableDefinition::new("streams");
const RECORDS: TableDefinition<&str, &[u8]> = TableDefinition::new("records");
const CONSUMERS: TableDefinition<&str, &[u8]> = TableDefinition::new("consumers");
const POLICIES: TableDefinition<&str, &str> = TableDefinition::new("policies");
const RESOURCE_TAGS: TableDefinition<&str, &[u8]> = TableDefinition::new("resource_tags");
const ACCOUNT_SETTINGS: TableDefinition<&str, &[u8]> = TableDefinition::new("account_settings");
#[derive(Debug, Clone)]
pub struct StoreOptions {
pub create_stream_ms: u64,
pub delete_stream_ms: u64,
pub update_stream_ms: u64,
pub shard_limit: u32,
}
impl Default for StoreOptions {
fn default() -> Self {
Self {
create_stream_ms: 500,
delete_stream_ms: 500,
update_stream_ms: 500,
shard_limit: 10,
}
}
}
#[derive(Clone)]
pub struct Store {
pub options: StoreOptions,
pub aws_account_id: String,
pub aws_region: String,
db: Arc<Database>,
}
fn serialize_stream(stream: &Stream) -> Vec<u8> {
let mut val = serde_json::to_value(stream).unwrap();
let obj = val.as_object_mut().unwrap();
obj.insert(
"_seq_ix".to_string(),
serde_json::to_value(&stream.seq_ix).unwrap(),
);
obj.insert(
"_tags".to_string(),
serde_json::to_value(&stream.tags).unwrap(),
);
if let Some(ref key_id) = stream.key_id {
obj.insert("_key_id".to_string(), serde_json::to_value(key_id).unwrap());
}
obj.insert(
"_warm_throughput_mibps".to_string(),
serde_json::to_value(stream.warm_throughput_mibps).unwrap(),
);
obj.insert(
"_max_record_size_kib".to_string(),
serde_json::to_value(stream.max_record_size_kib).unwrap(),
);
serde_json::to_vec(&val).unwrap()
}
fn deserialize_stream(bytes: &[u8]) -> Stream {
let val: Value = serde_json::from_slice(bytes).unwrap();
let mut stream: Stream = serde_json::from_value(val.clone()).unwrap();
if let Some(arr) = val.get("_seq_ix") {
stream.seq_ix = serde_json::from_value(arr.clone()).unwrap_or_default();
}
if let Some(obj) = val.get("_tags") {
stream.tags = serde_json::from_value(obj.clone()).unwrap_or_default();
}
if let Some(key_id) = val.get("_key_id") {
stream.key_id = serde_json::from_value(key_id.clone()).ok();
}
if let Some(v) = val.get("_warm_throughput_mibps") {
stream.warm_throughput_mibps = serde_json::from_value(v.clone()).unwrap_or(0);
}
if let Some(v) = val.get("_max_record_size_kib") {
stream.max_record_size_kib = serde_json::from_value(v.clone()).unwrap_or(1024);
}
stream
}
fn record_key(stream_name: &str, shard_key: &str) -> String {
format!("{stream_name}\0{shard_key}")
}
fn record_range(stream_name: &str, start: &str, end: &str) -> (String, String) {
(
format!("{stream_name}\0{start}"),
format!("{stream_name}\0{end}"),
)
}
impl Store {
pub fn new(options: StoreOptions) -> Self {
let aws_account_id = std::env::var("AWS_ACCOUNT_ID")
.unwrap_or_else(|_| "0000-0000-0000".to_string())
.chars()
.filter(|c| c.is_ascii_digit())
.collect();
let aws_region = std::env::var("AWS_REGION")
.or_else(|_| std::env::var("AWS_DEFAULT_REGION"))
.unwrap_or_else(|_| "us-east-1".to_string());
let db = Database::builder()
.create_with_backend(InMemoryBackend::new())
.expect("Failed to create in-memory redb database");
let write_txn = db.begin_write().unwrap();
write_txn.open_table(STREAMS).unwrap();
write_txn.open_table(RECORDS).unwrap();
write_txn.open_table(CONSUMERS).unwrap();
write_txn.open_table(POLICIES).unwrap();
write_txn.open_table(RESOURCE_TAGS).unwrap();
write_txn.open_table(ACCOUNT_SETTINGS).unwrap();
write_txn.commit().unwrap();
Self {
options,
aws_account_id,
aws_region,
db: Arc::new(db),
}
}
pub async fn get_stream(&self, name: &str) -> Result<Stream, KinesisErrorResponse> {
let read_txn = self.db.begin_read().unwrap();
let table = read_txn.open_table(STREAMS).unwrap();
match table.get(name).unwrap() {
Some(guard) => Ok(deserialize_stream(guard.value())),
None => Err(KinesisErrorResponse::client_error(
constants::RESOURCE_NOT_FOUND,
Some(&format!(
"Stream {} under account {} not found.",
name, self.aws_account_id
)),
)),
}
}
pub async fn put_stream(&self, name: &str, stream: Stream) {
let bytes = serialize_stream(&stream);
let write_txn = self.db.begin_write().unwrap();
{
let mut table = write_txn.open_table(STREAMS).unwrap();
table.insert(name, bytes.as_slice()).unwrap();
}
write_txn.commit().unwrap();
}
pub async fn delete_stream(&self, name: &str) {
let write_txn = self.db.begin_write().unwrap();
{
let mut streams = write_txn.open_table(STREAMS).unwrap();
streams.remove(name).unwrap();
let mut records = write_txn.open_table(RECORDS).unwrap();
let prefix = format!("{name}\0");
let prefix_end = format!("{name}\x01");
let keys_to_remove: Vec<String> = records
.range(prefix.as_str()..prefix_end.as_str())
.unwrap()
.map(|r| {
let (k, _) = r.unwrap();
k.value().to_string()
})
.collect();
for key in keys_to_remove {
records.remove(key.as_str()).unwrap();
}
}
write_txn.commit().unwrap();
}
pub async fn contains_stream(&self, name: &str) -> bool {
let read_txn = self.db.begin_read().unwrap();
let table = read_txn.open_table(STREAMS).unwrap();
table.get(name).unwrap().is_some()
}
pub async fn list_stream_names(&self) -> Vec<String> {
let read_txn = self.db.begin_read().unwrap();
let table = read_txn.open_table(STREAMS).unwrap();
table
.iter()
.unwrap()
.map(|r| {
let (k, _) = r.unwrap();
k.value().to_string()
})
.collect()
}
pub async fn update_stream<F, R>(&self, name: &str, f: F) -> Result<R, KinesisErrorResponse>
where
F: FnOnce(&mut Stream) -> Result<R, KinesisErrorResponse>,
{
let write_txn = self.db.begin_write().unwrap();
let result;
{
let mut table = write_txn.open_table(STREAMS).unwrap();
let bytes = table.get(name).unwrap().ok_or_else(|| {
KinesisErrorResponse::client_error(
"ResourceNotFoundException",
Some(&format!(
"Stream {} under account {} not found.",
name, self.aws_account_id
)),
)
})?;
let mut stream = deserialize_stream(bytes.value());
drop(bytes);
result = f(&mut stream)?;
let new_bytes = serialize_stream(&stream);
table.insert(name, new_bytes.as_slice()).unwrap();
}
write_txn.commit().unwrap();
Ok(result)
}
pub async fn with_streams_write<F, R>(&self, f: F) -> R
where
F: FnOnce(&mut BTreeMap<String, Stream>, &StoreOptions, &str, &str) -> R,
{
let write_txn = self.db.begin_write().unwrap();
let result;
{
let mut table = write_txn.open_table(STREAMS).unwrap();
let mut streams: BTreeMap<String, Stream> = table
.iter()
.unwrap()
.map(|r| {
let (k, v) = r.unwrap();
(k.value().to_string(), deserialize_stream(v.value()))
})
.collect();
result = f(
&mut streams,
&self.options,
&self.aws_account_id,
&self.aws_region,
);
let existing_keys: Vec<String> = table
.iter()
.unwrap()
.map(|r| r.unwrap().0.value().to_string())
.collect();
for key in &existing_keys {
if !streams.contains_key(key) {
table.remove(key.as_str()).unwrap();
}
}
for (name, stream) in &streams {
let bytes = serialize_stream(stream);
table.insert(name.as_str(), bytes.as_slice()).unwrap();
}
}
write_txn.commit().unwrap();
result
}
pub async fn get_record_store(&self, stream_name: &str) -> BTreeMap<String, StoredRecord> {
let read_txn = self.db.begin_read().unwrap();
let table = read_txn.open_table(RECORDS).unwrap();
let prefix = format!("{stream_name}\0");
let prefix_end = format!("{stream_name}\x01");
table
.range(prefix.as_str()..prefix_end.as_str())
.unwrap()
.map(|r| {
let (k, v) = r.unwrap();
let full_key = k.value().to_string();
let shard_key = full_key
.strip_prefix(&prefix)
.unwrap_or(&full_key)
.to_string();
let record: StoredRecord = serde_json::from_slice(v.value()).unwrap();
(shard_key, record)
})
.collect()
}
pub async fn put_record(&self, stream_name: &str, key: &str, record: StoredRecord) {
let composite_key = record_key(stream_name, key);
let bytes = serde_json::to_vec(&record).unwrap();
let write_txn = self.db.begin_write().unwrap();
{
let mut table = write_txn.open_table(RECORDS).unwrap();
table
.insert(composite_key.as_str(), bytes.as_slice())
.unwrap();
}
write_txn.commit().unwrap();
}
pub async fn put_records_batch(&self, stream_name: &str, batch: Vec<(String, StoredRecord)>) {
let write_txn = self.db.begin_write().unwrap();
{
let mut table = write_txn.open_table(RECORDS).unwrap();
for (key, record) in &batch {
let composite_key = record_key(stream_name, key);
let bytes = serde_json::to_vec(record).unwrap();
table
.insert(composite_key.as_str(), bytes.as_slice())
.unwrap();
}
}
write_txn.commit().unwrap();
}
pub async fn delete_record_keys(&self, stream_name: &str, keys: &[String]) {
let write_txn = self.db.begin_write().unwrap();
{
let mut table = write_txn.open_table(RECORDS).unwrap();
for key in keys {
let composite_key = record_key(stream_name, key);
table.remove(composite_key.as_str()).unwrap();
}
}
write_txn.commit().unwrap();
}
pub async fn sum_open_shards(&self) -> u32 {
let read_txn = self.db.begin_read().unwrap();
let table = read_txn.open_table(STREAMS).unwrap();
table
.iter()
.unwrap()
.map(|r| {
let (_, v) = r.unwrap();
let stream = deserialize_stream(v.value());
stream
.shards
.iter()
.filter(|s| s.sequence_number_range.ending_sequence_number.is_none())
.count() as u32
})
.sum()
}
pub async fn get_records_range(
&self,
stream_name: &str,
range_start: &str,
range_end: &str,
) -> Vec<(String, StoredRecord)> {
let read_txn = self.db.begin_read().unwrap();
let table = read_txn.open_table(RECORDS).unwrap();
let (start, end) = record_range(stream_name, range_start, range_end);
table
.range(start.as_str()..end.as_str())
.unwrap()
.map(|r| {
let (k, v) = r.unwrap();
let full_key = k.value().to_string();
let prefix = format!("{stream_name}\0");
let shard_key = full_key
.strip_prefix(&prefix)
.unwrap_or(&full_key)
.to_string();
let record: StoredRecord = serde_json::from_slice(v.value()).unwrap();
(shard_key, record)
})
.collect()
}
pub async fn put_consumer(&self, consumer_arn: &str, consumer: Consumer) {
let bytes = serde_json::to_vec(&consumer).unwrap();
let write_txn = self.db.begin_write().unwrap();
{
let mut table = write_txn.open_table(CONSUMERS).unwrap();
table.insert(consumer_arn, bytes.as_slice()).unwrap();
}
write_txn.commit().unwrap();
}
pub async fn get_consumer(&self, consumer_arn: &str) -> Option<Consumer> {
let read_txn = self.db.begin_read().unwrap();
let table = read_txn.open_table(CONSUMERS).unwrap();
table
.get(consumer_arn)
.unwrap()
.map(|guard| serde_json::from_slice(guard.value()).unwrap())
}
pub async fn delete_consumer(&self, consumer_arn: &str) {
let write_txn = self.db.begin_write().unwrap();
{
let mut table = write_txn.open_table(CONSUMERS).unwrap();
table.remove(consumer_arn).unwrap();
}
write_txn.commit().unwrap();
}
pub async fn list_consumers_for_stream(&self, stream_arn: &str) -> Vec<Consumer> {
let prefix = format!("{stream_arn}/consumer/");
let prefix_end = format!("{stream_arn}/consumer0"); let read_txn = self.db.begin_read().unwrap();
let table = read_txn.open_table(CONSUMERS).unwrap();
table
.range(prefix.as_str()..prefix_end.as_str())
.unwrap()
.map(|r| {
let (_, v) = r.unwrap();
serde_json::from_slice(v.value()).unwrap()
})
.collect()
}
pub async fn find_consumer(&self, stream_arn: &str, consumer_name: &str) -> Option<Consumer> {
let consumers = self.list_consumers_for_stream(stream_arn).await;
consumers
.into_iter()
.find(|c| c.consumer_name == consumer_name)
}
pub async fn put_policy(&self, resource_arn: &str, policy: &str) {
let write_txn = self.db.begin_write().unwrap();
{
let mut table = write_txn.open_table(POLICIES).unwrap();
table.insert(resource_arn, policy).unwrap();
}
write_txn.commit().unwrap();
}
pub async fn get_policy(&self, resource_arn: &str) -> Option<String> {
let read_txn = self.db.begin_read().unwrap();
let table = read_txn.open_table(POLICIES).unwrap();
table
.get(resource_arn)
.unwrap()
.map(|guard| guard.value().to_string())
}
pub async fn delete_policy(&self, resource_arn: &str) {
let write_txn = self.db.begin_write().unwrap();
{
let mut table = write_txn.open_table(POLICIES).unwrap();
table.remove(resource_arn).unwrap();
}
write_txn.commit().unwrap();
}
pub fn stream_name_from_arn(&self, arn: &str) -> Option<String> {
arn.split("/").nth(1).map(|s| s.to_string())
}
pub async fn get_resource_tags(&self, resource_arn: &str) -> BTreeMap<String, String> {
let read_txn = self.db.begin_read().unwrap();
let table = read_txn.open_table(RESOURCE_TAGS).unwrap();
table
.get(resource_arn)
.unwrap()
.map(|guard| serde_json::from_slice(guard.value()).unwrap_or_default())
.unwrap_or_default()
}
pub async fn put_resource_tags(&self, resource_arn: &str, tags: &BTreeMap<String, String>) {
let bytes = serde_json::to_vec(tags).unwrap();
let write_txn = self.db.begin_write().unwrap();
{
let mut table = write_txn.open_table(RESOURCE_TAGS).unwrap();
table.insert(resource_arn, bytes.as_slice()).unwrap();
}
write_txn.commit().unwrap();
}
pub async fn get_account_settings(&self) -> Value {
let read_txn = self.db.begin_read().unwrap();
let table = read_txn.open_table(ACCOUNT_SETTINGS).unwrap();
table
.get("account_settings")
.unwrap()
.map(|guard| serde_json::from_slice(guard.value()).unwrap())
.unwrap_or(Value::Object(Default::default()))
}
pub async fn put_account_settings(&self, settings: &Value) {
let bytes = serde_json::to_vec(settings).unwrap();
let write_txn = self.db.begin_write().unwrap();
{
let mut table = write_txn.open_table(ACCOUNT_SETTINGS).unwrap();
table.insert("account_settings", bytes.as_slice()).unwrap();
}
write_txn.commit().unwrap();
}
}