xrpc/
streaming.rs

1use bytes::Bytes;
2use futures::Stream;
3use parking_lot::Mutex;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::pin::Pin;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::task::{Context, Poll};
10use tokio::sync::mpsc;
11
12use crate::channel::message::MessageChannel;
13use crate::codec::{BincodeCodec, Codec};
14use crate::error::{Result, RpcError};
15use crate::message::Message;
16use crate::message::types::{MessageId, MessageType};
17
18pub type StreamId = u64;
19
20static STREAM_ID_COUNTER: AtomicU64 = AtomicU64::new(1);
21
22pub fn next_stream_id() -> StreamId {
23    STREAM_ID_COUNTER.fetch_add(1, Ordering::Relaxed)
24}
25
26pub struct StreamSender<T: MessageChannel, C: Codec = BincodeCodec> {
27    stream_id: StreamId,
28    sequence: AtomicU64,
29    transport: Arc<T>,
30    codec: C,
31    ended: std::sync::atomic::AtomicBool,
32}
33
34impl<T: MessageChannel> StreamSender<T, BincodeCodec> {
35    pub fn new(stream_id: StreamId, transport: Arc<T>) -> Self {
36        Self {
37            stream_id,
38            sequence: AtomicU64::new(0),
39            transport,
40            codec: BincodeCodec,
41            ended: std::sync::atomic::AtomicBool::new(false),
42        }
43    }
44}
45
46impl<T: MessageChannel, C: Codec> StreamSender<T, C> {
47    pub fn with_codec(stream_id: StreamId, transport: Arc<T>, codec: C) -> Self {
48        Self {
49            stream_id,
50            sequence: AtomicU64::new(0),
51            transport,
52            codec,
53            ended: std::sync::atomic::AtomicBool::new(false),
54        }
55    }
56
57    pub async fn send<D: Serialize>(&self, data: D) -> Result<()> {
58        if self.ended.load(Ordering::Acquire) {
59            return Err(RpcError::StreamError("Stream already ended".to_string()));
60        }
61
62        let seq = self.sequence.fetch_add(1, Ordering::Relaxed);
63        let payload = self.codec.encode(&data)?;
64        let chunk = Message::new(
65            MessageId::new(),
66            MessageType::StreamChunk,
67            "",
68            Bytes::from(payload),
69            crate::message::metadata::MessageMetadata::new().with_stream(self.stream_id, seq),
70        );
71        self.transport
72            .send(&chunk)
73            .await
74            .map_err(RpcError::Transport)
75    }
76
77    pub async fn end(&self) -> Result<()> {
78        if self
79            .ended
80            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
81            .is_err()
82        {
83            return Ok(());
84        }
85
86        let end_msg: Message = Message::stream_end(self.stream_id);
87        self.transport
88            .send(&end_msg)
89            .await
90            .map_err(RpcError::Transport)
91    }
92
93    pub fn stream_id(&self) -> StreamId {
94        self.stream_id
95    }
96
97    pub fn is_ended(&self) -> bool {
98        self.ended.load(Ordering::Acquire)
99    }
100}
101
102/// Receives stream chunks. Assumes transport delivers messages in order.
103pub struct StreamReceiver<D, C: Codec = BincodeCodec> {
104    stream_id: StreamId,
105    rx: mpsc::UnboundedReceiver<Result<Bytes>>,
106    ended: bool,
107    codec: C,
108    _phantom: std::marker::PhantomData<D>,
109}
110
111impl<D, C> StreamReceiver<D, C>
112where
113    D: for<'de> Deserialize<'de>,
114    C: Codec,
115{
116    pub(crate) fn new(
117        stream_id: StreamId,
118        rx: mpsc::UnboundedReceiver<Result<Bytes>>,
119        codec: C,
120    ) -> Self {
121        Self {
122            stream_id,
123            rx,
124            ended: false,
125            codec,
126            _phantom: std::marker::PhantomData,
127        }
128    }
129
130    pub fn stream_id(&self) -> StreamId {
131        self.stream_id
132    }
133
134    pub fn is_ended(&self) -> bool {
135        self.ended
136    }
137
138    pub async fn recv(&mut self) -> Option<Result<D>> {
139        if self.ended {
140            return None;
141        }
142
143        match self.rx.recv().await {
144            Some(Ok(data)) => Some(self.codec.decode(&data)),
145            Some(Err(e)) => {
146                self.ended = true;
147                Some(Err(e))
148            }
149            None => {
150                self.ended = true;
151                None
152            }
153        }
154    }
155
156    pub async fn collect(mut self) -> Result<Vec<D>> {
157        let mut items = Vec::new();
158        while let Some(result) = self.recv().await {
159            items.push(result?);
160        }
161        Ok(items)
162    }
163
164    pub fn cancel(&mut self) {
165        self.ended = true;
166        self.rx.close();
167    }
168}
169
170impl<D, C> Stream for StreamReceiver<D, C>
171where
172    D: for<'de> Deserialize<'de> + Unpin,
173    C: Codec + Unpin,
174{
175    type Item = Result<D>;
176
177    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
178        if self.ended {
179            return Poll::Ready(None);
180        }
181
182        match self.rx.poll_recv(cx) {
183            Poll::Ready(Some(Ok(data))) => {
184                let result = self.codec.decode(&data);
185                Poll::Ready(Some(result))
186            }
187            Poll::Ready(Some(Err(e))) => {
188                self.ended = true;
189                Poll::Ready(Some(Err(e)))
190            }
191            Poll::Ready(None) => {
192                self.ended = true;
193                Poll::Ready(None)
194            }
195            Poll::Pending => Poll::Pending,
196        }
197    }
198}
199
200pub struct StreamManager<C: Codec = BincodeCodec> {
201    streams: Arc<Mutex<HashMap<StreamId, mpsc::UnboundedSender<Result<Bytes>>>>>,
202    codec: C,
203}
204
205impl StreamManager<BincodeCodec> {
206    pub fn new() -> Self {
207        Self {
208            streams: Arc::new(Mutex::new(HashMap::new())),
209            codec: BincodeCodec,
210        }
211    }
212}
213
214impl<C: Codec + Clone> StreamManager<C> {
215    pub fn with_codec(codec: C) -> Self {
216        Self {
217            streams: Arc::new(Mutex::new(HashMap::new())),
218            codec,
219        }
220    }
221
222    pub fn create_receiver<D>(&self, stream_id: StreamId) -> StreamReceiver<D, C>
223    where
224        D: for<'de> Deserialize<'de>,
225    {
226        let (tx, rx) = mpsc::unbounded_channel();
227        self.streams.lock().insert(stream_id, tx);
228        StreamReceiver::new(stream_id, rx, self.codec.clone())
229    }
230
231    pub fn handle_message(&self, message: &Message<C>) -> bool {
232        let stream_id = match message.metadata.stream_id {
233            Some(id) => id,
234            None => return false,
235        };
236
237        let streams = self.streams.lock();
238        let sender = match streams.get(&stream_id) {
239            Some(tx) => tx,
240            None => return false,
241        };
242
243        match message.msg_type {
244            MessageType::StreamChunk => {
245                let _ = sender.send(Ok(message.payload.clone()));
246                true
247            }
248            MessageType::StreamEnd => {
249                drop(streams);
250                self.remove_stream(stream_id);
251                true
252            }
253            _ => false,
254        }
255    }
256
257    /// Forward error to stream receiver and close the stream
258    pub fn send_error(&self, stream_id: StreamId, error_msg: String) {
259        let streams = self.streams.lock();
260        if let Some(sender) = streams.get(&stream_id) {
261            let _ = sender.send(Err(RpcError::ServerError(error_msg)));
262        }
263        drop(streams);
264        self.remove_stream(stream_id);
265    }
266
267    pub fn remove_stream(&self, stream_id: StreamId) {
268        self.streams.lock().remove(&stream_id);
269    }
270
271    pub fn active_stream_count(&self) -> usize {
272        self.streams.lock().len()
273    }
274}
275
276impl Default for StreamManager<BincodeCodec> {
277    fn default() -> Self {
278        Self::new()
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285    use crate::channel::message::MessageChannelAdapter;
286    use crate::transport::channel::{ChannelConfig, ChannelFrameTransport};
287
288    #[tokio::test]
289    async fn test_stream_sender_receiver() {
290        let config = ChannelConfig::default();
291        let (t1, t2) = ChannelFrameTransport::create_pair("test", config).unwrap();
292
293        let sender_channel = Arc::new(MessageChannelAdapter::new(t1));
294        let receiver_channel = MessageChannelAdapter::new(t2);
295
296        let stream_id = next_stream_id();
297        let sender = StreamSender::new(stream_id, sender_channel);
298
299        let manager = StreamManager::new();
300        let receiver: StreamReceiver<i32> = manager.create_receiver(stream_id);
301
302        let recv_handle = tokio::spawn(async move {
303            loop {
304                let msg = receiver_channel.recv().await.unwrap();
305                if !manager.handle_message(&msg) {
306                    break;
307                }
308                if msg.msg_type == MessageType::StreamEnd {
309                    break;
310                }
311            }
312            receiver
313        });
314
315        sender.send(1i32).await.unwrap();
316        sender.send(2i32).await.unwrap();
317        sender.send(3i32).await.unwrap();
318        sender.end().await.unwrap();
319
320        let mut receiver = recv_handle.await.unwrap();
321
322        let items: Vec<i32> = vec![
323            receiver.recv().await.unwrap().unwrap(),
324            receiver.recv().await.unwrap().unwrap(),
325            receiver.recv().await.unwrap().unwrap(),
326        ];
327        assert_eq!(items, vec![1, 2, 3]);
328    }
329
330    #[test]
331    fn test_stream_id_generation() {
332        let id1 = next_stream_id();
333        let id2 = next_stream_id();
334        assert_ne!(id1, id2);
335    }
336}