Skip to main content

flowly_service/
concurrent_each.rs

1use std::{
2    marker::PhantomData,
3    pin::{Pin, pin},
4    sync::Arc,
5    task::{Poll, ready},
6};
7
8use futures::{FutureExt, Stream, StreamExt};
9use tokio::sync::{Mutex, OwnedMutexGuard};
10
11use crate::{Context, Service};
12
13pub struct ConcurrentRx<T: Send> {
14    guard: OwnedMutexGuard<flowly_spsc::Receiver<Option<T>>>,
15}
16
17impl<T: Send> Stream for ConcurrentRx<T> {
18    type Item = T;
19
20    fn poll_next(
21        mut self: Pin<&mut Self>,
22        cx: &mut std::task::Context<'_>,
23    ) -> Poll<Option<Self::Item>> {
24        match ready!(self.guard.poll_recv(cx)) {
25            Some(Some(val)) => Poll::Ready(Some(val)),
26            Some(None) => Poll::Ready(None),
27            None => Poll::Ready(None),
28        }
29    }
30}
31
32struct ConcurrentTask<I: Send, S: Service<I>> {
33    #[allow(dead_code)]
34    id: u32,
35    tx: tokio::sync::mpsc::Sender<I>,
36    rx: Arc<Mutex<flowly_spsc::Receiver<Option<S::Out>>>>,
37    _handle: tokio::task::JoinHandle<()>,
38    m: PhantomData<S>,
39    ctx_tx: std::sync::Mutex<Option<tokio::sync::oneshot::Sender<Context>>>,
40}
41
42impl<I, S> ConcurrentTask<I, S>
43where
44    S::Out: Send + 'static,
45    I: Send + 'static,
46    S: Service<I> + Send + 'static,
47{
48    fn new(id: u32, s: Arc<S>) -> Self {
49        let (tx, mut in_rx) = tokio::sync::mpsc::channel(1);
50        let (mut out_tx, out_rx) = flowly_spsc::channel(1);
51        let (ctx_tx, ctx_rx) = tokio::sync::oneshot::channel();
52
53        let _handle = tokio::spawn(async move {
54            let Ok(cx) = ctx_rx.await else {
55                log::error!("no context got");
56                return;
57            };
58
59            'recv: while let Some(item) = in_rx.recv().await {
60                let mut s = pin!(s.handle(item, &cx));
61
62                while let Some(x) = s.next().await {
63                    if out_tx.send(Some(x)).await.is_err() {
64                        log::error!("cannot send the message. channel closed!");
65                        break 'recv;
66                    }
67                }
68
69                if out_tx.send(None).await.is_err() {
70                    log::error!("cannot send the message. channel closed!");
71                    break 'recv;
72                }
73            }
74        });
75
76        Self {
77            id,
78            tx,
79            ctx_tx: std::sync::Mutex::new(Some(ctx_tx)),
80            rx: Arc::new(tokio::sync::Mutex::new(out_rx)),
81            _handle,
82            m: PhantomData,
83        }
84    }
85
86    #[inline]
87    fn is_available(&self) -> bool {
88        self.rx.try_lock().is_ok()
89    }
90
91    #[inline]
92    async fn send(
93        &self,
94        input: I,
95    ) -> Result<ConcurrentRx<S::Out>, tokio::sync::mpsc::error::SendError<I>> {
96        self.tx.send(input).await?;
97
98        Ok(ConcurrentRx {
99            guard: self.rx.clone().lock_owned().await,
100        })
101    }
102
103    fn is_ready(&self) -> bool {
104        if let Ok(lock) = self.ctx_tx.try_lock() {
105            lock.is_none()
106        } else {
107            false
108        }
109    }
110
111    fn init(&self, ctx: Context) {
112        if let Ok(Some(sender)) = self.ctx_tx.try_lock().map(|mut x| x.take()) {
113            if sender.send(ctx).is_err() {
114                log::warn!("cannot send context: receiver closed");
115            }
116        } else {
117            log::warn!("cannot init ConcurrentTask twice");
118        }
119    }
120}
121
122pub struct ConcurrentEach<I: Send + 'static, S: Service<I>> {
123    service: Arc<S>,
124    tasks: Vec<ConcurrentTask<I, S>>,
125    _m: PhantomData<I>,
126    limit: usize,
127}
128
129impl<I: Send + 'static + Clone, S: Service<I> + Clone> Clone for ConcurrentEach<I, S> {
130    fn clone(&self) -> Self {
131        Self {
132            service: self.service.clone(),
133            tasks: Vec::new(),
134            _m: self._m,
135            limit: self.limit,
136        }
137    }
138}
139
140impl<I, S> ConcurrentEach<I, S>
141where
142    I: Send,
143    S: Service<I> + Send + 'static,
144    S::Out: Send,
145{
146    pub fn new(service: S, limit: usize) -> Self {
147        let service = Arc::new(service);
148        Self {
149            tasks: (0..limit as u32)
150                .map(|id| ConcurrentTask::new(id, service.clone()))
151                .collect(),
152            service,
153            _m: PhantomData,
154            limit,
155        }
156    }
157}
158
159impl<I, R, E, S> Service<I> for ConcurrentEach<I, S>
160where
161    I: Send + Sync,
162    R: Send + 'static,
163    E: Send + 'static,
164    S: Service<I, Out = Result<R, E>> + Clone + Send + 'static,
165{
166    type Out = Result<ConcurrentRx<S::Out>, E>;
167
168    fn handle(&self, input: I, cx: &Context) -> impl Stream<Item = Self::Out> + Send {
169        async move {
170            let mut index = fastrand::usize(0..self.tasks.len());
171
172            for idx in 0..self.tasks.len() {
173                let idx = (idx + self.tasks.len()) % self.tasks.len();
174                if self.tasks[idx].is_available() {
175                    index = idx;
176                    break;
177                }
178            }
179
180            if !self.tasks[index].is_ready() {
181                self.tasks[index].init(cx.clone());
182            }
183
184            Ok(self.tasks[index].send(input).await.unwrap())
185        }
186        .into_stream()
187    }
188}
189
190pub fn concurrent_each<I, S>(service: S, limit: usize) -> ConcurrentEach<I, S>
191where
192    I: Send,
193    S: Send + Service<I> + Clone + 'static,
194    S::Out: Send,
195{
196    ConcurrentEach::new(service, limit)
197}