1#![doc = include_str!("../README.MD")]
2use std::future::Future;
3use std::sync::Arc;
4use std::time::Duration;
5use tokio::sync::mpsc::error::{SendError, SendTimeoutError};
6use tokio::sync::mpsc::{Receiver, Sender};
7use tokio::sync::{mpsc, Mutex};
8use tokio::sync::{Notify, Semaphore};
9use tokio::time::Instant;
10use tracing::{Instrument, Span};
11
12struct Shared<R> {
13 notify: Arc<Notify>,
14 data: Arc<Mutex<Option<R>>>,
15}
16
17impl<R> Clone for Shared<R> {
18 fn clone(&self) -> Self {
19 Self {
20 notify: self.notify.clone(),
21 data: self.data.clone(),
22 }
23 }
24}
25
26impl<R> Shared<R> {
27 fn new() -> Self {
28 Self {
29 notify: Arc::new(Notify::new()),
30 data: Arc::new(Mutex::new(None)),
31 }
32 }
33
34 async fn set_result(self, result: R) {
35 self.data.lock().await.replace(result);
36 self.notify.notify_one();
37 }
38
39 async fn wait_result(self) -> Option<R> {
40 self.notify.notified().await;
41 self.data.lock().await.take()
42 }
43}
44
45pub struct Task<T, R> {
46 inner: T,
47 shared: Shared<R>,
48 start_time: Instant,
49 span: Option<Span>,
50}
51
52impl<T, R> Task<T, R> {
53 fn new(inner: T, shared: Shared<R>, span: Option<Span>) -> Self {
54 Self {
55 inner,
56 shared,
57 start_time: Instant::now(),
58 span,
59 }
60 }
61}
62
63pub struct TaskState<R> {
64 shared: Shared<R>,
65}
66
67impl<R> TaskState<R> {
68 pub async fn wait_result(self) -> Option<R> {
69 self.shared.wait_result().await
70 }
71}
72
73pub struct QueuedTask<T, R> {
89 sender: Sender<Task<T, R>>,
90}
91
92impl<T, R> QueuedTask<T, R> {
93 pub fn capacity(&self) -> usize {
94 self.sender.capacity()
95 }
96
97 pub async fn push(&self, inner: T) -> Result<TaskState<R>, SendError<Task<T, R>>> {
98 self.push_with_span(inner, None).await
99 }
100
101 pub async fn push_with_span(
102 &self,
103 inner: T,
104 span: Option<Span>,
105 ) -> Result<TaskState<R>, SendError<Task<T, R>>> {
106 let shared = Shared::new();
107 self.sender
108 .send(Task::new(inner, shared.clone(), span))
109 .await?;
110 Ok(TaskState { shared })
111 }
112
113 pub async fn push_timeout(
114 &self,
115 inner: T,
116 timeout: Duration,
117 ) -> Result<TaskState<R>, SendTimeoutError<Task<T, R>>> {
118 self.push_timeout_with_span(inner, timeout, None).await
119 }
120
121 pub async fn push_timeout_with_span(
122 &self,
123 inner: T,
124 time_out: Duration,
125 span: Option<Span>,
126 ) -> Result<TaskState<R>, SendTimeoutError<Task<T, R>>> {
127 let shared = Shared::new();
128 self.sender
129 .send_timeout(Task::new(inner, shared.clone(), span), time_out)
130 .await?;
131 Ok(TaskState { shared })
132 }
133}
134
135pub struct QueuedTaskBuilder<F, T, R> {
136 handle: Option<F>,
138 sem: Semaphore,
139 sender: Sender<Task<T, R>>,
140 receiver: Receiver<Task<T, R>>,
141}
142
143impl<F, T, Fut, R> QueuedTaskBuilder<F, T, R>
144where
145 F: Fn(Duration, T) -> Fut + Send + Sync + 'static,
146 Fut: Future<Output = R> + Send + 'static,
147 T: Send + 'static,
148 R: Send + 'static,
149{
150 pub fn new(queue_len: usize, rate: usize) -> Self {
151 let (sender, receiver) = mpsc::channel(queue_len);
152 Self {
153 sem: Semaphore::new(rate),
155 handle: None,
156 sender,
157 receiver,
158 }
159 }
160
161 pub fn handle(mut self, f: F) -> Self {
162 self.handle = Some(f);
163 self
164 }
165
166 pub fn build(self) -> QueuedTask<T, R> {
167 let Self {
168 sem,
169 mut handle,
170 sender,
171 mut receiver,
172 ..
173 } = self;
174 let handle = handle.take().unwrap();
175 tokio::spawn(async move {
176 let arc_sem = Arc::new(sem);
177 let arc_handle = Arc::new(handle);
178 while let Some(Task {
179 inner,
180 shared,
181 start_time,
182 span,
183 }) = receiver.recv().await
184 {
185 let p = arc_sem.clone().acquire_owned().await.unwrap();
186 let h = arc_handle.clone();
187 match span {
188 None => {
189 tokio::spawn(async move {
190 let wait = start_time.elapsed();
191 let result = h(wait, inner).await;
192 shared.set_result(result).await;
193 drop(p)
194 });
195 }
196 Some(span) => {
197 tokio::spawn(async move {
198 let wait = start_time.elapsed();
199 let result = span
200 .in_scope(|| async { h(wait, inner).await }.in_current_span())
201 .await;
202 shared.set_result(result).await;
203 drop(p);
204 });
205 }
206 }
207 }
208 });
209 QueuedTask { sender }
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216
217 #[tokio::test]
218 async fn test() {
219 let t = Arc::new(QueuedTaskBuilder::new(10, 2).handle(handle).build());
220
221 async fn handle(wait_time: Duration, c: usize) -> usize {
222 tokio::time::sleep(Duration::from_secs(1)).await;
223 println!("{} {}", c, wait_time.as_millis());
224 c
225 }
226
227 let mut ts = vec![];
228
229 for i in 0..20 {
230 let tt = t.clone();
231 ts.push(tokio::spawn(async move {
232 let state = tt.push(i).await.unwrap();
234 let result = state.wait_result().await;
236 dbg!(result);
237 }));
238 }
239
240 for x in ts {
241 let _ = x.await;
242 }
243 }
244}