flowly_service/
spawn.rs

1use std::{marker::PhantomData, pin::pin};
2
3use futures::{FutureExt, Stream, StreamExt};
4use tokio::sync::{Mutex, mpsc};
5
6use crate::{Context, Service};
7
8const MAX_SPAWN_TASKS: usize = 256;
9
10struct SpawnEachTask<I: Send, S: Service<I>> {
11    #[allow(dead_code)]
12    id: u32,
13    tx: flowly_spsc::Sender<I>,
14    m: PhantomData<S>,
15    _handle: tokio::task::JoinHandle<()>,
16}
17
18impl<I, S> SpawnEachTask<I, S>
19where
20    S::Out: Send + 'static,
21    I: Send + 'static,
22    S: Service<I> + Send + 'static,
23{
24    fn new(
25        id: u32,
26        buffer: usize,
27        mut s: S,
28        out_tx: mpsc::Sender<Option<S::Out>>,
29        cx: Context,
30        input: I,
31    ) -> Self {
32        let (mut tx, mut rx) = flowly_spsc::channel(buffer);
33
34        let _handle = tokio::spawn(async move {
35            'recv: while let Some(item) = rx.recv().await {
36                let mut s = pin!(s.handle(item, &cx));
37
38                while let Some(x) = s.next().await {
39                    if out_tx.send(Some(x)).await.is_err() {
40                        log::error!("cannot send the message. channel closed!");
41                        break 'recv;
42                    }
43                }
44
45                if out_tx.send(None).await.is_err() {
46                    log::error!("cannot send the message. channel closed!");
47                    break 'recv;
48                }
49            }
50        });
51
52        tx.try_send(input).unwrap();
53
54        Self {
55            id,
56            tx,
57            _handle,
58            m: PhantomData,
59        }
60    }
61
62    #[inline]
63    async fn send(&mut self, input: I) -> Result<(), flowly_spsc::TrySendError<I>> {
64        self.tx.send(input).await
65    }
66}
67
68pub struct SpawnEach<I: Send + 'static, S: Service<I>> {
69    service: S,
70    sender: mpsc::Sender<Option<S::Out>>,
71    receiver: Mutex<mpsc::Receiver<Option<S::Out>>>,
72    tasks: Vec<SpawnEachTask<I, S>>,
73    _m: PhantomData<I>,
74    counter: u32,
75}
76
77impl<I, S> SpawnEach<I, S>
78where
79    I: Send,
80    S: Service<I> + Send,
81    S::Out: Send,
82{
83    pub(crate) fn new(service: S) -> Self {
84        let (sender, rx) = mpsc::channel(1);
85
86        Self {
87            service,
88            sender,
89            receiver: Mutex::new(rx),
90            tasks: Vec::with_capacity(MAX_SPAWN_TASKS),
91            _m: PhantomData,
92            counter: 0,
93        }
94    }
95
96    #[inline]
97    fn drain_rx(&mut self) -> impl Stream<Item = S::Out> + Send {
98        async_stream::stream! {
99            let mut guard = self.receiver.lock().await;
100            while let Some(res) = guard.recv().await {
101                if let Some(item) = res {
102                    yield item;
103                } else {
104                    break;
105                }
106            }
107        }
108    }
109}
110
111impl<I, S> Service<I> for SpawnEach<I, S>
112where
113    I: Send,
114    S: Service<I> + Clone + Send + 'static,
115    S::Out: Send,
116{
117    type Out = S::Out;
118
119    fn handle(&mut self, mut input: I, cx: &Context) -> impl Stream<Item = Self::Out> + Send {
120        if self.tasks.len() < MAX_SPAWN_TASKS {
121            self.tasks.push(SpawnEachTask::new(
122                self.counter,
123                2,
124                self.service.clone(),
125                self.sender.clone(),
126                cx.clone(),
127                input,
128            ));
129
130            self.counter += 1;
131            self.drain_rx().right_stream()
132        } else {
133            let index = fastrand::usize(0..self.tasks.len());
134
135            let (left, right) = self.tasks.split_at_mut(index);
136
137            for task in right.iter_mut().chain(left.iter_mut()) {
138                if let Err(err) = task.tx.try_send(input) {
139                    input = err.val;
140                } else {
141                    return self.drain_rx().right_stream();
142                }
143            }
144
145            async move {
146                if self.tasks[index].send(input).await.is_err() {
147                    log::error!("cannot send the message. channel closed!");
148                }
149
150                self.drain_rx()
151            }
152            .into_stream()
153            .flatten()
154            .left_stream()
155        }
156    }
157}
158
159pub fn spawn_each<I, S>(service: S) -> SpawnEach<I, S>
160where
161    I: Send,
162    S: Send + Service<I> + Clone + 'static,
163    S::Out: Send,
164{
165    SpawnEach::new(service)
166}
167
168// pub struct Spawn<S> {
169//     pub(crate) service: S,
170// }
171
172// pub struct SpawnLocal<S> {
173//     pub(crate) service: S,
174// }
175
176// impl<I: Send, S: Service<I> + Send + 'static> Service<I> for Spawn<S>
177// where
178//     S::Out: Send,
179// {
180//     type Out = S::Out;
181
182//     #[inline]
183//     fn handle(&mut self, input: I, cx: &Context) -> impl Stream<Item = S::Out> {
184//         let _ = cx;
185//         let _ = input;
186
187//         unimplemented!();
188
189//         futures::stream::empty()
190//         // let (tx, mut rx) = mpsc::channel(self.buffer);
191
192//         // let fut = Box::pin(async move {
193//         //     let mut stream = pin!(self.service.handle(input));
194
195//         //     while let Some(item) = stream.next().await {
196//         //         if (tx.send(item).await).is_err() {
197//         //             break;
198//         //         }
199//         //     }
200//         // }) as Pin<Box<dyn Future<Output = ()> + Send>>;
201
202//         // // SAFTY:
203//         // // This is safe because:
204//         // //  - input stream will be dropped as soon as it drained.
205//         // //  - it is garateed that input lives as long as it contains items
206//         // tokio::spawn(unsafe {
207//         //     std::mem::transmute::<
208//         //         Pin<Box<dyn Future<Output = ()> + Send>>,
209//         //         Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
210//         //     >(fut)
211//         // });
212
213//         // poll_fn(move |cx| rx.poll_recv(cx))
214//     }
215// }