queued_task/
lib.rs

1#![doc = include_str!("../README.MD")]
2use std::future::Future;
3use std::sync::Arc;
4use std::time::Duration;
5use tokio::sync::mpsc::error::{SendError, SendTimeoutError};
6use tokio::sync::mpsc::{Receiver, Sender};
7use tokio::sync::{mpsc, Mutex};
8use tokio::sync::{Notify, Semaphore};
9use tokio::time::Instant;
10
11struct Shared<R> {
12    notify: Arc<Notify>,
13    data: Arc<Mutex<Option<R>>>,
14}
15
16impl<R> Clone for Shared<R> {
17    fn clone(&self) -> Self {
18        Self {
19            notify: self.notify.clone(),
20            data: self.data.clone(),
21        }
22    }
23}
24
25impl<R> Shared<R> {
26    fn new() -> Self {
27        Self {
28            notify: Arc::new(Notify::new()),
29            data: Arc::new(Mutex::new(None)),
30        }
31    }
32
33    async fn set_result(self, result: R) {
34        self.data.lock().await.replace(result);
35        self.notify.notify_one();
36    }
37
38    async fn wait_result(self) -> Option<R> {
39        self.notify.notified().await;
40        self.data.lock().await.take()
41    }
42}
43
44pub struct Task<T, R> {
45    inner: T,
46    shared: Shared<R>,
47    start_time: Instant,
48}
49
50impl<T, R> Task<T, R> {
51    fn new(inner: T, shared: Shared<R>) -> Self {
52        Self {
53            inner,
54            shared,
55            start_time: Instant::now(),
56        }
57    }
58}
59
60pub struct TaskState<R> {
61    shared: Shared<R>,
62}
63
64impl<R> TaskState<R> {
65    pub async fn wait_result(self) -> Option<R> {
66        self.shared.wait_result().await
67    }
68}
69
70// #[derive(Debug)]
71// pub struct Config {
72//     length: usize,
73//     keep_alive_timeout: Duration,
74// }
75//
76// impl Default for Config {
77//     fn default() -> Self {
78//         Self {
79//             length: 16,
80//             keep_alive_timeout: Duration::from_secs(30),
81//         }
82//     }
83// }
84
85pub struct QueuedTask<T, R> {
86    sender: Sender<Task<T, R>>,
87}
88
89impl<T, R> QueuedTask<T, R> {
90    pub fn capacity(&self) -> usize {
91        self.sender.capacity()
92    }
93
94    pub async fn push(&self, inner: T) -> Result<TaskState<R>, SendError<Task<T, R>>> {
95        let shared = Shared::new();
96        self.sender.send(Task::new(inner, shared.clone())).await?;
97        Ok(TaskState { shared })
98    }
99
100    pub async fn push_timeout(
101        &self,
102        inner: T,
103        time_out: Duration,
104    ) -> Result<TaskState<R>, SendTimeoutError<Task<T, R>>> {
105        let shared = Shared::new();
106        self.sender
107            .send_timeout(Task::new(inner, shared.clone()), time_out)
108            .await?;
109        Ok(TaskState { shared })
110    }
111}
112
113pub struct QueuedTaskBuilder<F, T, R> {
114    // config: Config,
115    handle: Option<F>,
116    sem: Semaphore,
117    sender: Sender<Task<T, R>>,
118    receiver: Receiver<Task<T, R>>,
119}
120
121impl<F, T, Fut, R> QueuedTaskBuilder<F, T, R>
122where
123    F: Fn(Duration, T) -> Fut + Send + Sync + 'static,
124    Fut: Future<Output = R> + Send + 'static,
125    T: Send + 'static,
126    R: Send + 'static,
127{
128    pub fn new(queue_len: usize, rate: usize) -> Self {
129        let (sender, receiver) = mpsc::channel(queue_len);
130        Self {
131            // config,
132            sem: Semaphore::new(rate),
133            handle: None,
134            sender,
135            receiver,
136        }
137    }
138
139    pub fn handle(mut self, f: F) -> Self {
140        self.handle = Some(f);
141        self
142    }
143
144    pub fn build(self) -> QueuedTask<T, R> {
145        let Self {
146            sem,
147            mut handle,
148            sender,
149            mut receiver,
150            ..
151        } = self;
152        let handle = handle.take().unwrap();
153        tokio::spawn(async move {
154            let arc_sem = Arc::new(sem);
155            let arc_handle = Arc::new(handle);
156            while let Some(Task {
157                inner,
158                shared,
159                start_time,
160            }) = receiver.recv().await
161            {
162                let p = arc_sem.clone().acquire_owned().await.unwrap();
163                let h = arc_handle.clone();
164                tokio::spawn(async move {
165                    let wait = start_time.elapsed();
166                    let result = h(wait, inner).await;
167                    shared.set_result(result).await;
168                    drop(p)
169                });
170            }
171        });
172        QueuedTask { sender }
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    #[tokio::test]
181    async fn test() {
182        let t = Arc::new(QueuedTaskBuilder::new(10, 2).handle(handle).build());
183
184        async fn handle(wait_time: Duration, c: usize) -> usize {
185            tokio::time::sleep(Duration::from_secs(1)).await;
186            println!("{} {}", c, wait_time.as_millis());
187            c
188        }
189
190        let mut ts = vec![];
191
192        for i in 0..20 {
193            let tt = t.clone();
194            ts.push(tokio::spawn(async move {
195                // push task
196                let state = tt.push(i).await.unwrap();
197                // waiting for task result
198                let result = state.wait_result().await;
199                dbg!(result);
200            }));
201        }
202
203        for x in ts {
204            let _ = x.await;
205        }
206    }
207}