1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
use futures::future::{BoxFuture, FutureExt};
use futures::stream::futures_unordered::FuturesUnordered;
use futures::stream::StreamExt;
use parking_lot::Mutex;
use std::sync::Arc;
use tokio::sync::Notify;
pub type Driver = BoxFuture<'static, ()>;
struct TaskAggInner {
notify: Arc<Notify>,
driver_list: Vec<Driver>,
n_driver: bool,
d_count: usize,
}
#[derive(Clone)]
pub struct TaskAgg(Arc<Mutex<TaskAggInner>>);
trait S: 'static + Send + Sync {}
impl S for TaskAgg {}
impl TaskAgg {
pub fn new() -> (Driver, Self) {
let inner = Arc::new(Mutex::new(TaskAggInner {
notify: Arc::new(Notify::new()),
driver_list: Vec::new(),
n_driver: false,
d_count: 0,
}));
let driver = {
let inner = inner.clone();
async move {
let mut fu = FuturesUnordered::new();
let mut driver_list = Vec::new();
loop {
let cont = {
let mut lock = inner.lock();
if lock.d_count == 0 {
false
} else {
driver_list.append(&mut lock.driver_list);
if !lock.n_driver {
lock.n_driver = true;
let n = lock.notify.clone();
let inner = inner.clone();
driver_list.push(
async move {
n.notified().await;
let mut lock = inner.lock();
lock.n_driver = false;
}
.boxed(),
);
}
true
}
};
if cont {
for driver in driver_list.drain(..) {
fu.push(driver);
}
} else {
break;
}
let _ = fu.next().await;
}
}
.boxed()
};
(driver, Self(inner))
}
pub fn push(&self, f: Driver) {
let inner = self.0.clone();
let mut lock = self.0.lock();
lock.d_count += 1;
lock.driver_list.push(
async move {
f.await;
inner.lock().d_count -= 1;
}
.boxed(),
);
lock.notify.notify_waiters();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test(flavor = "multi_thread")]
async fn test_task_agg() {
let (s, r) = tokio::sync::oneshot::channel();
let (driver, agg) = TaskAgg::new();
let agg2 = agg.clone();
agg.push(
async move {
agg2.push(
async move {
println!("test");
s.send(()).unwrap();
}
.boxed(),
);
}
.boxed(),
);
driver.await;
r.await.unwrap();
}
}