1use std::{
2 marker::PhantomData,
3 pin::{Pin, pin},
4 sync::Arc,
5 task::{Poll, ready},
6};
7
8use futures::{FutureExt, Stream, StreamExt};
9use tokio::sync::{Mutex, OwnedMutexGuard};
10
11use crate::{Context, Service};
12
13pub struct ConcurrentRx<T: Send> {
14 guard: OwnedMutexGuard<flowly_spsc::Receiver<Option<T>>>,
15}
16
17impl<T: Send> Stream for ConcurrentRx<T> {
18 type Item = T;
19
20 fn poll_next(
21 mut self: Pin<&mut Self>,
22 cx: &mut std::task::Context<'_>,
23 ) -> Poll<Option<Self::Item>> {
24 match ready!(self.guard.poll_recv(cx)) {
25 Some(Some(val)) => Poll::Ready(Some(val)),
26 Some(None) => Poll::Ready(None),
27 None => Poll::Ready(None),
28 }
29 }
30}
31
32struct ConcurrentTask<I: Send, S: Service<I>> {
33 #[allow(dead_code)]
34 id: u32,
35 tx: flowly_spsc::Sender<I>,
36 m: PhantomData<S>,
37 _handle: tokio::task::JoinHandle<()>,
38 rx: Arc<Mutex<flowly_spsc::Receiver<Option<S::Out>>>>,
39}
40
41impl<I, S> ConcurrentTask<I, S>
42where
43 S::Out: Send + 'static,
44 I: Send + 'static,
45 S: Service<I> + Send + 'static,
46{
47 fn new(id: u32, mut s: S, cx: Context) -> Self {
48 let (tx, mut in_rx) = flowly_spsc::channel(1);
49 let (mut out_tx, out_rx) = flowly_spsc::channel(1);
50
51 let _handle = tokio::spawn(async move {
52 'recv: while let Some(item) = in_rx.recv().await {
53 let mut s = pin!(s.handle(item, &cx));
54
55 while let Some(x) = s.next().await {
56 if out_tx.send(Some(x)).await.is_err() {
57 log::error!("cannot send the message. channel closed!");
58 break 'recv;
59 }
60 }
61
62 if out_tx.send(None).await.is_err() {
63 log::error!("cannot send the message. channel closed!");
64 break 'recv;
65 }
66 }
67 });
68
69 Self {
70 id,
71 tx,
72 rx: Arc::new(tokio::sync::Mutex::new(out_rx)),
73 _handle,
74 m: PhantomData,
75 }
76 }
77
78 #[inline]
79 fn is_available(&self) -> bool {
80 self.rx.try_lock().is_ok()
81 }
82
83 #[inline]
84 async fn send(
85 &mut self,
86 input: I,
87 ) -> Result<ConcurrentRx<S::Out>, flowly_spsc::TrySendError<I>> {
88 self.tx.send(input).await?;
89
90 Ok(ConcurrentRx {
91 guard: self.rx.clone().lock_owned().await,
92 })
93 }
94}
95
96pub struct ConcurrentEach<I: Send + 'static, S: Service<I>> {
97 service: S,
98 tasks: Vec<ConcurrentTask<I, S>>,
99 _m: PhantomData<I>,
100 limit: usize,
101}
102
103impl<I: Send + 'static + Clone, S: Service<I> + Clone> Clone for ConcurrentEach<I, S> {
104 fn clone(&self) -> Self {
105 Self {
106 service: self.service.clone(),
107 tasks: Vec::new(),
108 _m: self._m,
109 limit: self.limit,
110 }
111 }
112}
113
114impl<I, S> ConcurrentEach<I, S>
115where
116 I: Send,
117 S: Service<I> + Send,
118 S::Out: Send,
119{
120 pub fn new(service: S, limit: usize) -> Self {
121 Self {
122 service,
123 tasks: Vec::with_capacity(limit),
124 _m: PhantomData,
125 limit,
126 }
127 }
128}
129
130impl<I, R, E, S> Service<I> for ConcurrentEach<I, S>
131where
132 I: Send,
133 R: Send + 'static,
134 E: Send + 'static,
135 S: Service<I, Out = Result<R, E>> + Clone + Send + 'static,
136{
137 type Out = Result<ConcurrentRx<S::Out>, E>;
138
139 fn handle(&mut self, input: I, cx: &Context) -> impl Stream<Item = Self::Out> + Send {
140 async move {
141 let index = if self.tasks.len() < self.limit {
142 let index = self.tasks.len();
143 self.tasks.push(ConcurrentTask::new(
144 index as u32,
145 self.service.clone(),
146 cx.clone(),
147 ));
148 index
149 } else {
150 let mut index = fastrand::usize(0..self.tasks.len());
151
152 for idx in 0..self.tasks.len() {
153 let idx = (idx + self.tasks.len()) % self.tasks.len();
154 if self.tasks[idx].is_available() {
155 index = idx;
156 break;
157 }
158 }
159
160 index
161 };
162
163 Ok(self.tasks[index].send(input).await.unwrap())
164 }
165 .into_stream()
166 }
167}
168
169pub fn concurrent_each<I, S>(service: S, limit: usize) -> ConcurrentEach<I, S>
170where
171 I: Send,
172 S: Send + Service<I> + Clone + 'static,
173 S::Out: Send,
174{
175 ConcurrentEach::new(service, limit)
176}