use crate::constants;
use crate::error::KinesisErrorResponse;
use crate::types::{Consumer, StoredRecord, Stream};
use redb::backends::InMemoryBackend;
use redb::{Database, ReadableDatabase, ReadableTable, TableDefinition};
use serde::Serialize;
use serde_json::Value;
use std::collections::BTreeMap;
use std::sync::Arc;
#[derive(Debug, thiserror::Error)]
pub enum StoreHealthError {
#[error("db read failed: {0}")]
ReadFailed(String),
#[error("table open failed: {0}")]
TableOpenFailed(String),
}
const STREAMS: TableDefinition<&str, &[u8]> = TableDefinition::new("streams");
const RECORDS: TableDefinition<&str, &[u8]> = TableDefinition::new("records");
const CONSUMERS: TableDefinition<&str, &[u8]> = TableDefinition::new("consumers");
const POLICIES: TableDefinition<&str, &str> = TableDefinition::new("policies");
const RESOURCE_TAGS: TableDefinition<&str, &[u8]> = TableDefinition::new("resource_tags");
const ACCOUNT_SETTINGS: TableDefinition<&str, &[u8]> = TableDefinition::new("account_settings");
#[derive(Debug, Clone)]
pub struct StoreOptions {
pub create_stream_ms: u64,
pub delete_stream_ms: u64,
pub update_stream_ms: u64,
pub shard_limit: u32,
pub iterator_ttl_seconds: u64,
pub retention_check_interval_secs: u64,
pub aws_account_id: String,
pub aws_region: String,
}
impl Default for StoreOptions {
fn default() -> Self {
Self {
create_stream_ms: 500,
delete_stream_ms: 500,
update_stream_ms: 500,
shard_limit: 10,
iterator_ttl_seconds: 300,
retention_check_interval_secs: 0,
aws_account_id: "000000000000".to_string(),
aws_region: "us-east-1".to_string(),
}
}
}
#[derive(Clone)]
pub struct Store {
pub options: StoreOptions,
pub aws_account_id: String,
pub aws_region: String,
db: Arc<Database>,
pub(crate) capture_writer: Option<crate::capture::CaptureWriter>,
}
fn serialize_stream(stream: &Stream) -> Vec<u8> {
let mut val = serde_json::to_value(stream).unwrap();
let obj = val.as_object_mut().unwrap();
obj.insert(
"_seq_ix".to_string(),
serde_json::to_value(&stream.seq_ix).unwrap(),
);
obj.insert(
"_tags".to_string(),
serde_json::to_value(&stream.tags).unwrap(),
);
if let Some(ref key_id) = stream.key_id {
obj.insert("_key_id".to_string(), serde_json::to_value(key_id).unwrap());
}
obj.insert(
"_warm_throughput_mibps".to_string(),
serde_json::to_value(stream.warm_throughput_mibps).unwrap(),
);
obj.insert(
"_max_record_size_kib".to_string(),
serde_json::to_value(stream.max_record_size_kib).unwrap(),
);
serde_json::to_vec(&val).unwrap()
}
fn deserialize_stream(bytes: &[u8]) -> Stream {
let val: Value = serde_json::from_slice(bytes).unwrap();
let mut stream: Stream = serde_json::from_value(val.clone()).unwrap();
if let Some(arr) = val.get("_seq_ix") {
stream.seq_ix = serde_json::from_value(arr.clone()).unwrap_or_default();
}
if let Some(obj) = val.get("_tags") {
stream.tags = serde_json::from_value(obj.clone()).unwrap_or_default();
}
if let Some(key_id) = val.get("_key_id") {
stream.key_id = serde_json::from_value(key_id.clone()).ok();
}
if let Some(v) = val.get("_warm_throughput_mibps") {
stream.warm_throughput_mibps = serde_json::from_value(v.clone()).unwrap_or(0);
}
if let Some(v) = val.get("_max_record_size_kib") {
stream.max_record_size_kib = serde_json::from_value(v.clone()).unwrap_or(1024);
}
stream
}
fn record_key(stream_name: &str, shard_key: &str) -> String {
format!("{stream_name}\0{shard_key}")
}
fn record_range(stream_name: &str, start: &str, end: &str) -> (String, String) {
(
format!("{stream_name}\0{start}"),
format!("{stream_name}\0{end}"),
)
}
impl Store {
pub fn new(options: StoreOptions) -> Self {
Self::with_capture(options, None)
}
pub fn with_capture(
options: StoreOptions,
capture_writer: Option<crate::capture::CaptureWriter>,
) -> Self {
let aws_account_id: String = options
.aws_account_id
.chars()
.filter(|c| c.is_ascii_digit())
.collect();
if aws_account_id.len() != 12 {
tracing::warn!(
"AWS account ID has {} digits after stripping non-digits (expected 12)",
aws_account_id.len()
);
}
let aws_region = options.aws_region.clone();
let db = Database::builder()
.create_with_backend(InMemoryBackend::new())
.expect("Failed to create in-memory redb database");
let write_txn = db.begin_write().unwrap();
write_txn.open_table(STREAMS).unwrap();
write_txn.open_table(RECORDS).unwrap();
write_txn.open_table(CONSUMERS).unwrap();
write_txn.open_table(POLICIES).unwrap();
write_txn.open_table(RESOURCE_TAGS).unwrap();
write_txn.open_table(ACCOUNT_SETTINGS).unwrap();
write_txn.commit().unwrap();
Self {
options,
aws_account_id,
aws_region,
db: Arc::new(db),
capture_writer,
}
}
pub async fn get_stream(&self, name: &str) -> Result<Stream, KinesisErrorResponse> {
let read_txn = self.db.begin_read().unwrap();
let table = read_txn.open_table(STREAMS).unwrap();
match table.get(name).unwrap() {
Some(guard) => Ok(deserialize_stream(guard.value())),
None => Err(KinesisErrorResponse::stream_not_found(
name,
&self.aws_account_id,
)),
}
}
pub async fn put_stream(&self, name: &str, stream: Stream) {
let bytes = serialize_stream(&stream);
let write_txn = self.db.begin_write().unwrap();
{
let mut table = write_txn.open_table(STREAMS).unwrap();
table.insert(name, bytes.as_slice()).unwrap();
}
write_txn.commit().unwrap();
}
pub async fn delete_stream(&self, name: &str) {
let write_txn = self.db.begin_write().unwrap();
{
let mut streams = write_txn.open_table(STREAMS).unwrap();
streams.remove(name).unwrap();
let mut records = write_txn.open_table(RECORDS).unwrap();
let prefix = format!("{name}\0");
let prefix_end = format!("{name}\x01");
let keys_to_remove: Vec<String> = records
.range(prefix.as_str()..prefix_end.as_str())
.unwrap()
.map(|r| {
let (k, _) = r.unwrap();
k.value().to_string()
})
.collect();
for key in keys_to_remove {
records.remove(key.as_str()).unwrap();
}
}
write_txn.commit().unwrap();
}
pub async fn contains_stream(&self, name: &str) -> bool {
let read_txn = self.db.begin_read().unwrap();
let table = read_txn.open_table(STREAMS).unwrap();
table.get(name).unwrap().is_some()
}
pub async fn list_stream_names(&self) -> Vec<String> {
let read_txn = self.db.begin_read().unwrap();
let table = read_txn.open_table(STREAMS).unwrap();
table
.iter()
.unwrap()
.map(|r| {
let (k, _) = r.unwrap();
k.value().to_string()
})
.collect()
}
pub async fn update_stream<F, R>(&self, name: &str, f: F) -> Result<R, KinesisErrorResponse>
where
F: FnOnce(&mut Stream) -> Result<R, KinesisErrorResponse>,
{
let write_txn = self.db.begin_write().unwrap();
let result;
{
let mut table = write_txn.open_table(STREAMS).unwrap();
let bytes = table.get(name).unwrap().ok_or_else(|| {
KinesisErrorResponse::stream_not_found(name, &self.aws_account_id)
})?;
let mut stream = deserialize_stream(bytes.value());
drop(bytes);
result = f(&mut stream)?;
let new_bytes = serialize_stream(&stream);
table.insert(name, new_bytes.as_slice()).unwrap();
}
write_txn.commit().unwrap();
Ok(result)
}
pub async fn with_streams_write<F, R>(&self, f: F) -> R
where
F: FnOnce(&mut BTreeMap<String, Stream>, &StoreOptions, &str, &str) -> R,
{
let write_txn = self.db.begin_write().unwrap();
let result;
{
let mut table = write_txn.open_table(STREAMS).unwrap();
let mut streams: BTreeMap<String, Stream> = table
.iter()
.unwrap()
.map(|r| {
let (k, v) = r.unwrap();
(k.value().to_string(), deserialize_stream(v.value()))
})
.collect();
result = f(
&mut streams,
&self.options,
&self.aws_account_id,
&self.aws_region,
);
let existing_keys: Vec<String> = table
.iter()
.unwrap()
.map(|r| r.unwrap().0.value().to_string())
.collect();
for key in &existing_keys {
if !streams.contains_key(key) {
table.remove(key.as_str()).unwrap();
}
}
for (name, stream) in &streams {
let bytes = serialize_stream(stream);
table.insert(name.as_str(), bytes.as_slice()).unwrap();
}
}
write_txn.commit().unwrap();
result
}
pub async fn get_record_store(&self, stream_name: &str) -> BTreeMap<String, StoredRecord> {
let read_txn = self.db.begin_read().unwrap();
let table = read_txn.open_table(RECORDS).unwrap();
let prefix = format!("{stream_name}\0");
let prefix_end = format!("{stream_name}\x01");
let prefix_len = stream_name.len() + 1;
table
.range(prefix.as_str()..prefix_end.as_str())
.unwrap()
.map(|r| {
let (k, v) = r.unwrap();
let full_key = k.value().to_string();
let shard_key = full_key[prefix_len..].to_string();
let record: StoredRecord = postcard::from_bytes(v.value()).unwrap();
(shard_key, record)
})
.collect()
}
pub async fn put_record<R: Serialize>(&self, stream_name: &str, key: &str, record: &R) {
let composite_key = record_key(stream_name, key);
let bytes = postcard::to_allocvec(record).unwrap();
let write_txn = self.db.begin_write().unwrap();
{
let mut table = write_txn.open_table(RECORDS).unwrap();
table
.insert(composite_key.as_str(), bytes.as_slice())
.unwrap();
}
write_txn.commit().unwrap();
}
pub async fn put_records_batch<R: Serialize>(&self, stream_name: &str, batch: &[(String, R)]) {
let write_txn = self.db.begin_write().unwrap();
{
let mut table = write_txn.open_table(RECORDS).unwrap();
for (key, record) in batch {
let composite_key = record_key(stream_name, key);
let bytes = postcard::to_allocvec(record).unwrap();
table
.insert(composite_key.as_str(), bytes.as_slice())
.unwrap();
}
}
write_txn.commit().unwrap();
}
pub async fn delete_expired_records(&self, stream_name: &str, retention_hours: u32) -> usize {
let now = crate::util::current_time_ms();
let cutoff_time = now - (retention_hours as u64 * 60 * 60 * 1000);
let prefix = format!("{stream_name}\0");
let prefix_end = format!("{stream_name}\x01");
let prefix_len = stream_name.len() + 1;
let keys_to_delete: Vec<String> = {
let read_txn = self.db.begin_read().unwrap();
let table = read_txn.open_table(RECORDS).unwrap();
table
.range(prefix.as_str()..prefix_end.as_str())
.unwrap()
.filter_map(|r| {
let (k, _) = r.unwrap();
let full_key = k.value().to_string();
let shard_key = &full_key[prefix_len..];
let seq_num = shard_key.split('/').nth(1)?;
let seq_obj = crate::sequence::parse_sequence(seq_num).ok()?;
if seq_obj.seq_time.unwrap_or(0) < cutoff_time {
Some(full_key)
} else {
None
}
})
.collect()
};
let count = keys_to_delete.len();
if count > 0 {
let write_txn = self.db.begin_write().unwrap();
{
let mut table = write_txn.open_table(RECORDS).unwrap();
for key in &keys_to_delete {
table.remove(key.as_str()).unwrap();
}
}
write_txn.commit().unwrap();
}
count
}
pub async fn delete_record_keys(&self, stream_name: &str, keys: &[String]) {
let write_txn = self.db.begin_write().unwrap();
{
let mut table = write_txn.open_table(RECORDS).unwrap();
for key in keys {
let composite_key = record_key(stream_name, key);
table.remove(composite_key.as_str()).unwrap();
}
}
write_txn.commit().unwrap();
}
pub async fn sum_open_shards(&self) -> u32 {
let read_txn = self.db.begin_read().unwrap();
let table = read_txn.open_table(STREAMS).unwrap();
table
.iter()
.unwrap()
.map(|r| {
let (_, v) = r.unwrap();
let stream = deserialize_stream(v.value());
stream
.shards
.iter()
.filter(|s| s.sequence_number_range.ending_sequence_number.is_none())
.count() as u32
})
.sum()
}
pub async fn get_records_range(
&self,
stream_name: &str,
range_start: &str,
range_end: &str,
) -> Vec<(String, StoredRecord)> {
let read_txn = self.db.begin_read().unwrap();
let table = read_txn.open_table(RECORDS).unwrap();
let (start, end) = record_range(stream_name, range_start, range_end);
let prefix_len = stream_name.len() + 1;
table
.range(start.as_str()..end.as_str())
.unwrap()
.map(|r| {
let (k, v) = r.unwrap();
let full_key = k.value().to_string();
let shard_key = full_key[prefix_len..].to_string();
let record: StoredRecord = postcard::from_bytes(v.value()).unwrap();
(shard_key, record)
})
.collect()
}
pub async fn get_records_range_limited(
&self,
stream_name: &str,
range_start: &str,
range_end: &str,
limit: usize,
) -> Vec<(String, StoredRecord)> {
let read_txn = self.db.begin_read().unwrap();
let table = read_txn.open_table(RECORDS).unwrap();
let (start, end) = record_range(stream_name, range_start, range_end);
let prefix_len = stream_name.len() + 1; table
.range(start.as_str()..end.as_str())
.unwrap()
.take(limit)
.map(|r| {
let (k, v) = r.unwrap();
let full_key = k.value().to_string();
let shard_key = full_key[prefix_len..].to_string();
let record: StoredRecord = postcard::from_bytes(v.value()).unwrap();
(shard_key, record)
})
.collect()
}
pub async fn find_first_record_at_timestamp(
&self,
stream_name: &str,
range_start: &str,
range_end: &str,
timestamp: f64,
) -> Option<(String, StoredRecord)> {
let read_txn = self.db.begin_read().unwrap();
let table = read_txn.open_table(RECORDS).unwrap();
let (start, end) = record_range(stream_name, range_start, range_end);
let prefix_len = stream_name.len() + 1;
for r in table.range(start.as_str()..end.as_str()).unwrap() {
let (k, v) = r.unwrap();
let record: StoredRecord = postcard::from_bytes(v.value()).unwrap();
if record.approximate_arrival_timestamp >= timestamp {
let full_key = k.value().to_string();
let shard_key = full_key[prefix_len..].to_string();
return Some((shard_key, record));
}
}
None
}
pub async fn put_consumer(&self, consumer_arn: &str, consumer: Consumer) {
let bytes = serde_json::to_vec(&consumer).unwrap();
let write_txn = self.db.begin_write().unwrap();
{
let mut table = write_txn.open_table(CONSUMERS).unwrap();
table.insert(consumer_arn, bytes.as_slice()).unwrap();
}
write_txn.commit().unwrap();
}
pub async fn get_consumer(&self, consumer_arn: &str) -> Option<Consumer> {
let read_txn = self.db.begin_read().unwrap();
let table = read_txn.open_table(CONSUMERS).unwrap();
table
.get(consumer_arn)
.unwrap()
.map(|guard| serde_json::from_slice(guard.value()).unwrap())
}
pub async fn delete_consumer(&self, consumer_arn: &str) {
let write_txn = self.db.begin_write().unwrap();
{
let mut table = write_txn.open_table(CONSUMERS).unwrap();
table.remove(consumer_arn).unwrap();
}
write_txn.commit().unwrap();
}
pub async fn list_consumers_for_stream(&self, stream_arn: &str) -> Vec<Consumer> {
let prefix = format!("{stream_arn}/consumer/");
let prefix_end = format!("{stream_arn}/consumer0"); let read_txn = self.db.begin_read().unwrap();
let table = read_txn.open_table(CONSUMERS).unwrap();
table
.range(prefix.as_str()..prefix_end.as_str())
.unwrap()
.map(|r| {
let (_, v) = r.unwrap();
serde_json::from_slice(v.value()).unwrap()
})
.collect()
}
pub async fn find_consumer(&self, stream_arn: &str, consumer_name: &str) -> Option<Consumer> {
let consumers = self.list_consumers_for_stream(stream_arn).await;
consumers
.into_iter()
.find(|c| c.consumer_name == consumer_name)
}
pub async fn put_policy(&self, resource_arn: &str, policy: &str) {
let write_txn = self.db.begin_write().unwrap();
{
let mut table = write_txn.open_table(POLICIES).unwrap();
table.insert(resource_arn, policy).unwrap();
}
write_txn.commit().unwrap();
}
pub async fn get_policy(&self, resource_arn: &str) -> Option<String> {
let read_txn = self.db.begin_read().unwrap();
let table = read_txn.open_table(POLICIES).unwrap();
table
.get(resource_arn)
.unwrap()
.map(|guard| guard.value().to_string())
}
pub async fn delete_policy(&self, resource_arn: &str) {
let write_txn = self.db.begin_write().unwrap();
{
let mut table = write_txn.open_table(POLICIES).unwrap();
table.remove(resource_arn).unwrap();
}
write_txn.commit().unwrap();
}
pub fn stream_name_from_arn(&self, arn: &str) -> Option<String> {
arn.split("/").nth(1).map(|s| s.to_string())
}
pub fn resolve_stream_name(&self, data: &Value) -> Result<String, KinesisErrorResponse> {
let stream_name_raw = data[constants::STREAM_NAME].as_str().unwrap_or("");
let stream_arn = data[constants::STREAM_ARN].as_str().unwrap_or("");
if !stream_name_raw.is_empty() && !stream_arn.is_empty() {
return Err(KinesisErrorResponse::client_error(
constants::INVALID_ARGUMENT,
Some("StreamARN and StreamName cannot be provided together."),
));
}
if !stream_name_raw.is_empty() {
Ok(stream_name_raw.to_string())
} else if !stream_arn.is_empty() {
self.stream_name_from_arn(stream_arn).ok_or_else(|| {
KinesisErrorResponse::client_error(
constants::RESOURCE_NOT_FOUND,
Some("Could not resolve stream from ARN."),
)
})
} else {
Err(KinesisErrorResponse::client_error(
constants::INVALID_ARGUMENT,
Some("Either StreamName or StreamARN must be provided."),
))
}
}
pub async fn get_resource_tags(&self, resource_arn: &str) -> BTreeMap<String, String> {
let read_txn = self.db.begin_read().unwrap();
let table = read_txn.open_table(RESOURCE_TAGS).unwrap();
table
.get(resource_arn)
.unwrap()
.map(|guard| serde_json::from_slice(guard.value()).unwrap_or_default())
.unwrap_or_default()
}
pub async fn put_resource_tags(&self, resource_arn: &str, tags: &BTreeMap<String, String>) {
let bytes = serde_json::to_vec(tags).unwrap();
let write_txn = self.db.begin_write().unwrap();
{
let mut table = write_txn.open_table(RESOURCE_TAGS).unwrap();
table.insert(resource_arn, bytes.as_slice()).unwrap();
}
write_txn.commit().unwrap();
}
pub async fn get_account_settings(&self) -> Value {
let read_txn = self.db.begin_read().unwrap();
let table = read_txn.open_table(ACCOUNT_SETTINGS).unwrap();
table
.get("account_settings")
.unwrap()
.map(|guard| serde_json::from_slice(guard.value()).unwrap())
.unwrap_or(Value::Object(Default::default()))
}
pub async fn put_account_settings(&self, settings: &Value) {
let bytes = serde_json::to_vec(settings).unwrap();
let write_txn = self.db.begin_write().unwrap();
{
let mut table = write_txn.open_table(ACCOUNT_SETTINGS).unwrap();
table.insert("account_settings", bytes.as_slice()).unwrap();
}
write_txn.commit().unwrap();
}
pub fn check_ready(&self) -> Result<(), StoreHealthError> {
let txn = self
.db
.begin_read()
.map_err(|e| StoreHealthError::ReadFailed(e.to_string()))?;
for table in [STREAMS, RECORDS, CONSUMERS, RESOURCE_TAGS, ACCOUNT_SETTINGS] {
txn.open_table(table)
.map_err(|e| StoreHealthError::TableOpenFailed(e.to_string()))?;
}
txn.open_table(POLICIES)
.map_err(|e| StoreHealthError::TableOpenFailed(e.to_string()))?;
Ok(())
}
}