1use std::{marker::PhantomData, pin::pin};
2
3use futures::{FutureExt, Stream, StreamExt};
4use tokio::sync::{Mutex, mpsc};
5
6use crate::{Context, Service};
7
8const MAX_SPAWN_TASKS: usize = 256;
9
10struct SpawnEachTask<I: Send, S: Service<I>> {
11 #[allow(dead_code)]
12 id: u32,
13 tx: flowly_spsc::Sender<I>,
14 m: PhantomData<S>,
15 _handle: tokio::task::JoinHandle<()>,
16}
17
18impl<I, S> SpawnEachTask<I, S>
19where
20 S::Out: Send + 'static,
21 I: Send + 'static,
22 S: Service<I> + Send + 'static,
23{
24 fn new(
25 id: u32,
26 buffer: usize,
27 mut s: S,
28 out_tx: mpsc::Sender<Option<S::Out>>,
29 cx: Context,
30 input: I,
31 ) -> Self {
32 let (mut tx, mut rx) = flowly_spsc::channel(buffer);
33
34 let _handle = tokio::spawn(async move {
35 'recv: while let Some(item) = rx.recv().await {
36 let mut s = pin!(s.handle(item, &cx));
37
38 while let Some(x) = s.next().await {
39 if out_tx.send(Some(x)).await.is_err() {
40 log::error!("cannot send the message. channel closed!");
41 break 'recv;
42 }
43 }
44
45 if out_tx.send(None).await.is_err() {
46 log::error!("cannot send the message. channel closed!");
47 break 'recv;
48 }
49 }
50 });
51
52 tx.try_send(input).unwrap();
53
54 Self {
55 id,
56 tx,
57 _handle,
58 m: PhantomData,
59 }
60 }
61
62 #[inline]
63 async fn send(&mut self, input: I) -> Result<(), flowly_spsc::TrySendError<I>> {
64 self.tx.send(input).await
65 }
66}
67
68pub struct SpawnEach<I: Send + 'static, S: Service<I>> {
69 service: S,
70 sender: mpsc::Sender<Option<S::Out>>,
71 receiver: Mutex<mpsc::Receiver<Option<S::Out>>>,
72 tasks: Vec<SpawnEachTask<I, S>>,
73 _m: PhantomData<I>,
74 counter: u32,
75}
76
77impl<I, S> SpawnEach<I, S>
78where
79 I: Send,
80 S: Service<I> + Send,
81 S::Out: Send,
82{
83 pub(crate) fn new(service: S) -> Self {
84 let (sender, rx) = mpsc::channel(1);
85
86 Self {
87 service,
88 sender,
89 receiver: Mutex::new(rx),
90 tasks: Vec::with_capacity(MAX_SPAWN_TASKS),
91 _m: PhantomData,
92 counter: 0,
93 }
94 }
95
96 #[inline]
97 fn drain_rx(&mut self) -> impl Stream<Item = S::Out> + Send {
98 async_stream::stream! {
99 let mut guard = self.receiver.lock().await;
100 while let Some(res) = guard.recv().await {
101 if let Some(item) = res {
102 yield item;
103 } else {
104 break;
105 }
106 }
107 }
108 }
109}
110
111impl<I, S> Service<I> for SpawnEach<I, S>
112where
113 I: Send,
114 S: Service<I> + Clone + Send + 'static,
115 S::Out: Send,
116{
117 type Out = S::Out;
118
119 fn handle(&mut self, mut input: I, cx: &Context) -> impl Stream<Item = Self::Out> + Send {
120 if self.tasks.len() < MAX_SPAWN_TASKS {
121 self.tasks.push(SpawnEachTask::new(
122 self.counter,
123 2,
124 self.service.clone(),
125 self.sender.clone(),
126 cx.clone(),
127 input,
128 ));
129
130 self.counter += 1;
131 self.drain_rx().right_stream()
132 } else {
133 let index = fastrand::usize(0..self.tasks.len());
134
135 let (left, right) = self.tasks.split_at_mut(index);
136
137 for task in right.iter_mut().chain(left.iter_mut()) {
138 if let Err(err) = task.tx.try_send(input) {
139 input = err.val;
140 } else {
141 return self.drain_rx().right_stream();
142 }
143 }
144
145 async move {
146 if self.tasks[index].send(input).await.is_err() {
147 log::error!("cannot send the message. channel closed!");
148 }
149
150 self.drain_rx()
151 }
152 .into_stream()
153 .flatten()
154 .left_stream()
155 }
156 }
157}
158
159pub fn spawn_each<I, S>(service: S) -> SpawnEach<I, S>
160where
161 I: Send,
162 S: Send + Service<I> + Clone + 'static,
163 S::Out: Send,
164{
165 SpawnEach::new(service)
166}
167
168