1use std::{
2 marker::PhantomData,
3 pin::{Pin, pin},
4 sync::Arc,
5 task::{Poll, ready},
6};
7
8use futures::{FutureExt, Stream, StreamExt};
9use tokio::sync::{Mutex, OwnedMutexGuard};
10
11use crate::{Context, Service};
12
13pub struct ConcurrentRx<T: Send> {
14 guard: OwnedMutexGuard<flowly_spsc::Receiver<Option<T>>>,
15}
16
17impl<T: Send> Stream for ConcurrentRx<T> {
18 type Item = T;
19
20 fn poll_next(
21 mut self: Pin<&mut Self>,
22 cx: &mut std::task::Context<'_>,
23 ) -> Poll<Option<Self::Item>> {
24 match ready!(self.guard.poll_recv(cx)) {
25 Some(Some(val)) => Poll::Ready(Some(val)),
26 Some(None) => Poll::Ready(None),
27 None => Poll::Ready(None),
28 }
29 }
30}
31
32struct ConcurrentTask<I: Send, S: Service<I>> {
33 #[allow(dead_code)]
34 id: u32,
35 tx: tokio::sync::mpsc::Sender<I>,
36 rx: Arc<Mutex<flowly_spsc::Receiver<Option<S::Out>>>>,
37 _handle: tokio::task::JoinHandle<()>,
38 m: PhantomData<S>,
39 ctx_tx: std::sync::Mutex<Option<tokio::sync::oneshot::Sender<Context>>>,
40}
41
42impl<I, S> ConcurrentTask<I, S>
43where
44 S::Out: Send + 'static,
45 I: Send + 'static,
46 S: Service<I> + Send + 'static,
47{
48 fn new(id: u32, s: Arc<S>) -> Self {
49 let (tx, mut in_rx) = tokio::sync::mpsc::channel(1);
50 let (mut out_tx, out_rx) = flowly_spsc::channel(1);
51 let (ctx_tx, ctx_rx) = tokio::sync::oneshot::channel();
52
53 let _handle = tokio::spawn(async move {
54 let Ok(cx) = ctx_rx.await else {
55 log::error!("no context got");
56 return;
57 };
58
59 'recv: while let Some(item) = in_rx.recv().await {
60 let mut s = pin!(s.handle(item, &cx));
61
62 while let Some(x) = s.next().await {
63 if out_tx.send(Some(x)).await.is_err() {
64 log::error!("cannot send the message. channel closed!");
65 break 'recv;
66 }
67 }
68
69 if out_tx.send(None).await.is_err() {
70 log::error!("cannot send the message. channel closed!");
71 break 'recv;
72 }
73 }
74 });
75
76 Self {
77 id,
78 tx,
79 ctx_tx: std::sync::Mutex::new(Some(ctx_tx)),
80 rx: Arc::new(tokio::sync::Mutex::new(out_rx)),
81 _handle,
82 m: PhantomData,
83 }
84 }
85
86 #[inline]
87 fn is_available(&self) -> bool {
88 self.rx.try_lock().is_ok()
89 }
90
91 #[inline]
92 async fn send(
93 &self,
94 input: I,
95 ) -> Result<ConcurrentRx<S::Out>, tokio::sync::mpsc::error::SendError<I>> {
96 self.tx.send(input).await?;
97
98 Ok(ConcurrentRx {
99 guard: self.rx.clone().lock_owned().await,
100 })
101 }
102
103 fn is_ready(&self) -> bool {
104 if let Ok(lock) = self.ctx_tx.try_lock() {
105 lock.is_none()
106 } else {
107 false
108 }
109 }
110
111 fn init(&self, ctx: Context) {
112 if let Ok(Some(sender)) = self.ctx_tx.try_lock().map(|mut x| x.take()) {
113 if sender.send(ctx).is_err() {
114 log::warn!("cannot send context: receiver closed");
115 }
116 } else {
117 log::warn!("cannot init ConcurrentTask twice");
118 }
119 }
120}
121
122pub struct ConcurrentEach<I: Send + 'static, S: Service<I>> {
123 service: Arc<S>,
124 tasks: Vec<ConcurrentTask<I, S>>,
125 _m: PhantomData<I>,
126 limit: usize,
127}
128
129impl<I: Send + 'static + Clone, S: Service<I> + Clone> Clone for ConcurrentEach<I, S> {
130 fn clone(&self) -> Self {
131 Self {
132 service: self.service.clone(),
133 tasks: Vec::new(),
134 _m: self._m,
135 limit: self.limit,
136 }
137 }
138}
139
140impl<I, S> ConcurrentEach<I, S>
141where
142 I: Send,
143 S: Service<I> + Send + 'static,
144 S::Out: Send,
145{
146 pub fn new(service: S, limit: usize) -> Self {
147 let service = Arc::new(service);
148 Self {
149 tasks: (0..limit as u32)
150 .map(|id| ConcurrentTask::new(id, service.clone()))
151 .collect(),
152 service,
153 _m: PhantomData,
154 limit,
155 }
156 }
157}
158
159impl<I, R, E, S> Service<I> for ConcurrentEach<I, S>
160where
161 I: Send + Sync,
162 R: Send + 'static,
163 E: Send + 'static,
164 S: Service<I, Out = Result<R, E>> + Clone + Send + 'static,
165{
166 type Out = Result<ConcurrentRx<S::Out>, E>;
167
168 fn handle(&self, input: I, cx: &Context) -> impl Stream<Item = Self::Out> + Send {
169 async move {
170 let mut index = fastrand::usize(0..self.tasks.len());
171
172 for idx in 0..self.tasks.len() {
173 let idx = (idx + self.tasks.len()) % self.tasks.len();
174 if self.tasks[idx].is_available() {
175 index = idx;
176 break;
177 }
178 }
179
180 if !self.tasks[index].is_ready() {
181 self.tasks[index].init(cx.clone());
182 }
183
184 Ok(self.tasks[index].send(input).await.unwrap())
185 }
186 .into_stream()
187 }
188}
189
190pub fn concurrent_each<I, S>(service: S, limit: usize) -> ConcurrentEach<I, S>
191where
192 I: Send,
193 S: Send + Service<I> + Clone + 'static,
194 S::Out: Send,
195{
196 ConcurrentEach::new(service, limit)
197}