use crate::constants;
use crate::error::KinesisErrorResponse;
use crate::sequence;
use crate::types::{Consumer, StoredRecord, Stream, StreamStatus};
use crate::util::current_time_ms;
use dashmap::DashMap;
use serde::Serialize;
use serde_json::Value;
use std::collections::BTreeMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::RwLock;
#[derive(Debug, thiserror::Error)]
pub enum StoreHealthError {
#[error("db read failed: {0}")]
ReadFailed(String),
#[error("table open failed: {0}")]
TableOpenFailed(String),
}
#[derive(Debug, Clone)]
pub struct StoreOptions {
pub create_stream_ms: u64,
pub delete_stream_ms: u64,
pub update_stream_ms: u64,
pub shard_limit: u32,
pub iterator_ttl_seconds: u64,
pub retention_check_interval_secs: u64,
pub aws_account_id: String,
pub aws_region: String,
}
impl Default for StoreOptions {
fn default() -> Self {
Self {
create_stream_ms: 500,
delete_stream_ms: 500,
update_stream_ms: 500,
shard_limit: 10,
iterator_ttl_seconds: 300,
retention_check_interval_secs: 0,
aws_account_id: "000000000000".to_string(),
aws_region: "us-east-1".to_string(),
}
}
}
struct ShardSeqState {
counter: AtomicU64,
}
struct StreamEntry {
stream: RwLock<Stream>,
shard_seq: RwLock<Vec<ShardSeqState>>,
}
pub struct SequenceAllocation {
pub shard_id: String,
pub seq_num: String,
pub stream_key: String,
pub now: u64,
}
type ShardRecords = DashMap<String, Arc<RwLock<BTreeMap<String, Vec<u8>>>>>;
struct StoreInner {
streams: DashMap<String, Arc<StreamEntry>>,
stream_records: DashMap<String, ShardRecords>,
consumers: DashMap<String, Vec<u8>>,
policies: DashMap<String, String>,
resource_tags: DashMap<String, BTreeMap<String, String>>,
account_settings: RwLock<Value>,
}
#[derive(Clone)]
pub struct Store {
pub options: StoreOptions,
pub aws_account_id: String,
pub aws_region: String,
inner: Arc<StoreInner>,
#[cfg(feature = "server")]
pub(crate) capture_writer: Option<crate::capture::CaptureWriter>,
}
fn shard_hex_from_key(key: &str) -> &str {
key.split('/').next().unwrap_or("")
}
fn ensure_shard_map<'a>(
stream_records: &'a DashMap<String, ShardRecords>,
stream_name: &str,
) -> dashmap::mapref::one::Ref<'a, String, ShardRecords> {
stream_records
.entry(stream_name.to_string())
.or_default()
.downgrade()
}
impl Store {
pub fn new(options: StoreOptions) -> Self {
#[cfg(feature = "server")]
{
Self::build(options, None)
}
#[cfg(not(feature = "server"))]
{
Self::build(options)
}
}
#[cfg(feature = "server")]
pub fn with_capture(
options: StoreOptions,
capture_writer: Option<crate::capture::CaptureWriter>,
) -> Self {
Self::build(options, capture_writer)
}
fn build(
options: StoreOptions,
#[cfg(feature = "server")] capture_writer: Option<crate::capture::CaptureWriter>,
) -> Self {
let aws_account_id: String = options
.aws_account_id
.chars()
.filter(|c| c.is_ascii_digit())
.collect();
if aws_account_id.len() != 12 {
tracing::warn!(
"AWS account ID has {} digits after stripping non-digits (expected 12)",
aws_account_id.len()
);
}
let aws_region = options.aws_region.clone();
Self {
options,
aws_account_id,
aws_region,
inner: Arc::new(StoreInner {
streams: DashMap::new(),
stream_records: DashMap::new(),
consumers: DashMap::new(),
policies: DashMap::new(),
resource_tags: DashMap::new(),
account_settings: RwLock::new(Value::Object(Default::default())),
}),
#[cfg(feature = "server")]
capture_writer,
}
}
pub async fn get_stream(&self, name: &str) -> Result<Stream, KinesisErrorResponse> {
let entry = self
.inner
.streams
.get(name)
.map(|e| e.value().clone())
.ok_or_else(|| KinesisErrorResponse::stream_not_found(name, &self.aws_account_id))?;
Ok(entry.stream.read().await.clone())
}
pub async fn put_stream(&self, name: &str, stream: Stream) {
if let Some(existing) = self.inner.streams.get(name) {
let mut guard = existing.stream.write().await;
*guard = stream;
} else {
let shard_seq = build_shard_seq(&stream);
let entry = Arc::new(StreamEntry {
stream: RwLock::new(stream),
shard_seq: RwLock::new(shard_seq),
});
self.inner.streams.insert(name.to_string(), entry);
}
}
pub async fn delete_stream(&self, name: &str) {
self.inner.streams.remove(name);
self.inner.stream_records.remove(name);
}
pub async fn contains_stream(&self, name: &str) -> bool {
self.inner.streams.contains_key(name)
}
pub async fn list_stream_names(&self) -> Vec<String> {
let mut names: Vec<String> = self.inner.streams.iter().map(|e| e.key().clone()).collect();
names.sort();
names
}
pub async fn update_stream<F, R>(&self, name: &str, f: F) -> Result<R, KinesisErrorResponse>
where
F: FnOnce(&mut Stream) -> Result<R, KinesisErrorResponse>,
{
let entry = self
.inner
.streams
.get(name)
.map(|e| e.value().clone())
.ok_or_else(|| KinesisErrorResponse::stream_not_found(name, &self.aws_account_id))?;
let mut stream = entry.stream.write().await;
let result = f(&mut stream)?;
sync_shard_seq(&entry, &stream).await;
Ok(result)
}
pub async fn sum_open_shards(&self) -> u32 {
let mut sum = 0u32;
for entry in self.inner.streams.iter() {
let stream = entry.value().stream.read().await;
sum += stream
.shards
.iter()
.filter(|s| s.sequence_number_range.ending_sequence_number.is_none())
.count() as u32;
}
sum
}
pub async fn get_record_store(&self, stream_name: &str) -> BTreeMap<String, StoredRecord> {
let shard_map = match self.inner.stream_records.get(stream_name) {
Some(m) => m,
None => return BTreeMap::new(),
};
let mut all_records = BTreeMap::new();
for shard_entry in shard_map.iter() {
let records = shard_entry.value().read().await;
for (k, v) in records.iter() {
let record: StoredRecord = postcard::from_bytes(v).unwrap();
all_records.insert(k.clone(), record);
}
}
all_records
}
pub async fn put_record<R: Serialize>(&self, stream_name: &str, key: &str, record: &R) {
let bytes = postcard::to_allocvec(record).unwrap();
let shard_hex = shard_hex_from_key(key);
let shard_map = ensure_shard_map(&self.inner.stream_records, stream_name);
let records_arc = shard_map
.entry(shard_hex.to_string())
.or_insert_with(|| Arc::new(RwLock::new(BTreeMap::new())))
.value()
.clone();
drop(shard_map);
let mut records = records_arc.write().await;
records.insert(key.to_string(), bytes);
}
pub async fn put_records_batch<R: Serialize>(&self, stream_name: &str, batch: &[(String, R)]) {
let pending: Vec<_> = {
let shard_map = ensure_shard_map(&self.inner.stream_records, stream_name);
batch
.iter()
.map(|(key, record)| {
let bytes = postcard::to_allocvec(record).unwrap();
let shard_hex = shard_hex_from_key(key);
let records_arc = shard_map
.entry(shard_hex.to_string())
.or_insert_with(|| Arc::new(RwLock::new(BTreeMap::new())))
.value()
.clone();
(records_arc, key.clone(), bytes)
})
.collect()
};
for (records_arc, key, bytes) in pending {
let mut records = records_arc.write().await;
records.insert(key, bytes);
}
}
pub async fn delete_expired_records(&self, stream_name: &str, retention_hours: u32) -> usize {
let now = crate::util::current_time_ms();
let cutoff_time = now - (retention_hours as u64 * 60 * 60 * 1000);
let shard_arcs: Vec<_> = match self.inner.stream_records.get(stream_name) {
Some(shard_map) => shard_map.iter().map(|e| e.value().clone()).collect(),
None => return 0,
};
let mut total_deleted = 0;
for records_arc in shard_arcs {
let keys_to_delete: Vec<String> = {
let records = records_arc.read().await;
records
.iter()
.filter_map(|(key, _)| {
let seq_num = key.split('/').nth(1)?;
let seq_obj = crate::sequence::parse_sequence(seq_num).ok()?;
if seq_obj.seq_time.unwrap_or(0) < cutoff_time {
Some(key.clone())
} else {
None
}
})
.collect()
};
if !keys_to_delete.is_empty() {
let mut records = records_arc.write().await;
for key in &keys_to_delete {
records.remove(key);
}
total_deleted += keys_to_delete.len();
}
}
total_deleted
}
pub async fn delete_record_keys(&self, stream_name: &str, keys: &[String]) {
let pending: Vec<_> = {
let shard_map = match self.inner.stream_records.get(stream_name) {
Some(m) => m,
None => return,
};
keys.iter()
.filter_map(|key| {
let shard_hex = shard_hex_from_key(key);
shard_map
.get(shard_hex)
.map(|r| (r.value().clone(), key.clone()))
})
.collect()
};
for (records_arc, key) in pending {
let mut records = records_arc.write().await;
records.remove(&key);
}
}
pub async fn get_records_range(
&self,
stream_name: &str,
range_start: &str,
range_end: &str,
) -> Vec<(String, StoredRecord)> {
let shard_map = match self.inner.stream_records.get(stream_name) {
Some(m) => m,
None => return Vec::new(),
};
let shard_hex = shard_hex_from_key(range_start);
let records_arc = match shard_map.get(shard_hex) {
Some(r) => r.value().clone(),
None => return Vec::new(),
};
let records = records_arc.read().await;
records
.range(range_start.to_string()..range_end.to_string())
.map(|(k, v)| {
let record: StoredRecord = postcard::from_bytes(v).unwrap();
(k.clone(), record)
})
.collect()
}
pub async fn get_records_range_limited(
&self,
stream_name: &str,
range_start: &str,
range_end: &str,
limit: usize,
) -> Vec<(String, StoredRecord)> {
let shard_map = match self.inner.stream_records.get(stream_name) {
Some(m) => m,
None => return Vec::new(),
};
let shard_hex = shard_hex_from_key(range_start);
let records_arc = match shard_map.get(shard_hex) {
Some(r) => r.value().clone(),
None => return Vec::new(),
};
let records = records_arc.read().await;
records
.range(range_start.to_string()..range_end.to_string())
.take(limit)
.map(|(k, v)| {
let record: StoredRecord = postcard::from_bytes(v).unwrap();
(k.clone(), record)
})
.collect()
}
pub async fn find_first_record_at_timestamp(
&self,
stream_name: &str,
range_start: &str,
range_end: &str,
timestamp: f64,
) -> Option<(String, StoredRecord)> {
let shard_map = self.inner.stream_records.get(stream_name)?;
let shard_hex = shard_hex_from_key(range_start);
let records_arc = shard_map.get(shard_hex)?.value().clone();
let records = records_arc.read().await;
for (k, v) in records.range(range_start.to_string()..range_end.to_string()) {
let record: StoredRecord = postcard::from_bytes(v).unwrap();
if record.approximate_arrival_timestamp >= timestamp {
return Some((k.clone(), record));
}
}
None
}
pub async fn put_consumer(&self, consumer_arn: &str, consumer: Consumer) {
let bytes = serde_json::to_vec(&consumer).unwrap();
self.inner.consumers.insert(consumer_arn.to_string(), bytes);
}
pub async fn get_consumer(&self, consumer_arn: &str) -> Option<Consumer> {
self.inner
.consumers
.get(consumer_arn)
.map(|entry| serde_json::from_slice(entry.value()).unwrap())
}
pub async fn delete_consumer(&self, consumer_arn: &str) {
self.inner.consumers.remove(consumer_arn);
}
pub async fn list_consumers_for_stream(&self, stream_arn: &str) -> Vec<Consumer> {
let prefix = format!("{stream_arn}/consumer/");
self.inner
.consumers
.iter()
.filter(|entry| entry.key().starts_with(&prefix))
.map(|entry| serde_json::from_slice(entry.value()).unwrap())
.collect()
}
pub async fn find_consumer(&self, stream_arn: &str, consumer_name: &str) -> Option<Consumer> {
let consumers = self.list_consumers_for_stream(stream_arn).await;
consumers
.into_iter()
.find(|c| c.consumer_name == consumer_name)
}
pub async fn put_policy(&self, resource_arn: &str, policy: &str) {
self.inner
.policies
.insert(resource_arn.to_string(), policy.to_string());
}
pub async fn get_policy(&self, resource_arn: &str) -> Option<String> {
self.inner
.policies
.get(resource_arn)
.map(|entry| entry.value().clone())
}
pub async fn delete_policy(&self, resource_arn: &str) {
self.inner.policies.remove(resource_arn);
}
pub fn stream_name_from_arn(&self, arn: &str) -> Option<String> {
arn.split("/").nth(1).map(|s| s.to_string())
}
pub fn resolve_stream_name(&self, data: &Value) -> Result<String, KinesisErrorResponse> {
let stream_name_raw = data[constants::STREAM_NAME].as_str().unwrap_or("");
let stream_arn = data[constants::STREAM_ARN].as_str().unwrap_or("");
if !stream_name_raw.is_empty() && !stream_arn.is_empty() {
return Err(KinesisErrorResponse::client_error(
constants::INVALID_ARGUMENT,
Some("StreamARN and StreamName cannot be provided together."),
));
}
if !stream_name_raw.is_empty() {
Ok(stream_name_raw.to_string())
} else if !stream_arn.is_empty() {
self.stream_name_from_arn(stream_arn).ok_or_else(|| {
KinesisErrorResponse::client_error(
constants::RESOURCE_NOT_FOUND,
Some("Could not resolve stream from ARN."),
)
})
} else {
Err(KinesisErrorResponse::client_error(
constants::INVALID_ARGUMENT,
Some("Either StreamName or StreamARN must be provided."),
))
}
}
pub async fn get_resource_tags(&self, resource_arn: &str) -> BTreeMap<String, String> {
self.inner
.resource_tags
.get(resource_arn)
.map(|entry| entry.value().clone())
.unwrap_or_default()
}
pub async fn put_resource_tags(&self, resource_arn: &str, tags: &BTreeMap<String, String>) {
self.inner
.resource_tags
.insert(resource_arn.to_string(), tags.clone());
}
pub async fn get_account_settings(&self) -> Value {
self.inner.account_settings.read().await.clone()
}
pub async fn put_account_settings(&self, settings: &Value) {
let mut guard = self.inner.account_settings.write().await;
*guard = settings.clone();
}
pub fn check_ready(&self) -> Result<(), StoreHealthError> {
Ok(())
}
pub async fn allocate_sequence(
&self,
name: &str,
hash_key: &u128,
) -> Result<SequenceAllocation, KinesisErrorResponse> {
let entry = self
.inner
.streams
.get(name)
.map(|e| e.value().clone())
.ok_or_else(|| KinesisErrorResponse::stream_not_found(name, &self.aws_account_id))?;
let stream = entry.stream.read().await;
if !matches!(
stream.stream_status,
StreamStatus::Active | StreamStatus::Updating
) {
return Err(KinesisErrorResponse::stream_not_found(
name,
&self.aws_account_id,
));
}
let (shard_ix, shard_id, shard_create_time) = route_hash_to_shard(&stream, hash_key);
drop(stream);
let shard_seq = entry.shard_seq.read().await;
let now = current_time_ms().max(shard_create_time);
let current_seq_ix = shard_seq
.get(shard_ix as usize)
.map(|s| s.counter.fetch_add(1, Ordering::Relaxed))
.unwrap_or(0);
drop(shard_seq);
let seq_num = sequence::stringify_sequence(&sequence::SeqObj {
shard_create_time,
seq_ix: Some(current_seq_ix),
byte1: None,
seq_time: Some(now),
seq_rand: None,
shard_ix,
version: 2,
});
let stream_key = format!("{}/{}", sequence::shard_ix_to_hex(shard_ix), seq_num);
Ok(SequenceAllocation {
shard_id,
seq_num,
stream_key,
now,
})
}
pub async fn allocate_sequences_batch(
&self,
name: &str,
hash_keys: &[u128],
) -> Result<Vec<SequenceAllocation>, KinesisErrorResponse> {
let entry = self
.inner
.streams
.get(name)
.map(|e| e.value().clone())
.ok_or_else(|| KinesisErrorResponse::stream_not_found(name, &self.aws_account_id))?;
let stream = entry.stream.read().await;
if !matches!(
stream.stream_status,
StreamStatus::Active | StreamStatus::Updating
) {
return Err(KinesisErrorResponse::stream_not_found(
name,
&self.aws_account_id,
));
}
let shard_seq = entry.shard_seq.read().await;
let mut allocations = Vec::with_capacity(hash_keys.len());
for hash_key in hash_keys {
let (shard_ix, shard_id, shard_create_time) = route_hash_to_shard(&stream, hash_key);
let now = current_time_ms().max(shard_create_time);
let current_seq_ix = shard_seq
.get(shard_ix as usize)
.map(|s| s.counter.fetch_add(1, Ordering::Relaxed))
.unwrap_or(0);
let seq_num = sequence::stringify_sequence(&sequence::SeqObj {
shard_create_time,
seq_ix: Some(current_seq_ix),
byte1: None,
seq_time: Some(now),
seq_rand: None,
shard_ix,
version: 2,
});
let stream_key = format!("{}/{}", sequence::shard_ix_to_hex(shard_ix), seq_num);
allocations.push(SequenceAllocation {
shard_id,
seq_num,
stream_key,
now,
});
}
Ok(allocations)
}
pub async fn current_shard_seq(&self, name: &str, shard_ix: i64) -> u64 {
if let Some(entry) = self.inner.streams.get(name).map(|e| e.value().clone()) {
let shard_seq = entry.shard_seq.read().await;
if let Some(seq_state) = shard_seq.get(shard_ix as usize) {
return seq_state.counter.load(Ordering::Relaxed);
}
}
0
}
}
fn build_shard_seq(stream: &Stream) -> Vec<ShardSeqState> {
stream
.shards
.iter()
.map(|_| ShardSeqState {
counter: AtomicU64::new(1),
})
.collect()
}
async fn sync_shard_seq(entry: &StreamEntry, stream: &Stream) {
let mut shard_seq = entry.shard_seq.write().await;
while shard_seq.len() < stream.shards.len() {
shard_seq.push(ShardSeqState {
counter: AtomicU64::new(1),
});
}
}
fn route_hash_to_shard(stream: &Stream, hash_key: &u128) -> (i64, String, u64) {
for (i, shard) in stream.shards.iter().enumerate() {
if shard.sequence_number_range.ending_sequence_number.is_none() {
let start = shard.hash_key_range.start_u128();
let end = shard.hash_key_range.end_u128();
if *hash_key >= start && *hash_key <= end {
let create_time =
sequence::parse_sequence(&shard.sequence_number_range.starting_sequence_number)
.map(|s| s.shard_create_time)
.unwrap_or(0);
return (i as i64, shard.shard_id.clone(), create_time);
}
}
}
(0, String::new(), 0)
}