queued_task/
lib.rs

1#![doc = include_str!("../README.MD")]
2use std::future::Future;
3use std::sync::Arc;
4use std::time::Duration;
5use tokio::sync::mpsc::error::{SendError, SendTimeoutError};
6use tokio::sync::mpsc::{Receiver, Sender};
7use tokio::sync::{mpsc, Mutex};
8use tokio::sync::{Notify, Semaphore};
9use tokio::time::Instant;
10use tracing::{Instrument, Span};
11
12struct Shared<R> {
13    notify: Arc<Notify>,
14    data: Arc<Mutex<Option<R>>>,
15}
16
17impl<R> Clone for Shared<R> {
18    fn clone(&self) -> Self {
19        Self {
20            notify: self.notify.clone(),
21            data: self.data.clone(),
22        }
23    }
24}
25
26impl<R> Shared<R> {
27    fn new() -> Self {
28        Self {
29            notify: Arc::new(Notify::new()),
30            data: Arc::new(Mutex::new(None)),
31        }
32    }
33
34    async fn set_result(self, result: R) {
35        self.data.lock().await.replace(result);
36        self.notify.notify_one();
37    }
38
39    async fn wait_result(self) -> Option<R> {
40        self.notify.notified().await;
41        self.data.lock().await.take()
42    }
43}
44
45pub struct Task<T, R> {
46    inner: T,
47    shared: Shared<R>,
48    start_time: Instant,
49    span: Option<Span>,
50}
51
52impl<T, R> Task<T, R> {
53    fn new(inner: T, shared: Shared<R>, span: Option<Span>) -> Self {
54        Self {
55            inner,
56            shared,
57            start_time: Instant::now(),
58            span,
59        }
60    }
61}
62
63pub struct TaskState<R> {
64    shared: Shared<R>,
65}
66
67impl<R> TaskState<R> {
68    pub async fn wait_result(self) -> Option<R> {
69        self.shared.wait_result().await
70    }
71}
72
73// #[derive(Debug)]
74// pub struct Config {
75//     length: usize,
76//     keep_alive_timeout: Duration,
77// }
78//
79// impl Default for Config {
80//     fn default() -> Self {
81//         Self {
82//             length: 16,
83//             keep_alive_timeout: Duration::from_secs(30),
84//         }
85//     }
86// }
87
88pub struct QueuedTask<T, R> {
89    sender: Sender<Task<T, R>>,
90}
91
92impl<T, R> QueuedTask<T, R> {
93    pub fn capacity(&self) -> usize {
94        self.sender.capacity()
95    }
96
97    pub async fn push(&self, inner: T) -> Result<TaskState<R>, SendError<Task<T, R>>> {
98        self.push_with_span(inner, None).await
99    }
100
101    pub async fn push_with_span(
102        &self,
103        inner: T,
104        span: Option<Span>,
105    ) -> Result<TaskState<R>, SendError<Task<T, R>>> {
106        let shared = Shared::new();
107        self.sender
108            .send(Task::new(inner, shared.clone(), span))
109            .await?;
110        Ok(TaskState { shared })
111    }
112
113    pub async fn push_timeout(
114        &self,
115        inner: T,
116        timeout: Duration,
117    ) -> Result<TaskState<R>, SendTimeoutError<Task<T, R>>> {
118        self.push_timeout_with_span(inner, timeout, None).await
119    }
120
121    pub async fn push_timeout_with_span(
122        &self,
123        inner: T,
124        time_out: Duration,
125        span: Option<Span>,
126    ) -> Result<TaskState<R>, SendTimeoutError<Task<T, R>>> {
127        let shared = Shared::new();
128        self.sender
129            .send_timeout(Task::new(inner, shared.clone(), span), time_out)
130            .await?;
131        Ok(TaskState { shared })
132    }
133}
134
135pub struct QueuedTaskBuilder<F, T, R> {
136    // config: Config,
137    handle: Option<F>,
138    sem: Semaphore,
139    sender: Sender<Task<T, R>>,
140    receiver: Receiver<Task<T, R>>,
141}
142
143impl<F, T, Fut, R> QueuedTaskBuilder<F, T, R>
144where
145    F: Fn(Duration, T) -> Fut + Send + Sync + 'static,
146    Fut: Future<Output = R> + Send + 'static,
147    T: Send + 'static,
148    R: Send + 'static,
149{
150    pub fn new(queue_len: usize, rate: usize) -> Self {
151        let (sender, receiver) = mpsc::channel(queue_len);
152        Self {
153            // config,
154            sem: Semaphore::new(rate),
155            handle: None,
156            sender,
157            receiver,
158        }
159    }
160
161    pub fn handle(mut self, f: F) -> Self {
162        self.handle = Some(f);
163        self
164    }
165
166    pub fn build(self) -> QueuedTask<T, R> {
167        let Self {
168            sem,
169            mut handle,
170            sender,
171            mut receiver,
172            ..
173        } = self;
174        let handle = handle.take().unwrap();
175        tokio::spawn(async move {
176            let arc_sem = Arc::new(sem);
177            let arc_handle = Arc::new(handle);
178            while let Some(Task {
179                inner,
180                shared,
181                start_time,
182                span,
183            }) = receiver.recv().await
184            {
185                let p = arc_sem.clone().acquire_owned().await.unwrap();
186                let h = arc_handle.clone();
187                match span {
188                    None => {
189                        tokio::spawn(async move {
190                            let wait = start_time.elapsed();
191                            let result = h(wait, inner).await;
192                            shared.set_result(result).await;
193                            drop(p)
194                        });
195                    }
196                    Some(span) => {
197                        tokio::spawn(async move {
198                            let wait = start_time.elapsed();
199                            let result = span
200                                .in_scope(|| async { h(wait, inner).await }.in_current_span())
201                                .await;
202                            shared.set_result(result).await;
203                            drop(p);
204                        });
205                    }
206                }
207            }
208        });
209        QueuedTask { sender }
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216
217    #[tokio::test]
218    async fn test() {
219        let t = Arc::new(QueuedTaskBuilder::new(10, 2).handle(handle).build());
220
221        async fn handle(wait_time: Duration, c: usize) -> usize {
222            tokio::time::sleep(Duration::from_secs(1)).await;
223            println!("{} {}", c, wait_time.as_millis());
224            c
225        }
226
227        let mut ts = vec![];
228
229        for i in 0..20 {
230            let tt = t.clone();
231            ts.push(tokio::spawn(async move {
232                // push task
233                let state = tt.push(i).await.unwrap();
234                // waiting for task result
235                let result = state.wait_result().await;
236                dbg!(result);
237            }));
238        }
239
240        for x in ts {
241            let _ = x.await;
242        }
243    }
244}