use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::time::{Duration, Instant};
use parking_lot::{Mutex, RwLock};
use tokio::sync::oneshot;
use tracing::{debug, trace};
use super::config::BatchProcessorConfig;
pub use super::config::BatchProcessorConfig as BatchConfig;
pub type RequestId = u64;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ProcessingStrategy {
Sequential,
Parallel,
PriorityParallel,
Adaptive,
}
impl Default for ProcessingStrategy {
fn default() -> Self {
Self::Adaptive
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum RequestPriority {
Low = 0,
Normal = 1,
High = 2,
Critical = 3,
}
impl Default for RequestPriority {
fn default() -> Self {
Self::Normal
}
}
#[derive(Debug)]
pub struct BatchRequest<T> {
pub id: RequestId,
pub payload: T,
pub priority: RequestPriority,
pub unit_id: u8,
pub function_code: u8,
pub start_address: u16,
pub quantity: u16,
submitted_at: Instant,
response_tx: Option<oneshot::Sender<BatchResponse<T>>>,
}
impl<T> BatchRequest<T> {
pub fn new(
payload: T,
unit_id: u8,
function_code: u8,
start_address: u16,
quantity: u16,
) -> (Self, oneshot::Receiver<BatchResponse<T>>) {
let (tx, rx) = oneshot::channel();
let request = Self {
id: 0, payload,
priority: RequestPriority::Normal,
unit_id,
function_code,
start_address,
quantity,
submitted_at: Instant::now(),
response_tx: Some(tx),
};
(request, rx)
}
pub fn with_priority(mut self, priority: RequestPriority) -> Self {
self.priority = priority;
self
}
pub fn wait_time(&self) -> Duration {
self.submitted_at.elapsed()
}
pub fn can_coalesce(&self, other: &Self) -> bool {
if self.function_code > 4 || other.function_code > 4 {
return false;
}
if self.unit_id != other.unit_id || self.function_code != other.function_code {
return false;
}
let self_end = self.start_address.saturating_add(self.quantity);
let other_end = other.start_address.saturating_add(other.quantity);
(self_end >= other.start_address && self.start_address <= other_end)
|| (other_end >= self.start_address && other.start_address <= self_end)
}
}
#[derive(Debug)]
pub struct BatchResponse<T> {
pub request_id: RequestId,
pub result: Result<T, BatchError>,
pub latency: Duration,
pub coalesced: bool,
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum BatchError {
#[error("Request timed out after {0:?}")]
Timeout(Duration),
#[error("Processor is shutting down")]
Shutdown,
#[error("Queue is full ({current}/{max})")]
QueueFull { current: usize, max: usize },
#[error("Processing error: {0}")]
Processing(String),
#[error("Request cancelled")]
Cancelled,
}
#[derive(Debug, Clone, Default)]
pub struct BatchStatistics {
pub total_submitted: u64,
pub total_processed: u64,
pub total_coalesced: u64,
pub total_batches: u64,
pub avg_batch_size: f64,
pub avg_latency_us: f64,
pub queue_depth: usize,
pub peak_queue_depth: usize,
}
struct ProcessingBatch<T> {
requests: Vec<BatchRequest<T>>,
created_at: Instant,
}
impl<T> ProcessingBatch<T> {
fn new() -> Self {
Self {
requests: Vec::new(),
created_at: Instant::now(),
}
}
fn is_empty(&self) -> bool {
self.requests.is_empty()
}
fn len(&self) -> usize {
self.requests.len()
}
fn age(&self) -> Duration {
self.created_at.elapsed()
}
fn push(&mut self, request: BatchRequest<T>) {
self.requests.push(request);
}
fn should_flush(&self, config: &BatchProcessorConfig) -> bool {
self.len() >= config.batch_size || self.age() >= config.batch_timeout
}
}
pub trait BatchHandler<T>: Send + Sync {
fn process(&self, request: &BatchRequest<T>) -> Result<T, BatchError>;
fn process_batch(&self, requests: &[BatchRequest<T>]) -> Vec<Result<T, BatchError>> {
requests.iter().map(|r| self.process(r)).collect()
}
fn can_coalesce(&self, a: &BatchRequest<T>, b: &BatchRequest<T>) -> bool {
a.can_coalesce(b)
}
}
pub struct BatchProcessor<T>
where
T: Send + 'static,
{
config: BatchProcessorConfig,
strategy: ProcessingStrategy,
next_id: AtomicU64,
pending_batch: Mutex<ProcessingBatch<T>>,
pending_count: AtomicUsize,
peak_queue_depth: AtomicUsize,
stats: RwLock<BatchStatisticsInternal>,
shutdown: std::sync::atomic::AtomicBool,
}
#[derive(Default)]
struct BatchStatisticsInternal {
total_submitted: u64,
total_processed: u64,
total_coalesced: u64,
total_batches: u64,
total_latency_us: u64,
}
impl<T> BatchProcessor<T>
where
T: Send + Clone + 'static,
{
pub fn new(config: BatchProcessorConfig) -> Self {
Self {
config,
strategy: ProcessingStrategy::Adaptive,
next_id: AtomicU64::new(1),
pending_batch: Mutex::new(ProcessingBatch::new()),
pending_count: AtomicUsize::new(0),
peak_queue_depth: AtomicUsize::new(0),
stats: RwLock::new(BatchStatisticsInternal::default()),
shutdown: std::sync::atomic::AtomicBool::new(false),
}
}
pub fn with_strategy(mut self, strategy: ProcessingStrategy) -> Self {
self.strategy = strategy;
self
}
pub fn submit(&self, mut request: BatchRequest<T>) -> Result<(), BatchError> {
if self.shutdown.load(Ordering::SeqCst) {
return Err(BatchError::Shutdown);
}
let current = self.pending_count.load(Ordering::Relaxed);
if current >= self.config.max_pending {
return Err(BatchError::QueueFull {
current,
max: self.config.max_pending,
});
}
request.id = self.next_id.fetch_add(1, Ordering::Relaxed);
let request_id = request.id;
{
let mut batch = self.pending_batch.lock();
batch.push(request);
let count = batch.len();
self.pending_count.store(count, Ordering::Relaxed);
let peak = self.peak_queue_depth.load(Ordering::Relaxed);
if count > peak {
self.peak_queue_depth.store(count, Ordering::Relaxed);
}
}
{
let mut stats = self.stats.write();
stats.total_submitted += 1;
}
trace!(
request_id,
pending = self.pending_count.load(Ordering::Relaxed),
"Request submitted"
);
Ok(())
}
pub async fn flush<H: BatchHandler<T>>(&self, handler: &H) -> usize {
let batch = {
let mut pending = self.pending_batch.lock();
std::mem::replace(&mut *pending, ProcessingBatch::new())
};
if batch.is_empty() {
return 0;
}
let batch_size = batch.len();
self.pending_count.store(0, Ordering::Relaxed);
debug!(
batch_size,
age_ms = batch.age().as_millis(),
"Flushing batch"
);
let processed = match self.strategy {
ProcessingStrategy::Sequential => self.process_sequential(batch, handler).await,
ProcessingStrategy::Parallel | ProcessingStrategy::PriorityParallel => {
self.process_parallel(batch, handler).await
}
ProcessingStrategy::Adaptive => {
if batch_size > 10 {
self.process_parallel(batch, handler).await
} else {
self.process_sequential(batch, handler).await
}
}
};
{
let mut stats = self.stats.write();
stats.total_batches += 1;
}
processed
}
async fn process_sequential<H: BatchHandler<T>>(
&self,
batch: ProcessingBatch<T>,
handler: &H,
) -> usize {
let mut processed = 0;
for request in batch.requests {
let start = Instant::now();
let result = handler.process(&request);
let latency = start.elapsed();
self.send_response(request, result, latency, false);
processed += 1;
{
let mut stats = self.stats.write();
stats.total_processed += 1;
stats.total_latency_us += latency.as_micros() as u64;
}
}
processed
}
async fn process_parallel<H: BatchHandler<T>>(
&self,
batch: ProcessingBatch<T>,
handler: &H,
) -> usize {
let requests = if self.config.coalescing.enabled {
self.coalesce_requests(batch.requests, handler)
} else {
batch.requests
};
let results = handler.process_batch(&requests);
let mut processed = 0;
for (request, result) in requests.into_iter().zip(results.into_iter()) {
let latency = request.wait_time();
self.send_response(request, result, latency, false);
processed += 1;
{
let mut stats = self.stats.write();
stats.total_processed += 1;
stats.total_latency_us += latency.as_micros() as u64;
}
}
processed
}
fn coalesce_requests<H: BatchHandler<T>>(
&self,
mut requests: Vec<BatchRequest<T>>,
handler: &H,
) -> Vec<BatchRequest<T>> {
if requests.len() <= 1 {
return requests;
}
requests.sort_by(|a, b| {
(a.unit_id, a.function_code, a.start_address).cmp(&(
b.unit_id,
b.function_code,
b.start_address,
))
});
let mut coalesced = Vec::with_capacity(requests.len());
let mut i = 0;
while i < requests.len() {
let mut current = requests.swap_remove(i);
let mut merged_count = 0;
let mut j = i;
while j < requests.len() && merged_count < self.config.coalescing.max_coalesce {
if handler.can_coalesce(¤t, &requests[j]) {
let other = requests.swap_remove(j);
let new_start = current.start_address.min(other.start_address);
let current_end = current.start_address + current.quantity;
let other_end = other.start_address + other.quantity;
let new_end = current_end.max(other_end);
current.start_address = new_start;
current.quantity = new_end - new_start;
merged_count += 1;
{
let mut stats = self.stats.write();
stats.total_coalesced += 1;
}
} else {
j += 1;
}
}
coalesced.push(current);
if i < requests.len() {
i += 1;
} else {
break;
}
}
coalesced
}
fn send_response(
&self,
mut request: BatchRequest<T>,
result: Result<T, BatchError>,
latency: Duration,
coalesced: bool,
) {
if let Some(tx) = request.response_tx.take() {
let response = BatchResponse {
request_id: request.id,
result,
latency,
coalesced,
};
let _ = tx.send(response);
}
}
pub fn should_flush(&self) -> bool {
let batch = self.pending_batch.lock();
batch.should_flush(&self.config)
}
pub fn queue_depth(&self) -> usize {
self.pending_count.load(Ordering::Relaxed)
}
pub fn statistics(&self) -> BatchStatistics {
let stats = self.stats.read();
let queue_depth = self.pending_count.load(Ordering::Relaxed);
let peak_queue_depth = self.peak_queue_depth.load(Ordering::Relaxed);
let avg_batch_size = if stats.total_batches > 0 {
stats.total_processed as f64 / stats.total_batches as f64
} else {
0.0
};
let avg_latency_us = if stats.total_processed > 0 {
stats.total_latency_us as f64 / stats.total_processed as f64
} else {
0.0
};
BatchStatistics {
total_submitted: stats.total_submitted,
total_processed: stats.total_processed,
total_coalesced: stats.total_coalesced,
total_batches: stats.total_batches,
avg_batch_size,
avg_latency_us,
queue_depth,
peak_queue_depth,
}
}
pub fn is_enabled(&self) -> bool {
self.config.enabled
}
pub fn shutdown(&self) {
self.shutdown.store(true, Ordering::SeqCst);
let batch = {
let mut pending = self.pending_batch.lock();
std::mem::replace(&mut *pending, ProcessingBatch::new())
};
for mut request in batch.requests {
if let Some(tx) = request.response_tx.take() {
let response = BatchResponse {
request_id: request.id,
result: Err(BatchError::Shutdown),
latency: request.wait_time(),
coalesced: false,
};
let _ = tx.send(response);
}
}
}
pub fn config(&self) -> &BatchProcessorConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::super::config::CoalescingConfig;
use super::*;
struct EchoHandler;
impl BatchHandler<Vec<u8>> for EchoHandler {
fn process(&self, request: &BatchRequest<Vec<u8>>) -> Result<Vec<u8>, BatchError> {
Ok(request.payload.clone())
}
}
fn make_config(batch_size: usize, max_pending: usize) -> BatchProcessorConfig {
BatchProcessorConfig {
enabled: true,
batch_size,
batch_timeout: Duration::from_millis(10),
max_pending,
coalescing: CoalescingConfig::disabled(),
}
}
#[test]
fn test_request_creation() {
let payload = vec![0x03u8, 0x00, 0x00, 0x00, 0x0A];
let (request, _rx) = BatchRequest::<Vec<u8>>::new(payload.clone(), 1, 0x03, 0, 10);
assert_eq!(request.unit_id, 1);
assert_eq!(request.function_code, 0x03);
assert_eq!(request.start_address, 0);
assert_eq!(request.quantity, 10);
assert_eq!(request.payload, payload);
}
#[test]
fn test_can_coalesce() {
let (req1, _) = BatchRequest::<Vec<u8>>::new(vec![], 1, 0x03, 0, 10);
let (req2, _) = BatchRequest::<Vec<u8>>::new(vec![], 1, 0x03, 10, 10);
let (req3, _) = BatchRequest::<Vec<u8>>::new(vec![], 1, 0x03, 100, 10);
let (req4, _) = BatchRequest::<Vec<u8>>::new(vec![], 2, 0x03, 0, 10);
let (req5, _) = BatchRequest::<Vec<u8>>::new(vec![], 1, 0x06, 0, 10);
assert!(req1.can_coalesce(&req2));
assert!(!req1.can_coalesce(&req3));
assert!(!req1.can_coalesce(&req4));
assert!(!req1.can_coalesce(&req5));
assert!(!req5.can_coalesce(&req1));
}
#[tokio::test]
async fn test_submit_and_flush() {
let config = make_config(10, 100);
let processor = BatchProcessor::<Vec<u8>>::new(config);
let handler = EchoHandler;
let payload = vec![0x01, 0x02, 0x03];
let (request, rx) = BatchRequest::new(payload.clone(), 1, 0x03, 0, 10);
processor.submit(request).unwrap();
assert_eq!(processor.queue_depth(), 1);
let processed = processor.flush(&handler).await;
assert_eq!(processed, 1);
assert_eq!(processor.queue_depth(), 0);
let response = rx.await.unwrap();
assert!(response.result.is_ok());
assert_eq!(response.result.unwrap(), payload);
}
#[tokio::test]
async fn test_queue_full_rejection() {
let config = make_config(10, 2);
let processor = BatchProcessor::<Vec<u8>>::new(config);
let (req1, _) = BatchRequest::new(vec![], 1, 0x03, 0, 10);
let (req2, _) = BatchRequest::new(vec![], 1, 0x03, 0, 10);
assert!(processor.submit(req1).is_ok());
assert!(processor.submit(req2).is_ok());
let (req3, _) = BatchRequest::new(vec![], 1, 0x03, 0, 10);
let result = processor.submit(req3);
assert!(matches!(result, Err(BatchError::QueueFull { .. })));
}
#[tokio::test]
async fn test_statistics() {
let config = make_config(10, 100);
let processor = BatchProcessor::<Vec<u8>>::new(config);
let handler = EchoHandler;
for i in 0..5 {
let (request, _) = BatchRequest::new(vec![i], 1, 0x03, i as u16, 1);
processor.submit(request).unwrap();
}
processor.flush(&handler).await;
let stats = processor.statistics();
assert_eq!(stats.total_submitted, 5);
assert_eq!(stats.total_processed, 5);
assert_eq!(stats.total_batches, 1);
assert_eq!(stats.avg_batch_size, 5.0);
}
#[tokio::test]
async fn test_shutdown() {
let config = make_config(10, 100);
let processor = BatchProcessor::<Vec<u8>>::new(config);
let (request, rx) = BatchRequest::new(vec![], 1, 0x03, 0, 10);
processor.submit(request).unwrap();
processor.shutdown();
let response = rx.await.unwrap();
assert!(matches!(response.result, Err(BatchError::Shutdown)));
let (request, _) = BatchRequest::new(vec![], 1, 0x03, 0, 10);
assert!(matches!(
processor.submit(request),
Err(BatchError::Shutdown)
));
}
#[test]
fn test_priority() {
let (request, _) = BatchRequest::<Vec<u8>>::new(vec![], 1, 0x03, 0, 10);
assert_eq!(request.priority, RequestPriority::Normal);
let request: BatchRequest<Vec<u8>> = request.with_priority(RequestPriority::High);
assert_eq!(request.priority, RequestPriority::High);
assert!(RequestPriority::Critical > RequestPriority::High);
assert!(RequestPriority::High > RequestPriority::Normal);
assert!(RequestPriority::Normal > RequestPriority::Low);
}
}