use std::path::Path;
use std::sync::Arc;
use std::time::Duration;
use arrow::record_batch::RecordBatch;
use arrow_ipc::reader::StreamReader;
use arrow_ipc::writer::StreamWriter;
use arrow_schema::SchemaRef;
use async_nats::header::HeaderMap;
use async_nats::jetstream::consumer::{self, DeliverPolicy};
use async_nats::jetstream::stream::{Config as StreamConfig, RetentionPolicy};
use async_nats::jetstream::{self, Context as JetStreamContext};
use async_stream::try_stream;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use futures::{StreamExt, TryStreamExt};
use parking_lot::RwLock;
use std::collections::HashMap;
use crate::trigger::broker::{BrokerKind, TriggerBroker};
use crate::trigger::consumer::ConsumerOffsetSnapshot;
use crate::trigger::error::TriggerError;
use crate::trigger::ids::{SubscriptionId, TopicId};
use crate::trigger::offset::Offset;
use crate::trigger::predicate::Predicate;
use crate::trigger::subscription::{DeliveredBatch, Subscription};
use crate::trigger::topic::TopicDefinition;
const HDR_OFFSET: &str = "jammi-offset";
const HDR_PRODUCED_AT: &str = "jammi-produced-at-us";
pub struct JetStreamBroker {
context: JetStreamContext,
schemas: RwLock<HashMap<TopicId, SchemaRef>>,
retention: Duration,
}
impl JetStreamBroker {
pub async fn connect(url: &str, retention_seconds: u64) -> Result<Self, TriggerError> {
let client = async_nats::connect(url)
.await
.map_err(|e| TriggerError::Driver(format!("nats connect: {e}")))?;
Ok(Self::from_client(client, retention_seconds))
}
pub async fn connect_with_credentials(
url: &str,
retention_seconds: u64,
credentials_path: &Path,
) -> Result<Self, TriggerError> {
let client = async_nats::ConnectOptions::with_credentials_file(credentials_path)
.await
.map_err(|e| TriggerError::Driver(format!("nats creds: {e}")))?
.connect(url)
.await
.map_err(|e| TriggerError::Driver(format!("nats connect: {e}")))?;
Ok(Self::from_client(client, retention_seconds))
}
fn from_client(client: async_nats::Client, retention_seconds: u64) -> Self {
Self {
context: jetstream::new(client),
schemas: RwLock::new(HashMap::new()),
retention: Duration::from_secs(retention_seconds),
}
}
fn stream_name(topic_id: TopicId) -> String {
format!("jammi_topic_{}", topic_id.as_uuid().simple())
}
fn subject_for(topic_id: TopicId) -> String {
format!("jammi.topic.{}.batch", topic_id.as_uuid().simple())
}
fn retention_for(&self, topic: &TopicDefinition) -> Duration {
topic
.broker_metadata
.get("retention_seconds")
.and_then(|v| v.parse::<u64>().ok())
.map(Duration::from_secs)
.unwrap_or(self.retention)
}
}
#[async_trait]
impl TriggerBroker for JetStreamBroker {
async fn register_topic(&self, topic: &TopicDefinition) -> Result<(), TriggerError> {
let cfg = StreamConfig {
name: Self::stream_name(topic.id),
subjects: vec![Self::subject_for(topic.id)],
retention: RetentionPolicy::Limits,
max_age: self.retention_for(topic),
..Default::default()
};
self.context
.get_or_create_stream(cfg)
.await
.map_err(|e| TriggerError::Driver(format!("create_stream: {e}")))?;
self.schemas
.write()
.insert(topic.id, Arc::clone(&topic.schema));
Ok(())
}
async fn drop_topic(&self, topic_id: TopicId) -> Result<(), TriggerError> {
let name = Self::stream_name(topic_id);
self.context
.delete_stream(&name)
.await
.map_err(|e| TriggerError::Driver(format!("delete_stream {name}: {e}")))?;
self.schemas.write().remove(&topic_id);
Ok(())
}
async fn publish(
&self,
topic_id: TopicId,
batch: RecordBatch,
produced_at: DateTime<Utc>,
offset: u64,
) -> Result<Offset, TriggerError> {
let schema = self
.schemas
.read()
.get(&topic_id)
.cloned()
.ok_or_else(|| TriggerError::TopicNotFound(topic_id.to_string()))?;
let payload = encode_batch_ipc(&schema, &batch)?;
let mut headers = HeaderMap::new();
headers.insert(HDR_OFFSET, offset.to_string());
headers.insert(HDR_PRODUCED_AT, produced_at.timestamp_micros().to_string());
let subject = Self::subject_for(topic_id);
self.context
.publish_with_headers(subject, headers, payload.into())
.await
.map_err(|e| TriggerError::Driver(format!("publish: {e}")))?
.await
.map_err(|e| TriggerError::Driver(format!("publish-ack: {e}")))?;
Ok(Offset::new(offset, produced_at))
}
async fn subscribe(
&self,
topic_id: TopicId,
predicate: Predicate,
from_offset: Option<Offset>,
) -> Result<Subscription, TriggerError> {
let schema = self
.schemas
.read()
.get(&topic_id)
.cloned()
.ok_or_else(|| TriggerError::TopicNotFound(topic_id.to_string()))?;
let stream_name = Self::stream_name(topic_id);
let stream = self
.context
.get_stream(&stream_name)
.await
.map_err(|e| TriggerError::Driver(format!("get_stream {stream_name}: {e}")))?;
let deliver_policy = deliver_policy_for(from_offset);
let consumer = stream
.create_consumer(consumer::pull::Config {
deliver_policy,
filter_subject: Self::subject_for(topic_id),
ack_policy: consumer::AckPolicy::Explicit,
..Default::default()
})
.await
.map_err(|e| TriggerError::Driver(format!("create_consumer: {e}")))?;
let mut messages = consumer
.messages()
.await
.map_err(|e| TriggerError::Driver(format!("consumer messages: {e}")))?;
let inner = try_stream! {
while let Some(message) = messages.next().await {
let message = message
.map_err(|e| TriggerError::Driver(format!("recv: {e}")))?;
let delivered = decode_message(&schema, &message)?;
message
.ack()
.await
.map_err(|e| TriggerError::Driver(format!("ack: {e}")))?;
if let Some(filtered) = predicate.evaluate(&delivered.batch)? {
yield DeliveredBatch {
offset: delivered.offset,
produced_at: delivered.produced_at,
batch: filtered,
};
}
}
};
Ok(Subscription::new(SubscriptionId::new(), Box::pin(inner)))
}
async fn list_consumers(
&self,
topic_id: TopicId,
) -> Result<Vec<ConsumerOffsetSnapshot>, TriggerError> {
if !self.schemas.read().contains_key(&topic_id) {
return Err(TriggerError::TopicNotFound(topic_id.to_string()));
}
let stream_name = Self::stream_name(topic_id);
let stream = self
.context
.get_stream(&stream_name)
.await
.map_err(|e| TriggerError::Driver(format!("get_stream {stream_name}: {e}")))?;
let mut consumers = stream.consumers();
let mut snapshots: Vec<ConsumerOffsetSnapshot> = Vec::new();
while let Some(info) = consumers
.try_next()
.await
.map_err(|e| TriggerError::Driver(format!("list consumers {stream_name}: {e}")))?
{
snapshots.push(ConsumerOffsetSnapshot {
consumer_name: info.name,
topic_id,
last_delivered_stream_sequence: info.delivered.stream_sequence,
last_ack_stream_sequence: info.ack_floor.stream_sequence,
});
}
Ok(snapshots)
}
fn driver_kind(&self) -> BrokerKind {
BrokerKind::JetStream
}
}
fn encode_batch_ipc(schema: &SchemaRef, batch: &RecordBatch) -> Result<Vec<u8>, TriggerError> {
let mut buf: Vec<u8> = Vec::new();
{
let mut writer = StreamWriter::try_new(&mut buf, schema.as_ref())
.map_err(|e| TriggerError::Driver(format!("ipc encode: {e}")))?;
writer
.write(batch)
.map_err(|e| TriggerError::Driver(format!("ipc encode batch: {e}")))?;
writer
.finish()
.map_err(|e| TriggerError::Driver(format!("ipc finish: {e}")))?;
}
Ok(buf)
}
fn decode_message(
schema: &SchemaRef,
message: &async_nats::jetstream::Message,
) -> Result<DeliveredBatch, TriggerError> {
let headers = message
.headers
.as_ref()
.ok_or_else(|| TriggerError::Driver("message missing headers".into()))?;
let offset_value: u64 = headers
.get(HDR_OFFSET)
.ok_or_else(|| TriggerError::Driver(format!("message missing `{HDR_OFFSET}`")))?
.as_str()
.parse()
.map_err(|e| TriggerError::Driver(format!("`{HDR_OFFSET}` parse: {e}")))?;
let produced_at_us: i64 = headers
.get(HDR_PRODUCED_AT)
.ok_or_else(|| TriggerError::Driver(format!("message missing `{HDR_PRODUCED_AT}`")))?
.as_str()
.parse()
.map_err(|e| TriggerError::Driver(format!("`{HDR_PRODUCED_AT}` parse: {e}")))?;
let produced_at = DateTime::<Utc>::from_timestamp_micros(produced_at_us).ok_or_else(|| {
TriggerError::Driver(format!(
"`{HDR_PRODUCED_AT}` out of range: {produced_at_us}"
))
})?;
let cursor = std::io::Cursor::new(&message.payload[..]);
let mut reader = StreamReader::try_new(cursor, None)
.map_err(|e| TriggerError::Driver(format!("ipc decode: {e}")))?;
let batch = reader
.next()
.ok_or_else(|| TriggerError::Driver("ipc stream had no batch".into()))?
.map_err(|e| TriggerError::Driver(format!("ipc decode batch: {e}")))?;
if batch.schema().as_ref() != schema.as_ref() {
return Err(TriggerError::Driver(
"message schema does not match topic schema".into(),
));
}
Ok(DeliveredBatch {
offset: Offset::new(offset_value, produced_at),
produced_at,
batch,
})
}
fn deliver_policy_for(from_offset: Option<Offset>) -> DeliverPolicy {
match from_offset {
None => DeliverPolicy::New,
Some(_) => DeliverPolicy::All,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn replay_point_over_delivers_never_maps_to_a_native_start_sequence() {
let now = Utc::now();
for engine_offset in [0u64, 1, 7, 42, u64::MAX] {
let policy = deliver_policy_for(Some(Offset::new(engine_offset, now)));
assert!(
matches!(policy, DeliverPolicy::All),
"engine offset {engine_offset} must over-deliver (DeliverAll), got {policy:?}"
);
assert!(
!matches!(policy, DeliverPolicy::ByStartSequence { .. }),
"engine offset must never be conflated with a JetStream stream sequence"
);
}
assert!(matches!(deliver_policy_for(None), DeliverPolicy::New));
}
}