1use std::sync::atomic::{AtomicU64, Ordering};
7use std::time::{Duration, Instant};
8
9use tokio::sync::{mpsc, Mutex};
10
11#[derive(Debug, Clone)]
13pub struct BatchConfig {
14 pub max_size: usize,
16 pub max_wait: Duration,
18 pub min_size: usize,
20}
21
22impl Default for BatchConfig {
23 fn default() -> Self {
24 Self {
25 max_size: 100,
26 max_wait: Duration::from_millis(50),
27 min_size: 10,
28 }
29 }
30}
31
32#[derive(Debug, Clone)]
34pub struct BatchedRequest {
35 pub args: Vec<u8>,
37 pub method: String,
39 pub response_tx: mpsc::Sender<Result<Vec<u8>, String>>,
41 pub queued_at: Instant,
43}
44
45#[derive(Debug)]
47pub struct RequestBatch {
48 pub requests: Vec<BatchedRequest>,
50 pub module: String,
52 pub created_at: Instant,
54}
55
56impl RequestBatch {
57 pub fn new(module: String) -> Self {
59 Self {
60 requests: Vec::new(),
61 module,
62 created_at: Instant::now(),
63 }
64 }
65
66 pub fn push(&mut self, request: BatchedRequest) {
68 self.requests.push(request);
69 }
70
71 pub fn len(&self) -> usize {
73 self.requests.len()
74 }
75
76 pub fn is_empty(&self) -> bool {
78 self.requests.is_empty()
79 }
80}
81
82#[derive(Debug, Clone)]
84pub struct BatchStats {
85 pub total_batches: u64,
86 pub total_requests: u64,
87 pub avg_batch_size: f64,
88 pub avg_latency_us: f64,
89 pub current_queue_size: usize,
90}
91
92#[derive(Debug)]
94pub struct RequestBatcher {
95 module_name: String,
97 config: BatchConfig,
99 pending: Mutex<Vec<BatchedRequest>>,
101 batch_tx: mpsc::Sender<RequestBatch>,
103 batch_rx: Mutex<Option<mpsc::Receiver<RequestBatch>>>,
105 total_batches: AtomicU64,
107 total_requests: AtomicU64,
109}
110
111impl RequestBatcher {
112 pub fn new(module_name: String, config: BatchConfig) -> Self {
114 let (batch_tx, batch_rx) = mpsc::channel(100);
115
116 Self {
117 module_name,
118 config,
119 pending: Mutex::new(Vec::new()),
120 batch_tx,
121 batch_rx: Mutex::new(Some(batch_rx)),
122 total_batches: AtomicU64::new(0),
123 total_requests: AtomicU64::new(0),
124 }
125 }
126
127 pub async fn queue(&self, request: BatchedRequest) -> Result<(), String> {
129 let mut pending = self.pending.lock().await;
130 pending.push(request);
131
132 let should_flush = pending.len() >= self.config.max_size;
133 drop(pending);
134
135 if should_flush {
136 self.flush().await;
137 }
138
139 Ok(())
140 }
141
142 pub async fn flush(&self) {
144 let mut pending = self.pending.lock().await;
145
146 if pending.is_empty() {
147 return;
148 }
149
150 let mut batch = RequestBatch::new(self.module_name.clone());
151 let requests: Vec<BatchedRequest> = std::mem::take(&mut *pending);
152
153 for req in requests {
154 batch.push(req);
155 }
156
157 let batch_size = batch.len();
158 let _ = self.batch_tx.send(batch).await;
159
160 self.total_batches.fetch_add(1, Ordering::Relaxed);
161 self.total_requests.fetch_add(batch_size as u64, Ordering::Relaxed);
162 }
163
164 pub fn take_receiver(&self) -> Option<mpsc::Receiver<RequestBatch>> {
166 self.batch_rx.try_lock().ok().and_then(|mut rx| rx.take())
167 }
168
169 pub fn stats(&self) -> BatchStats {
171 let total_batches = self.total_batches.load(Ordering::Relaxed);
172 let total_requests = self.total_requests.load(Ordering::Relaxed);
173
174 BatchStats {
175 total_batches,
176 total_requests,
177 avg_batch_size: if total_batches > 0 {
178 total_requests as f64 / total_batches as f64
179 } else {
180 0.0
181 },
182 avg_latency_us: 0.0, current_queue_size: 0, }
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191
192 #[test]
193 fn test_batch_config() {
194 let config = BatchConfig::default();
195 assert_eq!(config.max_size, 100);
196 assert_eq!(config.max_wait.as_millis(), 50);
197 assert_eq!(config.min_size, 10);
198 }
199
200 #[test]
201 fn test_batch_creation() {
202 let batch = RequestBatch::new("test".to_string());
203 assert_eq!(batch.len(), 0);
204 assert!(batch.is_empty());
205 }
206}