flowly_service/
concurrent_each.rs1use std::{
2 marker::PhantomData,
3 pin::{Pin, pin},
4 sync::Arc,
5 task::{Poll, ready},
6};
7
8use futures::{FutureExt, Stream, StreamExt};
9use tokio::sync::{Mutex, OwnedMutexGuard};
10
11use crate::{Context, Service};
12
13pub struct ConcurrentRx<T: Send> {
14 guard: OwnedMutexGuard<flowly_spsc::Receiver<Option<T>>>,
15}
16
17impl<T: Send> Stream for ConcurrentRx<T> {
18 type Item = T;
19
20 fn poll_next(
21 mut self: Pin<&mut Self>,
22 cx: &mut std::task::Context<'_>,
23 ) -> Poll<Option<Self::Item>> {
24 match ready!(self.guard.poll_recv(cx)) {
25 Some(Some(val)) => Poll::Ready(Some(val)),
26 Some(None) => Poll::Ready(None),
27 None => Poll::Ready(None),
28 }
29 }
30}
31
32struct ConcurrentTask<I: Send, S: Service<I>> {
33 #[allow(dead_code)]
34 id: u32,
35 tx: flowly_spsc::Sender<I>,
36 m: PhantomData<S>,
37 _handle: tokio::task::JoinHandle<()>,
38 rx: Arc<Mutex<flowly_spsc::Receiver<Option<S::Out>>>>,
39}
40
41impl<I, S> ConcurrentTask<I, S>
42where
43 S::Out: Send + 'static,
44 I: Send + 'static,
45 S: Service<I> + Send + 'static,
46{
47 fn new(id: u32, mut s: S, cx: Context) -> Self {
48 let (tx, mut in_rx) = flowly_spsc::channel(1);
49 let (mut out_tx, out_rx) = flowly_spsc::channel(1);
50
51 let _handle = tokio::spawn(async move {
52 'recv: while let Some(item) = in_rx.recv().await {
53 let mut s = pin!(s.handle(item, &cx));
54
55 while let Some(x) = s.next().await {
56 if out_tx.send(Some(x)).await.is_err() {
57 log::error!("cannot send the message. channel closed!");
58 break 'recv;
59 }
60 }
61
62 if out_tx.send(None).await.is_err() {
63 log::error!("cannot send the message. channel closed!");
64 break 'recv;
65 }
66 }
67 });
68
69 Self {
70 id,
71 tx,
72 rx: Arc::new(tokio::sync::Mutex::new(out_rx)),
73 _handle,
74 m: PhantomData,
75 }
76 }
77
78 #[inline]
79 fn is_available(&self) -> bool {
80 self.rx.try_lock().is_ok()
81 }
82
83 #[inline]
84 async fn send(
85 &mut self,
86 input: I,
87 ) -> Result<ConcurrentRx<S::Out>, flowly_spsc::TrySendError<I>> {
88 self.tx.send(input).await?;
89
90 Ok(ConcurrentRx {
91 guard: self.rx.clone().lock_owned().await,
92 })
93 }
94}
95
96pub struct ConcurrentEach<I: Send + 'static, S: Service<I>> {
97 service: S,
98 tasks: Vec<ConcurrentTask<I, S>>,
99 _m: PhantomData<I>,
100 limit: usize,
101}
102
103impl<I, S> ConcurrentEach<I, S>
104where
105 I: Send,
106 S: Service<I> + Send,
107 S::Out: Send,
108{
109 pub fn new(service: S, limit: usize) -> Self {
110 Self {
111 service,
112 tasks: Vec::with_capacity(limit),
113 _m: PhantomData,
114 limit,
115 }
116 }
117}
118
119impl<I, R, E, S> Service<I> for ConcurrentEach<I, S>
120where
121 I: Send,
122 R: Send + 'static,
123 E: Send + 'static,
124 S: Service<I, Out = Result<R, E>> + Clone + Send + 'static,
125{
126 type Out = Result<ConcurrentRx<S::Out>, E>;
127
128 fn handle(&mut self, input: I, cx: &Context) -> impl Stream<Item = Self::Out> + Send {
129 async move {
130 let index = if self.tasks.len() < self.limit {
131 let index = self.tasks.len();
132 self.tasks.push(ConcurrentTask::new(
133 index as u32,
134 self.service.clone(),
135 cx.clone(),
136 ));
137 index
138 } else {
139 let mut index = fastrand::usize(0..self.tasks.len());
140
141 for idx in 0..self.tasks.len() {
142 let idx = (idx + self.tasks.len()) % self.tasks.len();
143 if self.tasks[idx].is_available() {
144 index = idx;
145 break;
146 }
147 }
148
149 index
150 };
151
152 Ok(self.tasks[index].send(input).await.unwrap())
153 }
154 .into_stream()
155 }
156}
157
158pub fn concurrent_each<I, S>(service: S, limit: usize) -> ConcurrentEach<I, S>
159where
160 I: Send,
161 S: Send + Service<I> + Clone + 'static,
162 S::Out: Send,
163{
164 ConcurrentEach::new(service, limit)
165}