1use std::future::Future;
7use std::time::Duration;
8use tokio::time::timeout;
9
10#[derive(Debug, Clone)]
12pub struct BatchConfig {
13 pub max_batch_size: usize,
15 pub max_wait_time: Duration,
17 pub num_workers: usize,
19}
20
21impl Default for BatchConfig {
22 fn default() -> Self {
23 Self {
24 max_batch_size: 100,
25 max_wait_time: Duration::from_millis(100),
26 num_workers: 4,
27 }
28 }
29}
30
31pub async fn parallel_execute<F, Fut, T, E>(
33 operations: Vec<F>,
34 operation_timeout: Duration,
35) -> Vec<Result<T, E>>
36where
37 F: FnOnce() -> Fut + Send + 'static,
38 Fut: Future<Output = Result<T, E>> + Send + 'static,
39 T: Send + 'static,
40 E: From<String> + Send + 'static,
41{
42 let handles: Vec<_> = operations
43 .into_iter()
44 .map(|op| {
45 tokio::spawn(async move {
46 match timeout(operation_timeout, op()).await {
47 Ok(result) => result,
48 Err(_) => Err(E::from("Operation timed out".to_string())),
49 }
50 })
51 })
52 .collect();
53
54 let mut results = Vec::new();
55 for handle in handles {
56 match handle.await {
57 Ok(result) => results.push(result),
58 Err(e) => results.push(Err(E::from(format!("Task panicked: {}", e)))),
59 }
60 }
61 results
62}
63
64pub async fn batch_execute<F, Fut, T, E>(
66 mut items: Vec<F>,
67 config: BatchConfig,
68) -> Vec<Result<T, E>>
69where
70 F: FnOnce() -> Fut + Send + 'static,
71 Fut: Future<Output = Result<T, E>> + Send + 'static,
72 T: Send + 'static,
73 E: From<String> + Send + 'static,
74{
75 let mut all_results = Vec::new();
76
77 while !items.is_empty() {
78 let batch_size = items.len().min(config.max_batch_size);
79 let batch: Vec<_> = items.drain(..batch_size).collect();
80
81 let results = parallel_execute(batch, config.max_wait_time).await;
82 all_results.extend(results);
83 }
84
85 all_results
86}
87
88#[derive(Debug, Clone)]
90pub struct RateLimiter {
91 max_requests: u32,
92 window: Duration,
93 requests: std::sync::Arc<tokio::sync::Mutex<Vec<std::time::Instant>>>,
94}
95
96impl RateLimiter {
97 pub fn new(max_requests: u32, window: Duration) -> Self {
99 Self {
100 max_requests,
101 window,
102 requests: std::sync::Arc::new(tokio::sync::Mutex::new(Vec::new())),
103 }
104 }
105
106 pub async fn acquire(&self) {
108 loop {
109 let mut requests = self.requests.lock().await;
110 let now = std::time::Instant::now();
111
112 requests.retain(|&req_time| now.duration_since(req_time) < self.window);
114
115 if requests.len() < self.max_requests as usize {
116 requests.push(now);
117 return;
118 }
119
120 if let Some(&oldest) = requests.first() {
122 let elapsed = now.duration_since(oldest);
123 if elapsed < self.window {
124 let wait_time = self.window - elapsed;
125 drop(requests); tokio::time::sleep(wait_time).await;
127 }
128 }
129 }
130 }
131
132 pub async fn execute<F, Fut, T>(&self, operation: F) -> T
134 where
135 F: FnOnce() -> Fut,
136 Fut: Future<Output = T>,
137 {
138 self.acquire().await;
139 operation().await
140 }
141}
142
143#[derive(Clone)]
145pub struct ConnectionPool<T: Clone> {
146 connections: std::sync::Arc<tokio::sync::RwLock<Vec<T>>>,
147 current_index: std::sync::Arc<std::sync::atomic::AtomicUsize>,
148}
149
150impl<T: Clone> ConnectionPool<T> {
151 pub fn new(connections: Vec<T>) -> Self {
153 Self {
154 connections: std::sync::Arc::new(tokio::sync::RwLock::new(connections)),
155 current_index: std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0)),
156 }
157 }
158
159 pub async fn get(&self) -> Option<T> {
161 let connections = self.connections.read().await;
162 if connections.is_empty() {
163 return None;
164 }
165
166 let index = self
167 .current_index
168 .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
169 % connections.len();
170 Some(connections[index].clone())
171 }
172
173 pub async fn get_all(&self) -> Vec<T> {
175 self.connections.read().await.clone()
176 }
177
178 pub async fn add(&self, connection: T) {
180 self.connections.write().await.push(connection);
181 }
182
183 pub async fn remove(&self, predicate: impl Fn(&T) -> bool) {
185 self.connections.write().await.retain(|c| !predicate(c));
186 }
187
188 pub async fn size(&self) -> usize {
190 self.connections.read().await.len()
191 }
192}
193
194pub struct AsyncMemo<K, V>
196where
197 K: std::hash::Hash + Eq + Clone,
198 V: Clone,
199{
200 cache: std::sync::Arc<tokio::sync::RwLock<std::collections::HashMap<K, V>>>,
201 ttl: Option<Duration>,
202 timestamps:
203 std::sync::Arc<tokio::sync::RwLock<std::collections::HashMap<K, std::time::Instant>>>,
204}
205
206impl<K, V> Default for AsyncMemo<K, V>
207where
208 K: std::hash::Hash + Eq + Clone,
209 V: Clone,
210{
211 fn default() -> Self {
212 Self::new()
213 }
214}
215
216impl<K, V> AsyncMemo<K, V>
217where
218 K: std::hash::Hash + Eq + Clone,
219 V: Clone,
220{
221 pub fn new() -> Self {
223 Self {
224 cache: std::sync::Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())),
225 ttl: None,
226 timestamps: std::sync::Arc::new(tokio::sync::RwLock::new(
227 std::collections::HashMap::new(),
228 )),
229 }
230 }
231
232 pub fn with_ttl(ttl: Duration) -> Self {
234 Self {
235 cache: std::sync::Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())),
236 ttl: Some(ttl),
237 timestamps: std::sync::Arc::new(tokio::sync::RwLock::new(
238 std::collections::HashMap::new(),
239 )),
240 }
241 }
242
243 pub async fn get_or_compute<F, Fut>(&self, key: K, compute: F) -> V
245 where
246 F: FnOnce() -> Fut,
247 Fut: Future<Output = V>,
248 {
249 {
251 let cache = self.cache.read().await;
252 if let Some(value) = cache.get(&key) {
253 if let Some(ttl) = self.ttl {
254 let timestamps = self.timestamps.read().await;
255 if let Some(×tamp) = timestamps.get(&key) {
256 if timestamp.elapsed() < ttl {
257 return value.clone();
258 }
259 }
260 } else {
261 return value.clone();
262 }
263 }
264 }
265
266 let value = compute().await;
268
269 {
271 let mut cache = self.cache.write().await;
272 cache.insert(key.clone(), value.clone());
273
274 if self.ttl.is_some() {
275 let mut timestamps = self.timestamps.write().await;
276 timestamps.insert(key, std::time::Instant::now());
277 }
278 }
279
280 value
281 }
282
283 pub async fn clear(&self) {
285 let mut cache = self.cache.write().await;
286 cache.clear();
287 let mut timestamps = self.timestamps.write().await;
288 timestamps.clear();
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295 use std::sync::atomic::{AtomicU32, Ordering};
296 use std::sync::Arc;
297
298 #[tokio::test]
299 async fn test_parallel_execute() {
300 let operations: Vec<_> = (0..5)
301 .map(|i| {
302 move || async move {
303 tokio::time::sleep(Duration::from_millis(10)).await;
304 Ok::<_, String>(i)
305 }
306 })
307 .collect();
308
309 let results = parallel_execute(operations, Duration::from_secs(1)).await;
310
311 assert_eq!(results.len(), 5);
312 for result in results {
313 assert!(result.is_ok());
314 }
315 }
316
317 #[tokio::test]
318 async fn test_rate_limiter() {
319 let limiter = RateLimiter::new(2, Duration::from_millis(100));
320
321 let counter = Arc::new(AtomicU32::new(0));
322 let mut handles = vec![];
323
324 for _ in 0..5 {
325 let limiter = limiter.clone();
326 let counter = counter.clone();
327 handles.push(tokio::spawn(async move {
328 limiter.acquire().await;
329 counter.fetch_add(1, Ordering::Relaxed);
330 }));
331 }
332
333 tokio::time::sleep(Duration::from_millis(50)).await;
335 let count = counter.load(Ordering::Relaxed);
336 assert!(count <= 2);
337
338 for handle in handles {
340 handle.await.unwrap();
341 }
342
343 assert_eq!(counter.load(Ordering::Relaxed), 5);
344 }
345
346 #[tokio::test]
347 async fn test_connection_pool() {
348 let pool = ConnectionPool::new(vec!["conn1", "conn2", "conn3"]);
349
350 let conn1 = pool.get().await.unwrap();
351 let conn2 = pool.get().await.unwrap();
352 let conn3 = pool.get().await.unwrap();
353 let conn4 = pool.get().await.unwrap(); assert_eq!(conn1, "conn1");
356 assert_eq!(conn2, "conn2");
357 assert_eq!(conn3, "conn3");
358 assert_eq!(conn4, "conn1"); }
360
361 #[tokio::test]
362 async fn test_async_memo() {
363 let memo = AsyncMemo::new();
364 let counter = Arc::new(AtomicU32::new(0));
365
366 let counter_clone = counter.clone();
367 let value1 = memo
368 .get_or_compute("key1", || async {
369 counter_clone.fetch_add(1, Ordering::Relaxed);
370 42
371 })
372 .await;
373
374 assert_eq!(value1, 42);
375 assert_eq!(counter.load(Ordering::Relaxed), 1);
376
377 let value2 = memo
379 .get_or_compute("key1", || async {
380 counter.fetch_add(1, Ordering::Relaxed);
381 100
382 })
383 .await;
384
385 assert_eq!(value2, 42); assert_eq!(counter.load(Ordering::Relaxed), 1); }
388
389 #[tokio::test]
390 async fn test_async_memo_with_ttl() {
391 let memo = AsyncMemo::with_ttl(Duration::from_millis(50));
392 let counter = Arc::new(AtomicU32::new(0));
393
394 let counter_clone = counter.clone();
395 let value1 = memo
396 .get_or_compute("key1", || async {
397 counter_clone.fetch_add(1, Ordering::Relaxed);
398 42
399 })
400 .await;
401
402 assert_eq!(value1, 42);
403 assert_eq!(counter.load(Ordering::Relaxed), 1);
404
405 tokio::time::sleep(Duration::from_millis(100)).await;
407
408 let counter_clone = counter.clone();
410 let value2 = memo
411 .get_or_compute("key1", || async {
412 counter_clone.fetch_add(1, Ordering::Relaxed);
413 100
414 })
415 .await;
416
417 assert_eq!(value2, 100); assert_eq!(counter.load(Ordering::Relaxed), 2); }
420}