use async_nats::jetstream;
use async_trait::async_trait;
use futures::StreamExt;
use std::collections::HashMap;
use std::sync::Arc;
use tirea_contract::storage::{
Committed, ThreadHead, ThreadListPage, ThreadListQuery, ThreadReader, ThreadStore,
ThreadStoreError, ThreadWriter, VersionPrecondition,
};
use tirea_contract::{CheckpointReason, Thread, ThreadChangeSet};
const STREAM_NAME: &str = "THREAD_DELTAS";
const SUBJECT_PREFIX: &str = "thread";
const DRAIN_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(2);
fn delta_subject(thread_id: &str) -> String {
format!("{SUBJECT_PREFIX}.{thread_id}.deltas")
}
pub struct NatsBufferedThreadWriter {
inner: Arc<dyn ThreadStore>,
jetstream: jetstream::Context,
}
impl NatsBufferedThreadWriter {
pub async fn new(
inner: Arc<dyn ThreadStore>,
jetstream: jetstream::Context,
) -> Result<Self, async_nats::Error> {
jetstream
.get_or_create_stream(jetstream::stream::Config {
name: STREAM_NAME.to_string(),
subjects: vec![format!("{SUBJECT_PREFIX}.*.deltas")],
retention: jetstream::stream::RetentionPolicy::WorkQueue,
storage: jetstream::stream::StorageType::File,
max_age: std::time::Duration::from_secs(24 * 3600), ..Default::default()
})
.await?;
Ok(Self { inner, jetstream })
}
pub async fn recover(&self) -> Result<usize, NatsBufferedThreadWriterError> {
let stream = self.stream().await?;
let consumer_name = format!("recovery_{}", uuid::Uuid::now_v7().simple());
let consumer = stream
.create_consumer(jetstream::consumer::pull::Config {
name: Some(consumer_name.clone()),
ack_policy: jetstream::consumer::AckPolicy::Explicit,
deliver_policy: jetstream::consumer::DeliverPolicy::All,
filter_subject: format!("{SUBJECT_PREFIX}.*.deltas"),
..Default::default()
})
.await
.map_err(|e| NatsBufferedThreadWriterError::JetStream(e.to_string()))?;
let mut pending: HashMap<String, Vec<(ThreadChangeSet, jetstream::Message)>> =
HashMap::new();
let mut messages = consumer
.messages()
.await
.map_err(|e| NatsBufferedThreadWriterError::JetStream(e.to_string()))?;
while let Ok(Some(Ok(msg))) = tokio::time::timeout(DRAIN_TIMEOUT, messages.next()).await {
let subject = msg.subject.to_string();
let parts: Vec<&str> = subject.split('.').collect();
if parts.len() != 3 {
let _ = msg.double_ack().await;
continue;
}
let thread_id = parts[1].to_string();
match serde_json::from_slice::<ThreadChangeSet>(&msg.payload) {
Ok(delta) => pending.entry(thread_id).or_default().push((delta, msg)),
Err(_) => {
let _ = msg.double_ack().await;
}
}
}
let mut recovered = 0usize;
for (thread_id, deltas_with_msgs) in pending {
match self
.materialize_and_save_thread(&thread_id, deltas_with_msgs)
.await
{
Ok(acked) => recovered += acked,
Err(e) => {
tracing::error!(
thread_id = %thread_id,
error = %e,
"recovery: failed to materialize thread"
);
}
}
}
let _ = stream.delete_consumer(&consumer_name).await;
Ok(recovered)
}
async fn stream(&self) -> Result<jetstream::stream::Stream, NatsBufferedThreadWriterError> {
self.jetstream
.get_stream(STREAM_NAME)
.await
.map_err(|e| NatsBufferedThreadWriterError::JetStream(e.to_string()))
}
async fn materialize_and_save_thread(
&self,
thread_id: &str,
deltas_with_msgs: Vec<(ThreadChangeSet, jetstream::Message)>,
) -> Result<usize, NatsBufferedThreadWriterError> {
if deltas_with_msgs.is_empty() {
return Ok(0);
}
let mut thread = match self.inner.load(thread_id).await? {
Some(head) => head.thread,
None => Thread::new(thread_id.to_string()),
};
for (delta, _) in &deltas_with_msgs {
delta.apply_to(&mut thread);
}
self.inner.save(&thread).await?;
let mut acked = 0usize;
for (_, msg) in deltas_with_msgs {
let _ = msg.double_ack().await;
acked += 1;
}
Ok(acked)
}
async fn flush_thread_buffer(
&self,
thread_id: &str,
) -> Result<usize, NatsBufferedThreadWriterError> {
let stream = self.stream().await?;
let consumer_name = format!("flush_{}", uuid::Uuid::now_v7().simple());
let consumer = stream
.create_consumer(jetstream::consumer::pull::Config {
name: Some(consumer_name.clone()),
ack_policy: jetstream::consumer::AckPolicy::Explicit,
deliver_policy: jetstream::consumer::DeliverPolicy::All,
filter_subject: delta_subject(thread_id),
..Default::default()
})
.await
.map_err(|e| NatsBufferedThreadWriterError::JetStream(e.to_string()))?;
let mut deltas_with_msgs = Vec::new();
let mut messages = consumer
.messages()
.await
.map_err(|e| NatsBufferedThreadWriterError::JetStream(e.to_string()))?;
while let Ok(Some(Ok(msg))) = tokio::time::timeout(DRAIN_TIMEOUT, messages.next()).await {
match serde_json::from_slice::<ThreadChangeSet>(&msg.payload) {
Ok(delta) => deltas_with_msgs.push((delta, msg)),
Err(_) => {
let _ = msg.double_ack().await;
}
}
}
let result = self
.materialize_and_save_thread(thread_id, deltas_with_msgs)
.await;
let _ = stream.delete_consumer(&consumer_name).await;
result
}
}
#[async_trait]
impl ThreadWriter for NatsBufferedThreadWriter {
async fn create(&self, thread: &Thread) -> Result<Committed, ThreadStoreError> {
self.inner.create(thread).await
}
async fn append(
&self,
thread_id: &str,
delta: &ThreadChangeSet,
precondition: VersionPrecondition,
) -> Result<Committed, ThreadStoreError> {
let payload = serde_json::to_vec(delta)
.map_err(|e| ThreadStoreError::Serialization(e.to_string()))?;
self.jetstream
.publish(delta_subject(thread_id), payload.into())
.await
.map_err(|e| ThreadStoreError::Io(std::io::Error::other(e)))?
.await
.map_err(|e| ThreadStoreError::Io(std::io::Error::other(e)))?;
if delta.reason == CheckpointReason::RunFinished {
self.flush_thread_buffer(thread_id)
.await
.map_err(|e| match e {
NatsBufferedThreadWriterError::JetStream(msg) => {
ThreadStoreError::Io(std::io::Error::other(msg))
}
NatsBufferedThreadWriterError::Storage(err) => err,
})?;
}
let version = match precondition {
VersionPrecondition::Any => 0,
VersionPrecondition::Exact(v) => v.saturating_add(1),
};
Ok(Committed { version })
}
async fn delete(&self, thread_id: &str) -> Result<(), ThreadStoreError> {
self.inner.delete(thread_id).await
}
async fn save(&self, thread: &Thread) -> Result<(), ThreadStoreError> {
self.inner.save(thread).await?;
if let Ok(stream) = self.jetstream.get_stream(STREAM_NAME).await {
let _ = stream.purge().filter(delta_subject(&thread.id)).await;
}
Ok(())
}
}
#[async_trait]
impl ThreadReader for NatsBufferedThreadWriter {
async fn load(&self, thread_id: &str) -> Result<Option<ThreadHead>, ThreadStoreError> {
self.inner.load(thread_id).await
}
async fn list_threads(
&self,
query: &ThreadListQuery,
) -> Result<ThreadListPage, ThreadStoreError> {
self.inner.list_threads(query).await
}
}
#[derive(Debug, thiserror::Error)]
pub enum NatsBufferedThreadWriterError {
#[error("jetstream error: {0}")]
JetStream(String),
#[error("storage error: {0}")]
Storage(#[from] ThreadStoreError),
}