Skip to main content

gosuto_livekit/room/data_stream/
incoming.rs

1// Copyright 2025 LiveKit, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use super::{
16    AnyStreamInfo, ByteStreamInfo, StreamError, StreamProgress, StreamResult, TextStreamInfo,
17};
18use crate::{e2ee::EncryptionType, TakeCell};
19use bytes::{Bytes, BytesMut};
20use futures_util::{Stream, StreamExt};
21use livekit_protocol::data_stream as proto;
22use parking_lot::Mutex;
23use std::{
24    collections::HashMap,
25    fmt::Debug,
26    pin::Pin,
27    sync::Arc,
28    task::{Context, Poll},
29};
30use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
31
32/// Reader for an incoming data stream.
33///
34/// The stream being read from is kept open as long as its reader exists;
35/// dropping the reader will close the stream.
36///
37pub trait StreamReader: Stream<Item = StreamResult<Self::Output>> {
38    /// Type of output this reader produces.
39    type Output;
40
41    /// Information about the underlying data stream.
42    type Info;
43
44    /// Returns a reference to the stream info.
45    fn info(&self) -> &Self::Info;
46
47    /// Reads all incoming chunks from the byte stream, concatenating them
48    /// into a single value which is returned once the stream closes normally.
49    ///
50    /// Returns the data consisting of all concatenated chunks.
51    ///
52    fn read_all(self) -> impl std::future::Future<Output = StreamResult<Self::Output>> + Send;
53}
54
55impl<T> TakeCell<T>
56where
57    T: StreamReader,
58{
59    /// Takes the reader out of the cell if its info matches the given predicate.
60    ///
61    /// Use this method to conditionally handle incoming streams based on info fields
62    /// such as topic or attributes.
63    ///
64    /// This method will only take the reader if the provided predicate returns `true` when called with the reader's info.
65    /// If the predicate returns `false` or the reader has already been taken, this method returns `None`.
66    ///
67    pub fn take_if(&self, predicate: impl FnOnce(&T::Info) -> bool) -> Option<T> {
68        self.take_if_raw(|reader| predicate(reader.info()))
69    }
70}
71
72/// Reader for an incoming byte data stream.
73pub struct ByteStreamReader {
74    info: ByteStreamInfo,
75    chunk_rx: UnboundedReceiver<StreamResult<Bytes>>,
76}
77
78/// Reader for an incoming text data stream.
79pub struct TextStreamReader {
80    info: TextStreamInfo,
81    chunk_rx: UnboundedReceiver<StreamResult<Bytes>>,
82}
83
84impl StreamReader for ByteStreamReader {
85    type Output = Bytes;
86    type Info = ByteStreamInfo;
87
88    fn info(&self) -> &ByteStreamInfo {
89        &self.info
90    }
91
92    async fn read_all(mut self) -> StreamResult<Bytes> {
93        let mut buffer = BytesMut::new();
94        while let Some(result) = self.next().await {
95            match result {
96                Ok(bytes) => buffer.extend_from_slice(&bytes),
97                Err(e) => return Err(e),
98            }
99        }
100        Ok(buffer.freeze())
101    }
102}
103
104impl ByteStreamReader {
105    /// Reads incoming chunks from the byte stream, writing them to a file as they are received.
106    ///
107    /// Parameters:
108    ///   - directory: The directory to write the file in. The system temporary directory is used if not specified.
109    ///   - name_override: The name to use for the written file, overriding stream name.
110    ///
111    /// Returns: The path of the written file on disk.
112    ///
113    pub async fn write_to_file(
114        mut self,
115        directory: Option<impl AsRef<std::path::Path>>,
116        name_override: Option<&str>,
117    ) -> StreamResult<std::path::PathBuf> {
118        let directory =
119            directory.map(|d| d.as_ref().to_path_buf()).unwrap_or_else(|| std::env::temp_dir());
120        let name = name_override.unwrap_or_else(|| &self.info.name);
121        let file_path = directory.join(name);
122
123        let mut file = tokio::fs::File::create(&file_path).await.map_err(StreamError::Io)?;
124
125        while let Some(result) = self.next().await {
126            let bytes = result?;
127            tokio::io::AsyncWriteExt::write_all(&mut file, &bytes)
128                .await
129                .map_err(StreamError::Io)?;
130        }
131        tokio::io::AsyncWriteExt::flush(&mut file).await.map_err(StreamError::Io)?;
132
133        Ok(file_path)
134    }
135}
136
137impl Stream for ByteStreamReader {
138    type Item = StreamResult<Bytes>;
139
140    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
141        let this = self.get_mut();
142        match Pin::new(&mut this.chunk_rx).poll_recv(cx) {
143            Poll::Ready(Some(Ok(chunk))) => Poll::Ready(Some(Ok(chunk))),
144            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
145            Poll::Ready(None) => Poll::Ready(None),
146            Poll::Pending => Poll::Pending,
147        }
148    }
149}
150
151impl StreamReader for TextStreamReader {
152    type Output = String;
153    type Info = TextStreamInfo;
154
155    fn info(&self) -> &TextStreamInfo {
156        &self.info
157    }
158
159    async fn read_all(mut self) -> StreamResult<String> {
160        let mut result = String::new();
161        while let Some(chunk) = self.next().await {
162            match chunk {
163                Ok(text) => result.push_str(&text),
164                Err(e) => return Err(e),
165            }
166        }
167        Ok(result)
168    }
169}
170
171impl Stream for TextStreamReader {
172    type Item = StreamResult<String>;
173
174    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
175        let this = self.get_mut();
176        match Pin::new(&mut this.chunk_rx).poll_recv(cx) {
177            Poll::Ready(Some(Ok(chunk))) => match String::from_utf8(chunk.into()) {
178                Ok(content) => Poll::Ready(Some(Ok(content))),
179                Err(e) => {
180                    this.chunk_rx.close();
181                    Poll::Ready(Some(Err(StreamError::from(e))))
182                }
183            },
184            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
185            Poll::Ready(None) => Poll::Ready(None),
186            Poll::Pending => Poll::Pending,
187        }
188    }
189}
190
191impl Debug for ByteStreamReader {
192    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
193        f.debug_struct("ByteStreamReader")
194            .field("id", &self.info.id())
195            .field("topic", &self.info.topic)
196            .finish()
197    }
198}
199
200impl Debug for TextStreamReader {
201    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202        f.debug_struct("TextStreamReader")
203            .field("id", &self.info.id())
204            .field("topic", &self.info.topic)
205            .finish()
206    }
207}
208
209pub(crate) enum AnyStreamReader {
210    Byte(ByteStreamReader),
211    Text(TextStreamReader),
212}
213
214impl AnyStreamReader {
215    /// Creates a stream reader for the stream with the given info.
216    pub(super) fn from(info: AnyStreamInfo) -> (Self, UnboundedSender<StreamResult<Bytes>>) {
217        let (chunk_tx, chunk_rx) = mpsc::unbounded_channel();
218        let reader = match info {
219            AnyStreamInfo::Byte(info) => Self::Byte(ByteStreamReader { info, chunk_rx }),
220            AnyStreamInfo::Text(info) => Self::Text(TextStreamReader { info, chunk_rx }),
221        };
222        return (reader, chunk_tx);
223    }
224}
225struct Descriptor {
226    progress: StreamProgress,
227    chunk_tx: UnboundedSender<StreamResult<Bytes>>,
228    encryption_type: EncryptionType,
229    // TODO(ladvoc): keep track of open time.
230}
231
232#[derive(Clone)]
233pub(crate) struct IncomingStreamManager {
234    inner: Arc<Mutex<ManagerInner>>,
235    open_tx: UnboundedSender<(AnyStreamReader, String)>,
236}
237
238#[derive(Default)]
239struct ManagerInner {
240    open_streams: HashMap<String, Descriptor>,
241}
242
243impl IncomingStreamManager {
244    pub fn new() -> (Self, UnboundedReceiver<(AnyStreamReader, String)>) {
245        let (open_tx, open_rx) = mpsc::unbounded_channel();
246        (Self { inner: Arc::new(Mutex::new(Default::default())), open_tx }, open_rx)
247    }
248
249    /// Handles an incoming header packet.
250    pub fn handle_header(
251        &self,
252        header: proto::Header,
253        identity: String,
254        encryption_type: livekit_protocol::encryption::Type,
255    ) {
256        let Ok(info) = AnyStreamInfo::try_from_with_encryption(header, encryption_type.into())
257            .inspect_err(|e| log::error!("Invalid header: {}", e))
258        else {
259            return;
260        };
261
262        let id = info.id().to_owned();
263        let bytes_total = info.total_length();
264        let stream_encryption_type = info.encryption_type();
265
266        let mut inner = self.inner.lock();
267        if inner.open_streams.contains_key(&id) {
268            log::error!("Stream '{}' already open", id);
269            return;
270        }
271
272        let (reader, chunk_tx) = AnyStreamReader::from(info);
273        let _ = self.open_tx.send((reader, identity));
274
275        let descriptor = Descriptor {
276            progress: StreamProgress { bytes_total, ..Default::default() },
277            chunk_tx,
278            encryption_type: stream_encryption_type,
279        };
280        inner.open_streams.insert(id, descriptor);
281    }
282
283    /// Handles an incoming chunk packet.
284    pub fn handle_chunk(
285        &self,
286        chunk: proto::Chunk,
287        encryption_type: livekit_protocol::encryption::Type,
288    ) {
289        let id = chunk.stream_id;
290        let mut inner = self.inner.lock();
291        let Some(descriptor) = inner.open_streams.get_mut(&id) else {
292            return;
293        };
294
295        if descriptor.encryption_type != encryption_type.into() {
296            inner.close_stream_with_error(&id, StreamError::EncryptionTypeMismatch);
297            return;
298        }
299
300        if descriptor.progress.chunk_index != chunk.chunk_index {
301            inner.close_stream_with_error(&id, StreamError::MissedChunk);
302            return;
303        }
304
305        descriptor.progress.chunk_index += 1;
306        descriptor.progress.bytes_processed += chunk.content.len() as u64;
307
308        if match descriptor.progress.bytes_total {
309            Some(total) => descriptor.progress.bytes_processed > total as u64,
310            None => false,
311        } {
312            inner.close_stream_with_error(&id, StreamError::LengthExceeded);
313            return;
314        }
315        inner.yield_chunk(&id, Bytes::from(chunk.content));
316        // TODO: also yield progress
317    }
318
319    /// Handles an incoming trailer packet.
320    pub fn handle_trailer(&self, trailer: proto::Trailer) {
321        let id = trailer.stream_id;
322        let mut inner = self.inner.lock();
323        let Some(descriptor) = inner.open_streams.get_mut(&id) else {
324            return;
325        };
326
327        if !match descriptor.progress.bytes_total {
328            Some(total) => descriptor.progress.bytes_processed >= total as u64,
329            None => true,
330        } {
331            inner.close_stream_with_error(&id, StreamError::Incomplete);
332            return;
333        }
334        if !trailer.reason.is_empty() {
335            inner.close_stream_with_error(&id, StreamError::AbnormalEnd(trailer.reason));
336            return;
337        }
338        inner.close_stream(&id);
339    }
340}
341
342impl ManagerInner {
343    fn yield_chunk(&mut self, id: &str, chunk: Bytes) {
344        let Some(descriptor) = self.open_streams.get_mut(id) else {
345            return;
346        };
347        if descriptor.chunk_tx.send(Ok(chunk)).is_err() {
348            // Reader has been dropped, close the stream.
349            self.close_stream(id);
350        }
351    }
352
353    fn close_stream(&mut self, id: &str) {
354        // Dropping the sender closes the channel.
355        self.open_streams.remove(id);
356    }
357
358    fn close_stream_with_error(&mut self, id: &str, error: StreamError) {
359        if let Some(descriptor) = self.open_streams.remove(id) {
360            let _ = descriptor.chunk_tx.send(Err(error));
361        }
362    }
363}