use crate::orderbook::serialization::{EventSerializer, JsonEventSerializer};
use crate::orderbook::trade::{TradeListener, TradeResult};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tracing::{error, trace, warn};
const DEFAULT_MAX_RETRIES: u32 = 3;
const BASE_RETRY_DELAY_MS: u64 = 10;
pub struct NatsTradePublisher {
jetstream: async_nats::jetstream::Context,
subject_prefix: String,
runtime: tokio::runtime::Handle,
sequence: AtomicU64,
publish_count: AtomicU64,
error_count: AtomicU64,
max_retries: u32,
serializer: Arc<dyn EventSerializer>,
}
impl NatsTradePublisher {
#[inline]
pub fn new(
jetstream: async_nats::jetstream::Context,
subject_prefix: String,
runtime: tokio::runtime::Handle,
) -> Self {
Self {
jetstream,
subject_prefix,
runtime,
sequence: AtomicU64::new(0),
publish_count: AtomicU64::new(0),
error_count: AtomicU64::new(0),
max_retries: DEFAULT_MAX_RETRIES,
serializer: Arc::new(JsonEventSerializer),
}
}
#[must_use = "builders do nothing unless consumed"]
#[inline]
pub fn with_max_retries(mut self, max_retries: u32) -> Self {
self.max_retries = max_retries;
self
}
#[must_use = "builders do nothing unless consumed"]
#[inline]
pub fn with_serializer(mut self, serializer: Arc<dyn EventSerializer>) -> Self {
self.serializer = serializer;
self
}
#[must_use]
#[inline]
pub fn publish_count(&self) -> u64 {
self.publish_count.load(Ordering::Relaxed)
}
#[must_use]
#[inline]
pub fn error_count(&self) -> u64 {
self.error_count.load(Ordering::Relaxed)
}
#[must_use]
#[inline]
pub fn sequence(&self) -> u64 {
self.sequence.load(Ordering::Relaxed)
}
#[must_use]
#[inline]
pub fn serializer(&self) -> &dyn EventSerializer {
self.serializer.as_ref()
}
pub fn into_listener(self) -> (Arc<Self>, TradeListener) {
let publisher = Arc::new(self);
let handle = Arc::clone(&publisher);
let listener = Arc::new(move |trade_result: &TradeResult| {
let payload = match publisher.serializer.serialize_trade(trade_result) {
Ok(bytes) => bytes,
Err(e) => {
publisher.error_count.fetch_add(1, Ordering::Relaxed);
error!(error = %e, "failed to serialize trade result for NATS");
return;
}
};
let symbol_seq = publisher.sequence.fetch_add(1, Ordering::Relaxed);
let all_seq = publisher.sequence.fetch_add(1, Ordering::Relaxed);
let symbol_subject = format!("{}.{}", publisher.subject_prefix, trade_result.symbol);
let all_subject = format!("{}.all", publisher.subject_prefix);
let pub_clone = Arc::clone(&publisher);
let payload_bytes: bytes::Bytes = payload.into();
pub_clone.runtime.spawn(Self::publish_with_retry(
Arc::clone(&pub_clone),
symbol_subject,
all_subject,
payload_bytes,
symbol_seq,
all_seq,
));
});
(handle, listener)
}
async fn publish_with_retry(
publisher: Arc<Self>,
symbol_subject: String,
all_subject: String,
payload: bytes::Bytes,
symbol_seq: u64,
all_seq: u64,
) {
let content_type = publisher.serializer.content_type();
let mut symbol_headers = async_nats::HeaderMap::new();
symbol_headers.insert("Nats-Sequence", symbol_seq.to_string().as_str());
symbol_headers.insert("Content-Type", content_type);
let mut all_headers = async_nats::HeaderMap::new();
all_headers.insert("Nats-Sequence", all_seq.to_string().as_str());
all_headers.insert("Content-Type", content_type);
let symbol_ok =
Self::publish_single(&publisher, &symbol_subject, payload.clone(), symbol_headers)
.await;
let all_ok = Self::publish_single(&publisher, &all_subject, payload, all_headers).await;
if symbol_ok && all_ok {
publisher.publish_count.fetch_add(1, Ordering::Relaxed);
trace!(symbol_seq, all_seq, symbol = %symbol_subject, "trade event published to NATS");
}
}
async fn publish_single(
publisher: &Arc<Self>,
subject: &str,
payload: bytes::Bytes,
headers: async_nats::HeaderMap,
) -> bool {
let max_attempts = publisher.max_retries.saturating_add(1);
for attempt in 0..max_attempts {
let publish_result = publisher
.jetstream
.publish_with_headers(subject.to_string(), headers.clone(), payload.clone())
.await;
match publish_result {
Ok(ack_future) => {
match ack_future.await {
Ok(_) => return true,
Err(e) => {
warn!(
attempt = attempt + 1,
max = max_attempts,
subject,
error = %e,
"NATS ack failed, retrying"
);
}
}
}
Err(e) => {
warn!(
attempt = attempt + 1,
max = max_attempts,
subject,
error = %e,
"NATS publish failed, retrying"
);
}
}
if attempt + 1 < max_attempts {
let shift = u32::min(attempt, 63);
let delay_ms =
BASE_RETRY_DELAY_MS.saturating_mul(1u64.checked_shl(shift).unwrap_or(u64::MAX));
tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
}
}
publisher.error_count.fetch_add(1, Ordering::Relaxed);
error!(subject, "NATS publish failed after all retries");
false
}
}
impl std::fmt::Debug for NatsTradePublisher {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NatsTradePublisher")
.field("subject_prefix", &self.subject_prefix)
.field("sequence", &self.sequence.load(Ordering::Relaxed))
.field("publish_count", &self.publish_count.load(Ordering::Relaxed))
.field("error_count", &self.error_count.load(Ordering::Relaxed))
.field("max_retries", &self.max_retries)
.field("serializer", &self.serializer.content_type())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use pricelevel::{Id, MatchResult};
fn make_trade_result(symbol: &str) -> TradeResult {
let order_id = Id::new_uuid();
let match_result = MatchResult::new(order_id, 100);
TradeResult::new(symbol.to_string(), match_result)
}
#[test]
fn test_trade_result_serializes_to_json() {
let tr = make_trade_result("BTC/USD");
let result = serde_json::to_vec(&tr);
assert!(result.is_ok());
let bytes = result.unwrap_or_default();
assert!(!bytes.is_empty());
let json_str = String::from_utf8(bytes).unwrap_or_default();
assert!(json_str.contains("BTC/USD"));
assert!(json_str.contains("match_result"));
}
#[test]
fn test_trade_result_serialize_roundtrip_fields() {
let tr = make_trade_result("ETH/USDT");
let json = serde_json::to_value(&tr);
assert!(json.is_ok());
let value = json.unwrap_or(serde_json::Value::Null);
assert_eq!(
value.get("symbol").and_then(|v| v.as_str()),
Some("ETH/USDT")
);
assert_eq!(
value.get("total_maker_fees").and_then(|v| v.as_i64()),
Some(0)
);
assert_eq!(
value.get("total_taker_fees").and_then(|v| v.as_i64()),
Some(0)
);
}
#[test]
fn test_subject_formatting() {
let prefix = "trades";
let symbol = "BTC/USD";
let symbol_subject = format!("{prefix}.{symbol}");
let all_subject = format!("{prefix}.all");
assert_eq!(symbol_subject, "trades.BTC/USD");
assert_eq!(all_subject, "trades.all");
}
#[test]
fn test_subject_formatting_with_custom_prefix() {
let prefix = "orderbook.events.trades";
let symbol = "ETH-PERP";
let symbol_subject = format!("{prefix}.{symbol}");
let all_subject = format!("{prefix}.all");
assert_eq!(symbol_subject, "orderbook.events.trades.ETH-PERP");
assert_eq!(all_subject, "orderbook.events.trades.all");
}
#[test]
fn test_default_max_retries() {
assert_eq!(DEFAULT_MAX_RETRIES, 3);
}
#[test]
fn test_base_retry_delay() {
assert_eq!(BASE_RETRY_DELAY_MS, 10);
}
#[test]
fn test_exponential_backoff_calculation() {
for attempt in 0u32..4 {
let shift = u32::min(attempt, 63);
let delay =
BASE_RETRY_DELAY_MS.saturating_mul(1u64.checked_shl(shift).unwrap_or(u64::MAX));
let expected = BASE_RETRY_DELAY_MS * 2u64.pow(attempt);
assert_eq!(delay, expected);
}
}
#[test]
fn test_exponential_backoff_high_retry_count_does_not_panic() {
for attempt in [63u32, 64, 100, u32::MAX] {
let shift = u32::min(attempt, 63);
let delay =
BASE_RETRY_DELAY_MS.saturating_mul(1u64.checked_shl(shift).unwrap_or(u64::MAX));
assert!(delay >= BASE_RETRY_DELAY_MS);
}
}
#[test]
fn test_nats_publish_error_display() {
let err = crate::orderbook::OrderBookError::NatsPublishError {
message: "connection refused".to_string(),
};
let display = format!("{err}");
assert!(display.contains("nats publish error"));
assert!(display.contains("connection refused"));
}
#[test]
fn test_nats_serialization_error_display() {
let err = crate::orderbook::OrderBookError::NatsSerializationError {
message: "invalid utf-8".to_string(),
};
let display = format!("{err}");
assert!(display.contains("nats serialization error"));
assert!(display.contains("invalid utf-8"));
}
}