use crate::protocol::{Message, Offset, PartitionId, TopicName};
use crate::performance::ultra_performance::UltraPerformanceBroker;
use crossbeam::queue::SegQueue;
use dashmap::DashMap;
use std::sync::{Arc, atomic::{AtomicU64, AtomicBool, Ordering}};
use std::time::{Duration, Instant};
use tokio::sync::oneshot;
use tokio::time::sleep;
use tracing::{info, debug, warn};
#[derive(Clone)]
pub struct BatchAggregatorConfig {
pub max_batch_time_ms: u64,
pub max_batch_size: usize,
pub max_batch_memory_bytes: usize,
pub flush_workers: usize,
}
impl Default for BatchAggregatorConfig {
fn default() -> Self {
Self {
max_batch_time_ms: 25, max_batch_size: 500, max_batch_memory_bytes: 1024 * 1024, flush_workers: 4, }
}
}
pub struct PendingMessage {
pub topic: TopicName,
pub partition: PartitionId,
pub messages: Vec<Message>,
pub response_tx: oneshot::Sender<crate::Result<Offset>>,
pub created_at: Instant,
}
pub struct MessageBatch {
pub topic: TopicName,
pub partition: PartitionId,
pub messages: Vec<Message>,
pub response_channels: Vec<oneshot::Sender<crate::Result<Offset>>>,
pub created_at: Instant,
pub total_bytes: usize,
}
impl MessageBatch {
pub fn new(topic: TopicName, partition: PartitionId) -> Self {
Self {
topic,
partition,
messages: Vec::new(),
response_channels: Vec::new(),
created_at: Instant::now(),
total_bytes: 0,
}
}
pub fn add_message(&mut self, message: Message, response_tx: oneshot::Sender<crate::Result<Offset>>) {
self.total_bytes += message.value.len() + message.key.as_ref().map(|k| k.len()).unwrap_or(0);
self.messages.push(message);
self.response_channels.push(response_tx);
}
pub fn is_ready_for_flush(&self, config: &BatchAggregatorConfig) -> bool {
self.messages.len() >= config.max_batch_size ||
self.total_bytes >= config.max_batch_memory_bytes ||
self.created_at.elapsed() >= Duration::from_millis(config.max_batch_time_ms)
}
pub fn should_force_flush(&self) -> bool {
self.created_at.elapsed() >= Duration::from_millis(100)
}
}
pub type TopicPartitionKey = (TopicName, PartitionId);
pub struct ServerBatchAggregator {
config: BatchAggregatorConfig,
message_queue: Arc<SegQueue<PendingMessage>>,
batches: Arc<DashMap<TopicPartitionKey, MessageBatch>>,
ultra_broker: Arc<UltraPerformanceBroker>,
messages_batched: AtomicU64,
batches_flushed: AtomicU64,
average_batch_size: AtomicU64,
shutdown_signal: Arc<AtomicBool>,
}
impl ServerBatchAggregator {
pub fn new(ultra_broker: Arc<UltraPerformanceBroker>) -> Self {
Self::with_config(ultra_broker, BatchAggregatorConfig::default())
}
pub fn with_config(ultra_broker: Arc<UltraPerformanceBroker>, config: BatchAggregatorConfig) -> Self {
Self {
config,
message_queue: Arc::new(SegQueue::new()),
batches: Arc::new(DashMap::new()),
ultra_broker,
messages_batched: AtomicU64::new(0),
batches_flushed: AtomicU64::new(0),
average_batch_size: AtomicU64::new(0),
shutdown_signal: Arc::new(AtomicBool::new(false)),
}
}
pub async fn start(&self) {
info!("🚀 Starting ServerBatchAggregator with {} flush workers", self.config.flush_workers);
info!(" Max batch time: {}ms", self.config.max_batch_time_ms);
info!(" Max batch size: {} messages", self.config.max_batch_size);
info!(" Max batch memory: {} bytes", self.config.max_batch_memory_bytes);
for worker_id in 0..self.config.flush_workers {
let aggregator = self.clone();
tokio::spawn(async move {
aggregator.flush_worker_loop(worker_id).await;
});
}
let aggregator = self.clone();
tokio::spawn(async move {
aggregator.periodic_batch_checker().await;
});
}
pub async fn submit_message(
&self,
topic: TopicName,
partition: PartitionId,
messages: Vec<Message>,
) -> crate::Result<Offset> {
let (response_tx, response_rx) = oneshot::channel();
let pending = PendingMessage {
topic,
partition,
messages,
response_tx,
created_at: Instant::now(),
};
self.message_queue.push(pending);
self.try_process_queue().await;
response_rx.await.map_err(|_| {
crate::FluxmqError::Network("Batch processing channel closed".to_string())
})?
}
async fn try_process_queue(&self) {
while let Some(pending) = self.message_queue.pop() {
let key = (pending.topic.clone(), pending.partition);
let mut batch_ref = self.batches.entry(key.clone()).or_insert_with(|| {
MessageBatch::new(key.0.clone(), key.1)
});
if pending.messages.len() == 1 {
if let Some(message) = pending.messages.into_iter().next() {
batch_ref.add_message(message, pending.response_tx);
}
} else {
let mut messages_iter = pending.messages.into_iter();
if let Some(first_message) = messages_iter.next() {
batch_ref.add_message(first_message, pending.response_tx);
}
for message in messages_iter {
let (dummy_tx, _dummy_rx) = oneshot::channel();
batch_ref.add_message(message, dummy_tx);
}
}
if batch_ref.is_ready_for_flush(&self.config) {
drop(batch_ref);
if let Some((_, batch)) = self.batches.remove(&key) {
self.flush_batch(batch).await;
}
}
}
}
async fn flush_worker_loop(&self, worker_id: usize) {
debug!("🔧 Batch flush worker {} started", worker_id);
while !self.shutdown_signal.load(Ordering::Relaxed) {
self.try_process_queue().await;
sleep(Duration::from_millis(1)).await;
}
debug!("🔧 Batch flush worker {} shutting down", worker_id);
}
async fn periodic_batch_checker(&self) {
debug!("⏰ Periodic batch checker started");
while !self.shutdown_signal.load(Ordering::Relaxed) {
let batch_timeout = Duration::from_millis(self.config.max_batch_time_ms / 2);
sleep(batch_timeout).await;
let keys_to_flush: Vec<_> = self.batches.iter()
.filter_map(|entry| {
let (key, batch) = entry.pair();
if batch.should_force_flush() || batch.is_ready_for_flush(&self.config) {
Some(key.clone())
} else {
None
}
})
.collect();
for key in keys_to_flush {
if let Some((_, batch)) = self.batches.remove(&key) {
self.flush_batch(batch).await;
}
}
}
debug!("⏰ Periodic batch checker shutting down");
}
async fn flush_batch(&self, batch: MessageBatch) {
let batch_size = batch.messages.len();
let start_time = Instant::now();
debug!("🔥 Flushing batch: topic={}, partition={}, size={}, bytes={}",
batch.topic, batch.partition, batch_size, batch.total_bytes);
let messages_arc = Arc::new(batch.messages);
let result = self.ultra_broker
.append_messages_ultra_shared(&batch.topic, batch.partition, messages_arc);
let flush_duration = start_time.elapsed();
match result {
Ok(base_offset) => {
info!("🚀 BATCH-AGGREGATION: Successfully flushed {} messages in {:?} (offset: {})",
batch_size, flush_duration, base_offset);
for (i, response_tx) in batch.response_channels.into_iter().enumerate() {
let offset = base_offset + i as u64;
let _ = response_tx.send(Ok(offset));
}
self.messages_batched.fetch_add(batch_size as u64, Ordering::Relaxed);
self.batches_flushed.fetch_add(1, Ordering::Relaxed);
let total_batches = self.batches_flushed.load(Ordering::Relaxed);
let total_messages = self.messages_batched.load(Ordering::Relaxed);
if total_batches > 0 {
self.average_batch_size.store(total_messages / total_batches, Ordering::Relaxed);
}
}
Err(e) => {
warn!("⚠️ BATCH-AGGREGATION: Failed to flush batch: {}", e);
for response_tx in batch.response_channels {
let _ = response_tx.send(Err(crate::FluxmqError::Network(format!("Batch processing failed: {}", e))));
}
}
}
}
pub fn get_stats(&self) -> BatchAggregatorStats {
BatchAggregatorStats {
messages_batched: self.messages_batched.load(Ordering::Relaxed),
batches_flushed: self.batches_flushed.load(Ordering::Relaxed),
average_batch_size: self.average_batch_size.load(Ordering::Relaxed),
pending_batches: self.batches.len() as u64,
queued_messages: self.message_queue.len() as u64,
}
}
pub async fn shutdown(&self) {
info!("🔄 Shutting down ServerBatchAggregator");
self.shutdown_signal.store(true, Ordering::Relaxed);
let remaining_keys: Vec<_> = self.batches.iter().map(|entry| entry.key().clone()).collect();
for key in remaining_keys {
if let Some((_, batch)) = self.batches.remove(&key) {
self.flush_batch(batch).await;
}
}
info!("✅ ServerBatchAggregator shutdown complete");
}
}
impl Clone for ServerBatchAggregator {
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
message_queue: Arc::clone(&self.message_queue),
batches: Arc::clone(&self.batches),
ultra_broker: Arc::clone(&self.ultra_broker),
messages_batched: AtomicU64::new(self.messages_batched.load(Ordering::Relaxed)),
batches_flushed: AtomicU64::new(self.batches_flushed.load(Ordering::Relaxed)),
average_batch_size: AtomicU64::new(self.average_batch_size.load(Ordering::Relaxed)),
shutdown_signal: Arc::clone(&self.shutdown_signal),
}
}
}
#[derive(Debug, Clone)]
pub struct BatchAggregatorStats {
pub messages_batched: u64,
pub batches_flushed: u64,
pub average_batch_size: u64,
pub pending_batches: u64,
pub queued_messages: u64,
}
impl std::fmt::Display for BatchAggregatorStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f,
"BatchAggregator Stats: {} messages in {} batches (avg: {} msg/batch), {} pending, {} queued",
self.messages_batched,
self.batches_flushed,
self.average_batch_size,
self.pending_batches,
self.queued_messages
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::performance::ultra_performance::UltraPerformanceBroker;
use crate::protocol::{Message, TopicName, PartitionId};
use std::sync::Arc;
use tokio::time::{timeout, Duration};
use std::collections::HashMap;
async fn create_test_ultra_broker() -> Arc<UltraPerformanceBroker> {
Arc::new(UltraPerformanceBroker::new())
}
fn create_test_message(key: Option<String>, value: String) -> Message {
Message {
key: key.map(|k| k.into_bytes().into()),
value: value.into_bytes().into(),
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64,
headers: HashMap::new(),
}
}
#[tokio::test]
async fn test_batch_aggregator_creation() {
let ultra_broker = create_test_ultra_broker().await;
let aggregator = ServerBatchAggregator::new(ultra_broker);
let stats = aggregator.get_stats();
assert_eq!(stats.messages_batched, 0);
assert_eq!(stats.batches_flushed, 0);
assert_eq!(stats.pending_batches, 0);
assert_eq!(stats.queued_messages, 0);
}
#[tokio::test]
async fn test_batch_aggregator_with_custom_config() {
let ultra_broker = create_test_ultra_broker().await;
let config = BatchAggregatorConfig {
max_batch_time_ms: 10,
max_batch_size: 100,
max_batch_memory_bytes: 512 * 1024,
flush_workers: 2,
};
let aggregator = ServerBatchAggregator::with_config(ultra_broker, config);
assert_eq!(aggregator.config.max_batch_time_ms, 10);
assert_eq!(aggregator.config.max_batch_size, 100);
assert_eq!(aggregator.config.flush_workers, 2);
}
#[tokio::test]
async fn test_message_batch_creation() {
let topic: TopicName = "test-topic".to_string();
let partition: PartitionId = 0;
let batch = MessageBatch::new(topic.clone(), partition);
assert_eq!(batch.topic, topic);
assert_eq!(batch.partition, partition);
assert_eq!(batch.messages.len(), 0);
assert_eq!(batch.response_channels.len(), 0);
assert_eq!(batch.total_bytes, 0);
}
#[tokio::test]
async fn test_message_batch_add_message() {
let topic: TopicName = "test-topic".to_string();
let partition: PartitionId = 0;
let mut batch = MessageBatch::new(topic, partition);
let message = create_test_message(Some("key1".to_string()), "value1".to_string());
let expected_bytes = message.value.len() + message.key.as_ref().unwrap().len();
let (tx, _rx) = tokio::sync::oneshot::channel();
batch.add_message(message, tx);
assert_eq!(batch.messages.len(), 1);
assert_eq!(batch.response_channels.len(), 1);
assert_eq!(batch.total_bytes, expected_bytes);
}
#[tokio::test]
async fn test_batch_ready_for_flush_size() {
let config = BatchAggregatorConfig {
max_batch_size: 2, max_batch_time_ms: 1000, max_batch_memory_bytes: 1024 * 1024, flush_workers: 1,
};
let topic: TopicName = "test-topic".to_string();
let partition: PartitionId = 0;
let mut batch = MessageBatch::new(topic, partition);
assert!(!batch.is_ready_for_flush(&config));
let (tx1, _rx1) = tokio::sync::oneshot::channel();
batch.add_message(create_test_message(None, "msg1".to_string()), tx1);
assert!(!batch.is_ready_for_flush(&config));
let (tx2, _rx2) = tokio::sync::oneshot::channel();
batch.add_message(create_test_message(None, "msg2".to_string()), tx2);
assert!(batch.is_ready_for_flush(&config));
}
#[tokio::test]
async fn test_batch_ready_for_flush_memory() {
let config = BatchAggregatorConfig {
max_batch_size: 1000, max_batch_time_ms: 1000, max_batch_memory_bytes: 10, flush_workers: 1,
};
let topic: TopicName = "test-topic".to_string();
let partition: PartitionId = 0;
let mut batch = MessageBatch::new(topic, partition);
let (tx, _rx) = tokio::sync::oneshot::channel();
batch.add_message(create_test_message(None, "this is a long message".to_string()), tx);
assert!(batch.is_ready_for_flush(&config));
}
#[tokio::test]
async fn test_batch_ready_for_flush_time() {
let config = BatchAggregatorConfig {
max_batch_size: 1000, max_batch_time_ms: 1, max_batch_memory_bytes: 1024 * 1024, flush_workers: 1,
};
let topic: TopicName = "test-topic".to_string();
let partition: PartitionId = 0;
let mut batch = MessageBatch::new(topic, partition);
let (tx, _rx) = tokio::sync::oneshot::channel();
batch.add_message(create_test_message(None, "msg".to_string()), tx);
tokio::time::sleep(Duration::from_millis(10)).await;
assert!(batch.is_ready_for_flush(&config));
}
#[tokio::test]
async fn test_batch_force_flush() {
let topic: TopicName = "test-topic".to_string();
let partition: PartitionId = 0;
let batch = MessageBatch::new(topic, partition);
assert!(!batch.should_force_flush());
tokio::time::sleep(Duration::from_millis(110)).await;
assert!(batch.should_force_flush());
}
#[tokio::test]
async fn test_single_message_submission() {
let ultra_broker = create_test_ultra_broker().await;
let config = BatchAggregatorConfig {
max_batch_time_ms: 10,
max_batch_size: 1, max_batch_memory_bytes: 1024 * 1024,
flush_workers: 1,
};
let aggregator = Arc::new(ServerBatchAggregator::with_config(ultra_broker, config));
aggregator.start().await;
let topic: TopicName = "test-topic".to_string();
let partition: PartitionId = 0;
let message = create_test_message(Some("key1".to_string()), "test message".to_string());
let result = timeout(
Duration::from_secs(5),
aggregator.submit_message(topic, partition, vec![message])
).await;
assert!(result.is_ok(), "Message submission should not timeout");
let offset_result = result.unwrap();
assert!(offset_result.is_ok(), "Message should be processed successfully: {:?}", offset_result);
tokio::time::sleep(Duration::from_millis(50)).await; let stats = aggregator.get_stats();
assert!(stats.messages_batched >= 1, "Should have batched at least 1 message");
assert!(stats.batches_flushed >= 1, "Should have flushed at least 1 batch");
}
#[tokio::test]
async fn test_multiple_message_batching() {
let ultra_broker = create_test_ultra_broker().await;
let config = BatchAggregatorConfig {
max_batch_time_ms: 50, max_batch_size: 3, max_batch_memory_bytes: 1024 * 1024,
flush_workers: 2,
};
let aggregator = Arc::new(ServerBatchAggregator::with_config(ultra_broker, config));
aggregator.start().await;
let topic: TopicName = "test-topic".to_string();
let partition: PartitionId = 0;
let mut handles = Vec::new();
for i in 0..3 {
let agg = Arc::clone(&aggregator);
let topic_clone = topic.clone();
let message = create_test_message(Some(format!("key{}", i)), format!("message {}", i));
let handle = tokio::spawn(async move {
agg.submit_message(topic_clone, partition, vec![message]).await
});
handles.push(handle);
}
let results = futures::future::join_all(handles).await;
for result in results {
assert!(result.is_ok(), "Task should complete successfully");
let offset_result = result.unwrap();
assert!(offset_result.is_ok(), "Message should be processed successfully: {:?}", offset_result);
}
tokio::time::sleep(Duration::from_millis(100)).await;
let stats = aggregator.get_stats();
assert_eq!(stats.messages_batched, 3, "Should have batched exactly 3 messages");
assert_eq!(stats.batches_flushed, 1, "Should have flushed exactly 1 batch");
assert_eq!(stats.average_batch_size, 3, "Average batch size should be 3");
}
#[tokio::test]
async fn test_size_based_immediate_flushing() {
let ultra_broker = create_test_ultra_broker().await;
let config = BatchAggregatorConfig {
max_batch_time_ms: 1000, max_batch_size: 1, max_batch_memory_bytes: 1024 * 1024,
flush_workers: 1,
};
let aggregator = Arc::new(ServerBatchAggregator::with_config(ultra_broker, config));
aggregator.start().await;
let topic: TopicName = "test-topic".to_string();
let partition: PartitionId = 0;
let message = create_test_message(Some("key1".to_string()), "test message".to_string());
let result = timeout(
Duration::from_secs(2),
aggregator.submit_message(topic, partition, vec![message])
).await;
assert!(result.is_ok(), "Message submission should not timeout");
let offset_result = result.unwrap();
assert!(offset_result.is_ok(), "Message should be processed successfully: {:?}", offset_result);
tokio::time::sleep(Duration::from_millis(10)).await;
let stats = aggregator.get_stats();
assert!(stats.messages_batched >= 1, "Should have batched at least 1 message, got: {}", stats.messages_batched);
assert!(stats.batches_flushed >= 1, "Should have flushed at least 1 batch (size-based), got: {}", stats.batches_flushed);
}
#[tokio::test]
async fn test_multi_partition_batching() {
let ultra_broker = create_test_ultra_broker().await;
let config = BatchAggregatorConfig {
max_batch_time_ms: 10,
max_batch_size: 1, max_batch_memory_bytes: 1024 * 1024,
flush_workers: 2,
};
let aggregator = Arc::new(ServerBatchAggregator::with_config(ultra_broker, config));
aggregator.start().await;
let topic: TopicName = "test-topic".to_string();
let partitions = [0, 1, 2];
let mut handles = Vec::new();
for &partition in &partitions {
let agg = Arc::clone(&aggregator);
let topic_clone = topic.clone();
let message = create_test_message(
Some(format!("key_p{}", partition)),
format!("message for partition {}", partition)
);
let handle = tokio::spawn(async move {
agg.submit_message(topic_clone, partition, vec![message]).await
});
handles.push(handle);
}
let results = futures::future::join_all(handles).await;
for result in results {
assert!(result.is_ok(), "Task should complete successfully");
let offset_result = result.unwrap();
assert!(offset_result.is_ok(), "Message should be processed successfully: {:?}", offset_result);
}
tokio::time::sleep(Duration::from_millis(100)).await;
let stats = aggregator.get_stats();
assert_eq!(stats.messages_batched, 3, "Should have batched 3 messages total");
assert_eq!(stats.batches_flushed, 3, "Should have flushed 3 batches (one per partition)");
}
#[tokio::test]
async fn test_aggregator_shutdown() {
let ultra_broker = create_test_ultra_broker().await;
let aggregator = Arc::new(ServerBatchAggregator::new(ultra_broker));
aggregator.start().await;
let topic: TopicName = "test-topic".to_string();
let partition: PartitionId = 0;
let message = create_test_message(Some("key".to_string()), "test".to_string());
let _handle = {
let agg = Arc::clone(&aggregator);
tokio::spawn(async move {
agg.submit_message(topic, partition, vec![message]).await
})
};
tokio::time::sleep(Duration::from_millis(10)).await;
let shutdown_result = timeout(Duration::from_secs(5), aggregator.shutdown()).await;
assert!(shutdown_result.is_ok(), "Shutdown should complete within timeout");
let stats = aggregator.get_stats();
assert_eq!(stats.pending_batches, 0, "No batches should be pending after shutdown");
}
#[tokio::test]
async fn test_batch_aggregator_stats_display() {
let stats = BatchAggregatorStats {
messages_batched: 1000,
batches_flushed: 50,
average_batch_size: 20,
pending_batches: 2,
queued_messages: 5,
};
let display = format!("{}", stats);
assert!(display.contains("1000 messages"));
assert!(display.contains("50 batches"));
assert!(display.contains("avg: 20 msg/batch"));
assert!(display.contains("2 pending"));
assert!(display.contains("5 queued"));
}
#[tokio::test]
async fn test_config_defaults() {
let default_config = BatchAggregatorConfig::default();
assert_eq!(default_config.max_batch_time_ms, 25);
assert_eq!(default_config.max_batch_size, 500);
assert_eq!(default_config.max_batch_memory_bytes, 1024 * 1024);
assert_eq!(default_config.flush_workers, 4);
}
}