1#![doc = include_str!("../README.MD")]
2use std::future::Future;
3use std::sync::Arc;
4use std::time::Duration;
5use tokio::sync::mpsc::error::{SendError, SendTimeoutError};
6use tokio::sync::mpsc::{Receiver, Sender};
7use tokio::sync::{mpsc, Mutex};
8use tokio::sync::{Notify, Semaphore};
9use tokio::time::Instant;
10
11struct Shared<R> {
12 notify: Arc<Notify>,
13 data: Arc<Mutex<Option<R>>>,
14}
15
16impl<R> Clone for Shared<R> {
17 fn clone(&self) -> Self {
18 Self {
19 notify: self.notify.clone(),
20 data: self.data.clone(),
21 }
22 }
23}
24
25impl<R> Shared<R> {
26 fn new() -> Self {
27 Self {
28 notify: Arc::new(Notify::new()),
29 data: Arc::new(Mutex::new(None)),
30 }
31 }
32
33 async fn set_result(self, result: R) {
34 self.data.lock().await.replace(result);
35 self.notify.notify_one();
36 }
37
38 async fn wait_result(self) -> Option<R> {
39 self.notify.notified().await;
40 self.data.lock().await.take()
41 }
42}
43
44pub struct Task<T, R> {
45 inner: T,
46 shared: Shared<R>,
47 start_time: Instant,
48}
49
50impl<T, R> Task<T, R> {
51 fn new(inner: T, shared: Shared<R>) -> Self {
52 Self {
53 inner,
54 shared,
55 start_time: Instant::now(),
56 }
57 }
58}
59
60pub struct TaskState<R> {
61 shared: Shared<R>,
62}
63
64impl<R> TaskState<R> {
65 pub async fn wait_result(self) -> Option<R> {
66 self.shared.wait_result().await
67 }
68}
69
70pub struct QueuedTask<T, R> {
86 sender: Sender<Task<T, R>>,
87}
88
89impl<T, R> QueuedTask<T, R> {
90 pub fn capacity(&self) -> usize {
91 self.sender.capacity()
92 }
93
94 pub async fn push(&self, inner: T) -> Result<TaskState<R>, SendError<Task<T, R>>> {
95 let shared = Shared::new();
96 self.sender.send(Task::new(inner, shared.clone())).await?;
97 Ok(TaskState { shared })
98 }
99
100 pub async fn push_timeout(
101 &self,
102 inner: T,
103 time_out: Duration,
104 ) -> Result<TaskState<R>, SendTimeoutError<Task<T, R>>> {
105 let shared = Shared::new();
106 self.sender
107 .send_timeout(Task::new(inner, shared.clone()), time_out)
108 .await?;
109 Ok(TaskState { shared })
110 }
111}
112
113pub struct QueuedTaskBuilder<F, T, R> {
114 handle: Option<F>,
116 sem: Semaphore,
117 sender: Sender<Task<T, R>>,
118 receiver: Receiver<Task<T, R>>,
119}
120
121impl<F, T, Fut, R> QueuedTaskBuilder<F, T, R>
122where
123 F: Fn(Duration, T) -> Fut + Send + Sync + 'static,
124 Fut: Future<Output = R> + Send + 'static,
125 T: Send + 'static,
126 R: Send + 'static,
127{
128 pub fn new(queue_len: usize, rate: usize) -> Self {
129 let (sender, receiver) = mpsc::channel(queue_len);
130 Self {
131 sem: Semaphore::new(rate),
133 handle: None,
134 sender,
135 receiver,
136 }
137 }
138
139 pub fn handle(mut self, f: F) -> Self {
140 self.handle = Some(f);
141 self
142 }
143
144 pub fn build(self) -> QueuedTask<T, R> {
145 let Self {
146 sem,
147 mut handle,
148 sender,
149 mut receiver,
150 ..
151 } = self;
152 let handle = handle.take().unwrap();
153 tokio::spawn(async move {
154 let arc_sem = Arc::new(sem);
155 let arc_handle = Arc::new(handle);
156 while let Some(Task {
157 inner,
158 shared,
159 start_time,
160 }) = receiver.recv().await
161 {
162 let p = arc_sem.clone().acquire_owned().await.unwrap();
163 let h = arc_handle.clone();
164 tokio::spawn(async move {
165 let wait = start_time.elapsed();
166 let result = h(wait, inner).await;
167 shared.set_result(result).await;
168 drop(p)
169 });
170 }
171 });
172 QueuedTask { sender }
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179
180 #[tokio::test]
181 async fn test() {
182 let t = Arc::new(QueuedTaskBuilder::new(10, 2).handle(handle).build());
183
184 async fn handle(wait_time: Duration, c: usize) -> usize {
185 tokio::time::sleep(Duration::from_secs(1)).await;
186 println!("{} {}", c, wait_time.as_millis());
187 c
188 }
189
190 let mut ts = vec![];
191
192 for i in 0..20 {
193 let tt = t.clone();
194 ts.push(tokio::spawn(async move {
195 let state = tt.push(i).await.unwrap();
197 let result = state.wait_result().await;
199 dbg!(result);
200 }));
201 }
202
203 for x in ts {
204 let _ = x.await;
205 }
206 }
207}