flowly_service/
concurrent_each.rs

1use std::{
2    marker::PhantomData,
3    pin::{Pin, pin},
4    sync::Arc,
5    task::{Poll, ready},
6};
7
8use futures::{FutureExt, Stream, StreamExt};
9use tokio::sync::{Mutex, OwnedMutexGuard};
10
11use crate::{Context, Service};
12
13pub struct ConcurrentRx<T: Send> {
14    guard: OwnedMutexGuard<flowly_spsc::Receiver<Option<T>>>,
15}
16
17impl<T: Send> Stream for ConcurrentRx<T> {
18    type Item = T;
19
20    fn poll_next(
21        mut self: Pin<&mut Self>,
22        cx: &mut std::task::Context<'_>,
23    ) -> Poll<Option<Self::Item>> {
24        match ready!(self.guard.poll_recv(cx)) {
25            Some(Some(val)) => Poll::Ready(Some(val)),
26            Some(None) => Poll::Ready(None),
27            None => Poll::Ready(None),
28        }
29    }
30}
31
32struct ConcurrentTask<I: Send, S: Service<I>> {
33    #[allow(dead_code)]
34    id: u32,
35    tx: flowly_spsc::Sender<I>,
36    m: PhantomData<S>,
37    _handle: tokio::task::JoinHandle<()>,
38    rx: Arc<Mutex<flowly_spsc::Receiver<Option<S::Out>>>>,
39}
40
41impl<I, S> ConcurrentTask<I, S>
42where
43    S::Out: Send + 'static,
44    I: Send + 'static,
45    S: Service<I> + Send + 'static,
46{
47    fn new(id: u32, mut s: S, cx: Context) -> Self {
48        let (tx, mut in_rx) = flowly_spsc::channel(1);
49        let (mut out_tx, out_rx) = flowly_spsc::channel(1);
50
51        let _handle = tokio::spawn(async move {
52            'recv: while let Some(item) = in_rx.recv().await {
53                let mut s = pin!(s.handle(item, &cx));
54
55                while let Some(x) = s.next().await {
56                    if out_tx.send(Some(x)).await.is_err() {
57                        log::error!("cannot send the message. channel closed!");
58                        break 'recv;
59                    }
60                }
61
62                if out_tx.send(None).await.is_err() {
63                    log::error!("cannot send the message. channel closed!");
64                    break 'recv;
65                }
66            }
67        });
68
69        Self {
70            id,
71            tx,
72            rx: Arc::new(tokio::sync::Mutex::new(out_rx)),
73            _handle,
74            m: PhantomData,
75        }
76    }
77
78    #[inline]
79    fn is_available(&self) -> bool {
80        self.rx.try_lock().is_ok()
81    }
82
83    #[inline]
84    async fn send(
85        &mut self,
86        input: I,
87    ) -> Result<ConcurrentRx<S::Out>, flowly_spsc::TrySendError<I>> {
88        self.tx.send(input).await?;
89
90        Ok(ConcurrentRx {
91            guard: self.rx.clone().lock_owned().await,
92        })
93    }
94}
95
96pub struct ConcurrentEach<I: Send + 'static, S: Service<I>> {
97    service: S,
98    tasks: Vec<ConcurrentTask<I, S>>,
99    _m: PhantomData<I>,
100    limit: usize,
101}
102
103impl<I: Send + 'static + Clone, S: Service<I> + Clone> Clone for ConcurrentEach<I, S> {
104    fn clone(&self) -> Self {
105        Self {
106            service: self.service.clone(),
107            tasks: Vec::new(),
108            _m: self._m,
109            limit: self.limit,
110        }
111    }
112}
113
114impl<I, S> ConcurrentEach<I, S>
115where
116    I: Send,
117    S: Service<I> + Send,
118    S::Out: Send,
119{
120    pub fn new(service: S, limit: usize) -> Self {
121        Self {
122            service,
123            tasks: Vec::with_capacity(limit),
124            _m: PhantomData,
125            limit,
126        }
127    }
128}
129
130impl<I, R, E, S> Service<I> for ConcurrentEach<I, S>
131where
132    I: Send,
133    R: Send + 'static,
134    E: Send + 'static,
135    S: Service<I, Out = Result<R, E>> + Clone + Send + 'static,
136{
137    type Out = Result<ConcurrentRx<S::Out>, E>;
138
139    fn handle(&mut self, input: I, cx: &Context) -> impl Stream<Item = Self::Out> + Send {
140        async move {
141            let index = if self.tasks.len() < self.limit {
142                let index = self.tasks.len();
143                self.tasks.push(ConcurrentTask::new(
144                    index as u32,
145                    self.service.clone(),
146                    cx.clone(),
147                ));
148                index
149            } else {
150                let mut index = fastrand::usize(0..self.tasks.len());
151
152                for idx in 0..self.tasks.len() {
153                    let idx = (idx + self.tasks.len()) % self.tasks.len();
154                    if self.tasks[idx].is_available() {
155                        index = idx;
156                        break;
157                    }
158                }
159
160                index
161            };
162
163            Ok(self.tasks[index].send(input).await.unwrap())
164        }
165        .into_stream()
166    }
167}
168
169pub fn concurrent_each<I, S>(service: S, limit: usize) -> ConcurrentEach<I, S>
170where
171    I: Send,
172    S: Send + Service<I> + Clone + 'static,
173    S::Out: Send,
174{
175    ConcurrentEach::new(service, limit)
176}