use crate::constants;
use crate::error::KinesisErrorResponse;
use crate::event_stream;
use crate::sequence;
use crate::store::Store;
use crate::types::{EpochSeconds, ResponseRecord, ShardIteratorType, StreamStatus};
use crate::util::current_time_ms;
use axum::body::Body;
use bytes::Bytes;
use serde_json::{Value, json};
const POLL_INTERVAL_MS: u64 = 200;
pub async fn execute_streaming(
store: &Store,
data: Value,
content_type: &str,
) -> Result<Body, KinesisErrorResponse> {
let consumer_arn = data[constants::CONSUMER_ARN].as_str().unwrap_or("");
let shard_id_input = data[constants::SHARD_ID].as_str().unwrap_or("");
let starting_position = &data[constants::STARTING_POSITION];
if consumer_arn.is_empty() {
return Err(KinesisErrorResponse::client_error(
constants::INVALID_ARGUMENT,
Some("ConsumerARN is required."),
));
}
if shard_id_input.is_empty() {
return Err(KinesisErrorResponse::client_error(
constants::INVALID_ARGUMENT,
Some("ShardId is required."),
));
}
let consumer = store.get_consumer(consumer_arn).await.ok_or_else(|| {
KinesisErrorResponse::client_error(
constants::RESOURCE_NOT_FOUND,
Some(&format!("Consumer {} not found.", consumer_arn)),
)
})?;
let stream_arn_end = consumer_arn.find("/consumer/").ok_or_else(|| {
KinesisErrorResponse::client_error(
constants::INVALID_ARGUMENT,
Some("Invalid ConsumerARN format."),
)
})?;
let stream_arn = &consumer_arn[..stream_arn_end];
let stream_name = store.stream_name_from_arn(stream_arn).ok_or_else(|| {
KinesisErrorResponse::client_error(
constants::RESOURCE_NOT_FOUND,
Some("Could not resolve stream from ConsumerARN."),
)
})?;
let stream = store.get_stream(&stream_name).await?;
if stream.stream_status != StreamStatus::Active {
return Err(KinesisErrorResponse::stream_not_active(
&stream_name,
&store.aws_account_id,
));
}
if consumer.consumer_status != crate::types::ConsumerStatus::Active {
return Err(KinesisErrorResponse::client_error(
constants::RESOURCE_IN_USE,
Some("Consumer is not ACTIVE."),
));
}
let (shard_id, shard_ix) = sequence::resolve_shard_id(shard_id_input).map_err(|_| {
KinesisErrorResponse::client_error(
constants::RESOURCE_NOT_FOUND,
Some(&format!("Shard {} not found.", shard_id_input)),
)
})?;
if shard_ix >= stream.shards.len() as i64 {
return Err(KinesisErrorResponse::client_error(
constants::RESOURCE_NOT_FOUND,
Some(&format!("Shard {} not found.", shard_id)),
));
}
let iterator_type: ShardIteratorType = serde_json::from_value(
starting_position
.get("Type")
.cloned()
.unwrap_or(serde_json::Value::Null),
)
.map_err(|_| {
KinesisErrorResponse::client_error(
constants::INVALID_ARGUMENT,
Some(
"StartingPosition.Type is required and must be one of: TRIM_HORIZON, LATEST, AT_SEQUENCE_NUMBER, AFTER_SEQUENCE_NUMBER, AT_TIMESTAMP.",
),
)
})?;
let shard_seq = &stream.shards[shard_ix as usize]
.sequence_number_range
.starting_sequence_number;
let shard_seq_obj = sequence::parse_sequence(shard_seq)
.map_err(|_| KinesisErrorResponse::server_error(None, None))?;
let now = current_time_ms();
tracing::trace!(
shard = %shard_id,
?iterator_type,
"subscribe: starting position"
);
let start_seq = match iterator_type {
ShardIteratorType::TrimHorizon => shard_seq.clone(),
ShardIteratorType::Latest => {
let seq_ix = store.current_shard_seq(&stream_name, shard_ix).await;
sequence::stringify_sequence(&sequence::SeqObj {
shard_create_time: shard_seq_obj.shard_create_time,
seq_ix: Some(seq_ix),
seq_time: Some(now),
shard_ix: shard_seq_obj.shard_ix,
byte1: None,
seq_rand: None,
version: 2,
})
}
ShardIteratorType::AtSequenceNumber => starting_position
.get("SequenceNumber")
.and_then(|v| v.as_str())
.unwrap_or(shard_seq)
.to_string(),
ShardIteratorType::AfterSequenceNumber => {
let seq_str = starting_position
.get("SequenceNumber")
.and_then(|v| v.as_str())
.unwrap_or(shard_seq);
let seq_obj = sequence::parse_sequence(seq_str).map_err(|_| {
KinesisErrorResponse::client_error(
constants::INVALID_ARGUMENT,
Some("Invalid SequenceNumber."),
)
})?;
sequence::increment_sequence(&seq_obj, None)
}
ShardIteratorType::AtTimestamp => {
let ts = starting_position
.get("Timestamp")
.and_then(|v| v.as_f64())
.unwrap_or(0.0);
let range_start = format!("{}/", sequence::shard_ix_to_hex(shard_ix));
let range_end = sequence::shard_ix_to_hex(shard_ix + 1);
let found_seq = store
.find_first_record_at_timestamp(&stream_name, &range_start, &range_end, ts)
.await
.and_then(|(key, _)| key.split('/').nth(1).map(|s| s.to_string()));
found_seq.unwrap_or_else(|| shard_seq.clone())
}
};
tracing::trace!(consumer_arn, shard = %shard_id, stream = %stream_name, "shard subscription started");
let store = store.clone();
let stream_name = stream_name.to_string();
let shard_id = shard_id.to_string();
let is_cbor = content_type == constants::CONTENT_TYPE_CBOR;
let event_record_limit = store.options.subscribe_to_shard_event_record_limit;
let max_subscription_ms = store.options.subscribe_to_shard_session_ms;
let stream = async_stream::stream! {
let mut current_seq = start_seq;
let start_time = current_time_ms();
yield Ok::<Bytes, std::io::Error>(Bytes::from(event_stream::encode_initial_response()));
loop {
let now = current_time_ms();
if now - start_time >= max_subscription_ms {
break;
}
let stream_data = match store.get_stream(&stream_name).await {
Ok(s) => s,
Err(_) => break,
};
let cutoff_time = now - (stream_data.retention_period_hours as u64 * 60 * 60 * 1000);
let cutoff_timestamp = (cutoff_time / 1000) as f64;
let range_start = format!("{}/{}", sequence::shard_ix_to_hex(shard_ix), current_seq);
let range_end = sequence::shard_ix_to_hex(shard_ix + 1);
let range_records = store
.get_records_range_limited(&stream_name, &range_start, &range_end, event_record_limit)
.await;
let mut records: Vec<ResponseRecord<'_>> = Vec::with_capacity(range_records.len());
let mut last_seq_num: Option<&str> = None;
for (key, record) in &range_records {
let seq_num = match key.split('/').nth(1) {
Some(s) => s,
None => continue,
};
if record.approximate_arrival_timestamp < cutoff_timestamp {
continue;
}
records.push(ResponseRecord {
data: &record.data,
partition_key: &record.partition_key,
sequence_number: seq_num,
approximate_arrival_timestamp: EpochSeconds(record.approximate_arrival_timestamp),
});
last_seq_num = Some(seq_num);
}
let continuation_seq = match last_seq_num.and_then(|s| sequence::parse_sequence(s).ok()) {
Some(ref last) => sequence::increment_sequence(last, None),
None => current_seq.clone(),
};
let mut child_shards: Vec<Value> = Vec::new();
let shard = &stream_data.shards[shard_ix as usize];
let shard_closed = shard.sequence_number_range.ending_sequence_number.is_some();
if shard_closed {
for s in &stream_data.shards {
let is_child = s.parent_shard_id.as_deref() == Some(&shard_id)
|| s.adjacent_parent_shard_id.as_deref() == Some(&shard_id);
if is_child {
let mut parent_shards = vec![json!(shard_id)];
if let Some(ref adj) = s.adjacent_parent_shard_id
&& adj != &shard_id {
parent_shards.push(json!(adj));
}
child_shards.push(json!({
"ShardId": s.shard_id,
"ParentShards": parent_shards,
"HashKeyRange": s.hash_key_range,
}));
}
}
}
let millis_behind = 0u64;
let event = json!({
"Records": records,
"ContinuationSequenceNumber": continuation_seq,
"MillisBehindLatest": millis_behind,
"ChildShards": child_shards,
});
let (payload, event_content_type) = if is_cbor {
let mut buf = Vec::new();
ciborium::into_writer(&crate::server::BlobAwareValue::new(&event), &mut buf).unwrap();
(buf, constants::CONTENT_TYPE_CBOR)
} else {
(serde_json::to_vec(&event).unwrap(), "application/json")
};
yield Ok(Bytes::from(event_stream::encode_subscribe_event(&payload, event_content_type)));
current_seq = continuation_seq;
if shard_closed && records.is_empty() {
break;
}
crate::runtime::sleep_ms(POLL_INTERVAL_MS).await;
}
};
Ok(Body::from_stream(stream))
}