use std::fmt::Display;
use std::future::Future;
use std::path::Path;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use bytes::Bytes;
use tokio::sync::Mutex;
use tokio::time::Instant;
use tracing::{debug, instrument, warn};
use async_nats::jetstream;
use cellos_core::ports::EventSink;
use cellos_core::{redact_url_if_echoed_in_text, CellosError, CloudEventV1};
const PUBLISH_MAX_ATTEMPTS: u32 = 3;
const PUBLISH_BACKOFF_BASE: Duration = Duration::from_millis(100);
const RECONNECT_BACKOFF_BASE: Duration = Duration::from_millis(100);
const RECONNECT_BACKOFF_CAP: Duration = Duration::from_secs(180);
pub(crate) fn reconnect_backoff(attempt: u32) -> Duration {
let factor = 1u64.checked_shl(attempt).unwrap_or(u64::MAX);
let raw = RECONNECT_BACKOFF_BASE
.checked_mul(u32::try_from(factor).unwrap_or(u32::MAX))
.unwrap_or(RECONNECT_BACKOFF_CAP);
raw.min(RECONNECT_BACKOFF_CAP)
}
pub const TENANT_ID_PLACEHOLDER: &str = "{tenantId}";
pub const TENANT_ID_DEFAULT_TOKEN: &str = "single";
pub fn resolve_tenant_subject(template: &str, event: &CloudEventV1) -> String {
if !template.contains(TENANT_ID_PLACEHOLDER) {
return template.to_string();
}
let tenant = event
.data
.as_ref()
.and_then(|d| d.get("tenantId"))
.and_then(|v| v.as_str())
.unwrap_or(TENANT_ID_DEFAULT_TOKEN);
template.replace(TENANT_ID_PLACEHOLDER, tenant)
}
pub async fn with_retry<F, Fut, T, E>(max_attempts: u32, mut f: F) -> Result<T, E>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, E>>,
E: Display,
{
let mut attempt: u32 = 1;
loop {
match f().await {
Ok(value) => return Ok(value),
Err(err) if attempt >= max_attempts => return Err(err),
Err(err) => {
let backoff = PUBLISH_BACKOFF_BASE.saturating_mul(4u32.saturating_pow(attempt - 1));
warn!(
attempt,
max_attempts,
backoff_ms = backoff.as_millis() as u64,
error = %err,
"publish attempt failed; retrying after backoff"
);
tokio::time::sleep(backoff).await;
attempt += 1;
}
}
}
}
#[doc(hidden)]
#[derive(Debug, Clone)]
pub enum ReconnectState {
Connected,
Reconnecting {
attempt: u32,
next_after: Instant,
},
}
#[doc(hidden)]
#[async_trait]
pub trait Publisher: Send + Sync {
async fn publish(&self, subject: String, payload: Bytes) -> Result<(), CellosError>;
}
struct JetStreamPublisher {
context: jetstream::Context,
}
#[async_trait]
impl Publisher for JetStreamPublisher {
async fn publish(&self, subject: String, payload: Bytes) -> Result<(), CellosError> {
self.context
.publish(subject, payload)
.await
.map_err(|e| CellosError::EventSink(format!("jetstream publish: {e}")))?;
Ok(())
}
}
pub struct JetStreamEventSink {
publisher: Arc<dyn Publisher>,
subject: String,
state: Arc<Mutex<ReconnectState>>,
}
impl JetStreamEventSink {
pub async fn connect(nats_url: &str, subject: impl Into<String>) -> Result<Self, CellosError> {
Self::connect_with_root_ca(nats_url, subject, None).await
}
pub async fn connect_with_root_ca(
nats_url: &str,
subject: impl Into<String>,
root_ca_pem_file: Option<&Path>,
) -> Result<Self, CellosError> {
let mut opts = async_nats::ConnectOptions::new();
if let Some(path) = root_ca_pem_file {
opts = opts.add_root_certificates(path.to_path_buf());
}
let conn = opts.connect(nats_url).await.map_err(|e| {
let msg = redact_url_if_echoed_in_text(&e.to_string(), nats_url);
CellosError::EventSink(format!("nats connect: {msg}"))
})?;
let context = jetstream::new(conn);
Ok(Self {
publisher: Arc::new(JetStreamPublisher { context }),
subject: subject.into(),
state: Arc::new(Mutex::new(ReconnectState::Connected)),
})
}
#[doc(hidden)]
pub fn from_publisher(publisher: Arc<dyn Publisher>, subject: impl Into<String>) -> Self {
Self {
publisher,
subject: subject.into(),
state: Arc::new(Mutex::new(ReconnectState::Connected)),
}
}
#[doc(hidden)]
pub async fn debug_state(&self) -> ReconnectState {
self.state.lock().await.clone()
}
}
#[async_trait]
impl EventSink for JetStreamEventSink {
#[instrument(skip(self, event), fields(ce_id = %event.id, ce_type = %event.ty))]
async fn emit(&self, event: &CloudEventV1) -> Result<(), CellosError> {
let payload = serde_json::to_vec(event)
.map_err(|e| CellosError::EventSink(format!("serialize CloudEvent: {e}")))?;
let payload = Bytes::from(payload);
{
let state = self.state.lock().await;
if let ReconnectState::Reconnecting {
attempt,
next_after,
} = *state
{
let now = Instant::now();
if now < next_after {
let wait_ms = next_after.saturating_duration_since(now).as_millis() as u64;
return Err(CellosError::EventSink(format!(
"jetstream sink in reconnecting state (attempt={attempt}, next probe in {wait_ms}ms)"
)));
}
}
}
let resolved_subject = resolve_tenant_subject(&self.subject, event);
let publish_result = with_retry(PUBLISH_MAX_ATTEMPTS, || {
let publisher = Arc::clone(&self.publisher);
let subject = resolved_subject.clone();
let payload = payload.clone();
async move { publisher.publish(subject, payload).await }
})
.await;
match publish_result {
Ok(()) => {
let mut state = self.state.lock().await;
if matches!(*state, ReconnectState::Reconnecting { .. }) {
debug!("jetstream sink: publish recovered; resetting backoff to Connected");
}
*state = ReconnectState::Connected;
Ok(())
}
Err(e) => {
let mut state = self.state.lock().await;
let next_attempt = match *state {
ReconnectState::Connected => 0,
ReconnectState::Reconnecting { attempt, .. } => attempt.saturating_add(1),
};
let backoff = reconnect_backoff(next_attempt);
let next_after = Instant::now() + backoff;
warn!(
attempt = next_attempt,
backoff_ms = backoff.as_millis() as u64,
error = %e,
"jetstream sink: publish failed; entering reconnecting state"
);
*state = ReconnectState::Reconnecting {
attempt: next_attempt,
next_after,
};
Err(e)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
fn rt() -> tokio::runtime::Runtime {
tokio::runtime::Builder::new_current_thread()
.enable_time()
.start_paused(true)
.build()
.unwrap()
}
#[test]
fn with_retry_succeeds_on_first_try() {
let calls = Arc::new(AtomicU32::new(0));
let result: Result<u32, &'static str> = rt().block_on(async {
let calls = calls.clone();
with_retry(3, || {
let calls = calls.clone();
async move {
calls.fetch_add(1, Ordering::SeqCst);
Ok::<u32, &'static str>(42)
}
})
.await
});
assert_eq!(result, Ok(42));
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
#[test]
fn with_retry_recovers_after_transient_failures() {
let calls = Arc::new(AtomicU32::new(0));
let calls_for_assert = calls.clone();
let result: Result<&'static str, &'static str> = rt().block_on(async move {
with_retry(3, || {
let calls = calls.clone();
async move {
let n = calls.fetch_add(1, Ordering::SeqCst) + 1;
if n < 3 {
Err("transient")
} else {
Ok("ok")
}
}
})
.await
});
assert_eq!(result, Ok("ok"));
assert_eq!(calls_for_assert.load(Ordering::SeqCst), 3);
}
#[test]
fn with_retry_returns_last_error_after_exhaustion() {
let calls = Arc::new(AtomicU32::new(0));
let calls_for_assert = calls.clone();
let result: Result<(), String> = rt().block_on(async move {
with_retry(3, || {
let calls = calls.clone();
async move {
let n = calls.fetch_add(1, Ordering::SeqCst) + 1;
Err::<(), String>(format!("fail-{n}"))
}
})
.await
});
assert_eq!(result, Err("fail-3".into()));
assert_eq!(calls_for_assert.load(Ordering::SeqCst), 3);
}
fn ce(data: Option<serde_json::Value>) -> CloudEventV1 {
CloudEventV1 {
specversion: "1.0".into(),
id: "ce-1".into(),
source: "test".into(),
ty: "dev.cellos.events.cell.lifecycle.v1.started".into(),
datacontenttype: Some("application/json".into()),
data,
time: None,
traceparent: None,
}
}
#[test]
fn resolve_tenant_subject_template_without_placeholder_is_passthrough() {
let event = ce(Some(serde_json::json!({"tenantId": "acme"})));
assert_eq!(
resolve_tenant_subject("cellos.events.v1", &event),
"cellos.events.v1"
);
}
#[test]
fn resolve_tenant_subject_substitutes_when_tenant_present() {
let event = ce(Some(serde_json::json!({"tenantId": "acme"})));
assert_eq!(
resolve_tenant_subject("cellos.events.{tenantId}.v1", &event),
"cellos.events.acme.v1"
);
}
#[test]
fn resolve_tenant_subject_uses_sentinel_when_tenant_absent() {
let event = ce(Some(serde_json::json!({"cellId": "c1"})));
assert_eq!(
resolve_tenant_subject("cellos.events.{tenantId}.v1", &event),
"cellos.events.single.v1"
);
}
#[test]
fn resolve_tenant_subject_uses_sentinel_when_data_missing() {
let event = ce(None);
assert_eq!(
resolve_tenant_subject("cellos.events.{tenantId}.v1", &event),
"cellos.events.single.v1"
);
}
#[test]
fn with_retry_single_attempt_does_not_retry() {
let calls = Arc::new(AtomicU32::new(0));
let calls_for_assert = calls.clone();
let result: Result<(), &'static str> = rt().block_on(async move {
with_retry(1, || {
let calls = calls.clone();
async move {
calls.fetch_add(1, Ordering::SeqCst);
Err::<(), &'static str>("nope")
}
})
.await
});
assert_eq!(result, Err("nope"));
assert_eq!(calls_for_assert.load(Ordering::SeqCst), 1);
}
#[test]
fn reconnect_backoff_schedule_is_exponential_and_capped() {
assert_eq!(reconnect_backoff(0), Duration::from_millis(100));
assert_eq!(reconnect_backoff(1), Duration::from_millis(200));
assert_eq!(reconnect_backoff(2), Duration::from_millis(400));
assert_eq!(reconnect_backoff(3), Duration::from_millis(800));
assert_eq!(reconnect_backoff(4), Duration::from_millis(1600));
assert_eq!(reconnect_backoff(20), RECONNECT_BACKOFF_CAP);
assert_eq!(reconnect_backoff(u32::MAX), RECONNECT_BACKOFF_CAP);
}
}