1use pin_project::pin_project;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::task::{Context, Poll};
6use std::time::{Duration, Instant};
7use tokio::sync::{RwLock, Semaphore};
8use tokio::time::timeout;
9use tracing::{debug, error, warn};
10
11pub struct AsyncOptimizer {
13 max_concurrent: Arc<Semaphore>,
14 default_timeout: Duration,
15 retry_config: RetryConfig,
16}
17
18#[derive(Debug, Clone)]
19pub struct RetryConfig {
20 pub max_attempts: usize,
21 pub base_delay: Duration,
22 pub max_delay: Duration,
23 pub backoff_multiplier: f64,
24}
25
26impl Default for RetryConfig {
27 fn default() -> Self {
28 Self {
29 max_attempts: 3,
30 base_delay: Duration::from_millis(100),
31 max_delay: Duration::from_secs(5),
32 backoff_multiplier: 2.0,
33 }
34 }
35}
36
37impl AsyncOptimizer {
38 pub fn new(max_concurrent: usize, default_timeout: Duration) -> Self {
40 Self {
41 max_concurrent: Arc::new(Semaphore::new(max_concurrent)),
42 default_timeout,
43 retry_config: RetryConfig::default(),
44 }
45 }
46
47 pub async fn execute<F, T>(&self, future: F) -> Result<T, AsyncError>
49 where
50 F: Future<Output = Result<T, AsyncError>>,
51 {
52 self.execute_with_timeout(future, self.default_timeout)
53 .await
54 }
55
56 pub async fn execute_with_timeout<F, T>(
58 &self,
59 future: F,
60 timeout_duration: Duration,
61 ) -> Result<T, AsyncError>
62 where
63 F: Future<Output = Result<T, AsyncError>>,
64 {
65 let _permit = self
66 .max_concurrent
67 .acquire()
68 .await
69 .map_err(|_| AsyncError::ResourceExhausted)?;
70
71 let start_time = Instant::now();
72
73 match timeout(timeout_duration, future).await {
74 Ok(result) => {
75 let duration = start_time.elapsed();
76 if duration > Duration::from_millis(100) {
77 debug!(
78 "Async operation took {:.2}ms",
79 duration.as_secs_f64() * 1000.0
80 );
81 }
82 result
83 }
84 Err(_) => {
85 warn!(
86 "Async operation timed out after {:.2}s",
87 timeout_duration.as_secs_f64()
88 );
89 Err(AsyncError::Timeout)
90 }
91 }
92 }
93
94 pub async fn execute_with_retry<F, Fut, T>(&self, mut operation: F) -> Result<T, AsyncError>
96 where
97 F: FnMut() -> Fut,
98 Fut: Future<Output = Result<T, AsyncError>>,
99 {
100 let mut attempt = 0;
101 let mut delay = self.retry_config.base_delay;
102
103 loop {
104 attempt += 1;
105
106 match self.execute(operation()).await {
107 Ok(result) => {
108 if attempt > 1 {
109 debug!("Operation succeeded on attempt {}", attempt);
110 }
111 return Ok(result);
112 }
113 Err(e) => {
114 if attempt >= self.retry_config.max_attempts {
115 error!("Operation failed after {} attempts: {:?}", attempt, e);
116 return Err(e);
117 }
118
119 if !e.is_retryable() {
120 error!("Non-retryable error: {:?}", e);
121 return Err(e);
122 }
123
124 warn!(
125 "Attempt {} failed, retrying in {:.2}ms: {:?}",
126 attempt,
127 delay.as_secs_f64() * 1000.0,
128 e
129 );
130
131 tokio::time::sleep(delay).await;
132
133 delay = Duration::from_millis(
135 (delay.as_millis() as f64 * self.retry_config.backoff_multiplier) as u64,
136 )
137 .min(self.retry_config.max_delay);
138 }
139 }
140 }
141 }
142
143 pub async fn batch_execute<F, T>(&self, mut operations: Vec<F>) -> Vec<Result<T, AsyncError>>
145 where
146 F: Future<Output = Result<T, AsyncError>>,
147 {
148 let chunk_size = self
149 .max_concurrent
150 .available_permits()
151 .min(operations.len());
152 let mut results = Vec::with_capacity(operations.len());
153
154 while !operations.is_empty() {
155 let chunk_len = chunk_size.min(operations.len());
156 let chunk: Vec<_> = operations.drain(..chunk_len).collect();
157
158 let chunk_futures: Vec<_> = chunk.into_iter().map(|op| self.execute(op)).collect();
159
160 let chunk_results = futures::future::join_all(chunk_futures).await;
161 results.extend(chunk_results);
162 }
163
164 results
165 }
166}
167
168type ProcessorFn<T> = Arc<
170 dyn Fn(Vec<T>) -> Pin<Box<dyn Future<Output = Result<(), AsyncError>> + Send>> + Send + Sync,
171>;
172
173pub struct StreamProcessor<T> {
175 buffer_size: usize,
176 batch_timeout: Duration,
177 processor: ProcessorFn<T>,
178}
179
180impl<T: Send + 'static> StreamProcessor<T> {
181 pub fn new<F, Fut>(buffer_size: usize, batch_timeout: Duration, processor: F) -> Self
183 where
184 F: Fn(Vec<T>) -> Fut + Send + Sync + 'static,
185 Fut: Future<Output = Result<(), AsyncError>> + Send + 'static,
186 {
187 Self {
188 buffer_size,
189 batch_timeout,
190 processor: Arc::new(move |batch| Box::pin(processor(batch))),
191 }
192 }
193
194 pub async fn process_stream(
196 &self,
197 mut receiver: tokio::sync::mpsc::Receiver<T>,
198 ) -> Result<(), AsyncError> {
199 let mut buffer = Vec::with_capacity(self.buffer_size);
200 let mut last_flush = Instant::now();
201
202 while let Some(item) = receiver.recv().await {
203 buffer.push(item);
204
205 let should_flush =
207 buffer.len() >= self.buffer_size || last_flush.elapsed() >= self.batch_timeout;
208
209 if should_flush {
210 self.flush_buffer(&mut buffer).await?;
211 last_flush = Instant::now();
212 }
213 }
214
215 if !buffer.is_empty() {
217 self.flush_buffer(&mut buffer).await?;
218 }
219
220 Ok(())
221 }
222
223 async fn flush_buffer(&self, buffer: &mut Vec<T>) -> Result<(), AsyncError> {
225 if buffer.is_empty() {
226 return Ok(());
227 }
228
229 let batch = std::mem::take(buffer);
230 let batch_size = batch.len();
231
232 debug!("Processing batch of {} items", batch_size);
233
234 let start = Instant::now();
235 (self.processor)(batch).await?;
236 let duration = start.elapsed();
237
238 debug!(
239 "Batch processed in {:.2}ms",
240 duration.as_secs_f64() * 1000.0
241 );
242
243 Ok(())
244 }
245}
246
247pub struct TaskPool {
249 pool: Arc<RwLock<Vec<tokio::task::JoinHandle<()>>>>,
250 max_tasks: usize,
251}
252
253impl TaskPool {
254 pub fn new(max_tasks: usize) -> Self {
256 Self {
257 pool: Arc::new(RwLock::new(Vec::with_capacity(max_tasks))),
258 max_tasks,
259 }
260 }
261
262 pub async fn spawn<F>(&self, future: F) -> Result<(), AsyncError>
264 where
265 F: Future<Output = ()> + Send + 'static,
266 {
267 let mut pool = self.pool.write().await;
268
269 pool.retain(|handle| !handle.is_finished());
271
272 if pool.len() >= self.max_tasks {
273 return Err(AsyncError::ResourceExhausted);
274 }
275
276 let handle = tokio::spawn(future);
277 pool.push(handle);
278
279 Ok(())
280 }
281
282 pub async fn join_all(&self) -> Result<(), AsyncError> {
284 let mut pool = self.pool.write().await;
285 let handles = std::mem::take(&mut *pool);
286
287 for handle in handles {
288 if let Err(e) = handle.await {
289 error!("Task failed: {:?}", e);
290 }
291 }
292
293 Ok(())
294 }
295
296 pub async fn shutdown(&self) {
298 let mut pool = self.pool.write().await;
299
300 for handle in pool.drain(..) {
301 handle.abort();
302 }
303 }
304}
305
306#[derive(Debug, Clone, thiserror::Error)]
308pub enum AsyncError {
309 #[error("Operation timed out")]
310 Timeout,
311
312 #[error("Resource exhausted")]
313 ResourceExhausted,
314
315 #[error("Task cancelled")]
316 Cancelled,
317
318 #[error("IO error: {0}")]
319 Io(String),
320
321 #[error("Network error: {0}")]
322 Network(String),
323
324 #[error("Internal error: {0}")]
325 Internal(String),
326}
327
328impl AsyncError {
329 pub fn is_retryable(&self) -> bool {
331 match self {
332 AsyncError::Timeout => true,
333 AsyncError::ResourceExhausted => true,
334 AsyncError::Network(_) => true,
335 AsyncError::Io(_) => true,
336 AsyncError::Cancelled => false,
337 AsyncError::Internal(_) => false,
338 }
339 }
340}
341
342type ErrorHandlerFn = Box<dyn Fn(&AsyncError) + Send + Sync>;
344
345#[pin_project]
347pub struct ErrorPropagationFuture<F> {
348 #[pin]
349 inner: F,
350 error_handler: Option<ErrorHandlerFn>,
351}
352
353impl<F> ErrorPropagationFuture<F> {
354 pub fn new(future: F) -> Self {
356 Self {
357 inner: future,
358 error_handler: None,
359 }
360 }
361
362 pub fn with_error_handler<H>(mut self, handler: H) -> Self
364 where
365 H: Fn(&AsyncError) + Send + Sync + 'static,
366 {
367 self.error_handler = Some(Box::new(handler));
368 self
369 }
370}
371
372impl<F, T> Future for ErrorPropagationFuture<F>
373where
374 F: Future<Output = Result<T, AsyncError>>,
375{
376 type Output = Result<T, AsyncError>;
377
378 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
379 let this = self.project();
380
381 match this.inner.poll(cx) {
382 Poll::Ready(Err(e)) => {
383 if let Some(ref handler) = this.error_handler {
384 handler(&e);
385 }
386 Poll::Ready(Err(e))
387 }
388 poll => poll,
389 }
390 }
391}
392
393pub struct ResourceLimiter {
395 memory_limit: usize,
396 task_limit: usize,
397 current_memory: Arc<RwLock<usize>>,
398 current_tasks: Arc<RwLock<usize>>,
399}
400
401impl ResourceLimiter {
402 pub fn new(memory_limit: usize, task_limit: usize) -> Self {
404 Self {
405 memory_limit,
406 task_limit,
407 current_memory: Arc::new(RwLock::new(0)),
408 current_tasks: Arc::new(RwLock::new(0)),
409 }
410 }
411
412 pub async fn check_resources(
414 &self,
415 memory_needed: usize,
416 tasks_needed: usize,
417 ) -> Result<(), AsyncError> {
418 let current_memory = *self.current_memory.read().await;
419 let current_tasks = *self.current_tasks.read().await;
420
421 if current_memory + memory_needed > self.memory_limit {
422 return Err(AsyncError::ResourceExhausted);
423 }
424
425 if current_tasks + tasks_needed > self.task_limit {
426 return Err(AsyncError::ResourceExhausted);
427 }
428
429 Ok(())
430 }
431
432 pub async fn acquire_resources(
434 self: &Arc<Self>,
435 memory: usize,
436 tasks: usize,
437 ) -> Result<ResourceGuard, AsyncError> {
438 self.check_resources(memory, tasks).await?;
439
440 {
441 let mut current_memory = self.current_memory.write().await;
442 let mut current_tasks = self.current_tasks.write().await;
443
444 *current_memory += memory;
445 *current_tasks += tasks;
446 }
447
448 Ok(ResourceGuard {
449 limiter: Arc::clone(self),
450 memory,
451 tasks,
452 })
453 }
454
455 async fn release_resources(&self, memory: usize, tasks: usize) {
457 let mut current_memory = self.current_memory.write().await;
458 let mut current_tasks = self.current_tasks.write().await;
459
460 *current_memory = current_memory.saturating_sub(memory);
461 *current_tasks = current_tasks.saturating_sub(tasks);
462 }
463}
464
465pub struct ResourceGuard {
467 limiter: Arc<ResourceLimiter>,
468 memory: usize,
469 tasks: usize,
470}
471
472impl Drop for ResourceGuard {
473 fn drop(&mut self) {
474 let limiter = Arc::clone(&self.limiter);
477 let memory = self.memory;
478 let tasks = self.tasks;
479
480 if let Ok(handle) = tokio::runtime::Handle::try_current() {
481 handle.spawn(async move {
482 limiter.release_resources(memory, tasks).await;
483 });
484 }
485 }
486}
487
488#[cfg(test)]
489mod tests {
490 use super::*;
491
492 #[tokio::test]
493 async fn test_async_optimizer() {
494 let optimizer = AsyncOptimizer::new(2, Duration::from_secs(1));
495
496 let result = optimizer
497 .execute(async { Ok::<i32, AsyncError>(42) })
498 .await
499 .unwrap();
500
501 assert_eq!(result, 42);
502 }
503
504 #[tokio::test]
505 async fn test_retry_logic() {
506 let optimizer = AsyncOptimizer::new(2, Duration::from_secs(1));
507 let mut attempt = 0;
508
509 let result = optimizer
510 .execute_with_retry(|| {
511 attempt += 1;
512 async move {
513 if attempt < 3 {
514 Err(AsyncError::Network("temporary failure".to_string()))
515 } else {
516 Ok(42)
517 }
518 }
519 })
520 .await
521 .unwrap();
522
523 assert_eq!(result, 42);
524 assert_eq!(attempt, 3);
525 }
526
527 #[tokio::test]
528 async fn test_resource_limiter() {
529 let limiter = Arc::new(ResourceLimiter::new(1000, 10));
530
531 let _guard = limiter.acquire_resources(500, 5).await.unwrap();
533
534 assert!(limiter.acquire_resources(600, 3).await.is_err());
536 }
537
538 #[tokio::test]
539 async fn test_task_pool() {
540 let pool = TaskPool::new(2);
541
542 pool.spawn(async {
543 tokio::time::sleep(Duration::from_millis(10)).await;
544 })
545 .await
546 .unwrap();
547
548 pool.spawn(async {
549 tokio::time::sleep(Duration::from_millis(10)).await;
550 })
551 .await
552 .unwrap();
553
554 assert!(pool.spawn(async {}).await.is_err());
556
557 pool.join_all().await.unwrap();
558 }
559}