gosuto_livekit/room/data_stream/
incoming.rs1use super::{
16 AnyStreamInfo, ByteStreamInfo, StreamError, StreamProgress, StreamResult, TextStreamInfo,
17};
18use crate::{e2ee::EncryptionType, TakeCell};
19use bytes::{Bytes, BytesMut};
20use futures_util::{Stream, StreamExt};
21use livekit_protocol::data_stream as proto;
22use parking_lot::Mutex;
23use std::{
24 collections::HashMap,
25 fmt::Debug,
26 pin::Pin,
27 sync::Arc,
28 task::{Context, Poll},
29};
30use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
31
32pub trait StreamReader: Stream<Item = StreamResult<Self::Output>> {
38 type Output;
40
41 type Info;
43
44 fn info(&self) -> &Self::Info;
46
47 fn read_all(self) -> impl std::future::Future<Output = StreamResult<Self::Output>> + Send;
53}
54
55impl<T> TakeCell<T>
56where
57 T: StreamReader,
58{
59 pub fn take_if(&self, predicate: impl FnOnce(&T::Info) -> bool) -> Option<T> {
68 self.take_if_raw(|reader| predicate(reader.info()))
69 }
70}
71
72pub struct ByteStreamReader {
74 info: ByteStreamInfo,
75 chunk_rx: UnboundedReceiver<StreamResult<Bytes>>,
76}
77
78pub struct TextStreamReader {
80 info: TextStreamInfo,
81 chunk_rx: UnboundedReceiver<StreamResult<Bytes>>,
82}
83
84impl StreamReader for ByteStreamReader {
85 type Output = Bytes;
86 type Info = ByteStreamInfo;
87
88 fn info(&self) -> &ByteStreamInfo {
89 &self.info
90 }
91
92 async fn read_all(mut self) -> StreamResult<Bytes> {
93 let mut buffer = BytesMut::new();
94 while let Some(result) = self.next().await {
95 match result {
96 Ok(bytes) => buffer.extend_from_slice(&bytes),
97 Err(e) => return Err(e),
98 }
99 }
100 Ok(buffer.freeze())
101 }
102}
103
104impl ByteStreamReader {
105 pub async fn write_to_file(
114 mut self,
115 directory: Option<impl AsRef<std::path::Path>>,
116 name_override: Option<&str>,
117 ) -> StreamResult<std::path::PathBuf> {
118 let directory =
119 directory.map(|d| d.as_ref().to_path_buf()).unwrap_or_else(|| std::env::temp_dir());
120 let name = name_override.unwrap_or_else(|| &self.info.name);
121 let file_path = directory.join(name);
122
123 let mut file = tokio::fs::File::create(&file_path).await.map_err(StreamError::Io)?;
124
125 while let Some(result) = self.next().await {
126 let bytes = result?;
127 tokio::io::AsyncWriteExt::write_all(&mut file, &bytes)
128 .await
129 .map_err(StreamError::Io)?;
130 }
131 tokio::io::AsyncWriteExt::flush(&mut file).await.map_err(StreamError::Io)?;
132
133 Ok(file_path)
134 }
135}
136
137impl Stream for ByteStreamReader {
138 type Item = StreamResult<Bytes>;
139
140 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
141 let this = self.get_mut();
142 match Pin::new(&mut this.chunk_rx).poll_recv(cx) {
143 Poll::Ready(Some(Ok(chunk))) => Poll::Ready(Some(Ok(chunk))),
144 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
145 Poll::Ready(None) => Poll::Ready(None),
146 Poll::Pending => Poll::Pending,
147 }
148 }
149}
150
151impl StreamReader for TextStreamReader {
152 type Output = String;
153 type Info = TextStreamInfo;
154
155 fn info(&self) -> &TextStreamInfo {
156 &self.info
157 }
158
159 async fn read_all(mut self) -> StreamResult<String> {
160 let mut result = String::new();
161 while let Some(chunk) = self.next().await {
162 match chunk {
163 Ok(text) => result.push_str(&text),
164 Err(e) => return Err(e),
165 }
166 }
167 Ok(result)
168 }
169}
170
171impl Stream for TextStreamReader {
172 type Item = StreamResult<String>;
173
174 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
175 let this = self.get_mut();
176 match Pin::new(&mut this.chunk_rx).poll_recv(cx) {
177 Poll::Ready(Some(Ok(chunk))) => match String::from_utf8(chunk.into()) {
178 Ok(content) => Poll::Ready(Some(Ok(content))),
179 Err(e) => {
180 this.chunk_rx.close();
181 Poll::Ready(Some(Err(StreamError::from(e))))
182 }
183 },
184 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
185 Poll::Ready(None) => Poll::Ready(None),
186 Poll::Pending => Poll::Pending,
187 }
188 }
189}
190
191impl Debug for ByteStreamReader {
192 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
193 f.debug_struct("ByteStreamReader")
194 .field("id", &self.info.id())
195 .field("topic", &self.info.topic)
196 .finish()
197 }
198}
199
200impl Debug for TextStreamReader {
201 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202 f.debug_struct("TextStreamReader")
203 .field("id", &self.info.id())
204 .field("topic", &self.info.topic)
205 .finish()
206 }
207}
208
209pub(crate) enum AnyStreamReader {
210 Byte(ByteStreamReader),
211 Text(TextStreamReader),
212}
213
214impl AnyStreamReader {
215 pub(super) fn from(info: AnyStreamInfo) -> (Self, UnboundedSender<StreamResult<Bytes>>) {
217 let (chunk_tx, chunk_rx) = mpsc::unbounded_channel();
218 let reader = match info {
219 AnyStreamInfo::Byte(info) => Self::Byte(ByteStreamReader { info, chunk_rx }),
220 AnyStreamInfo::Text(info) => Self::Text(TextStreamReader { info, chunk_rx }),
221 };
222 return (reader, chunk_tx);
223 }
224}
225struct Descriptor {
226 progress: StreamProgress,
227 chunk_tx: UnboundedSender<StreamResult<Bytes>>,
228 encryption_type: EncryptionType,
229 }
231
232#[derive(Clone)]
233pub(crate) struct IncomingStreamManager {
234 inner: Arc<Mutex<ManagerInner>>,
235 open_tx: UnboundedSender<(AnyStreamReader, String)>,
236}
237
238#[derive(Default)]
239struct ManagerInner {
240 open_streams: HashMap<String, Descriptor>,
241}
242
243impl IncomingStreamManager {
244 pub fn new() -> (Self, UnboundedReceiver<(AnyStreamReader, String)>) {
245 let (open_tx, open_rx) = mpsc::unbounded_channel();
246 (Self { inner: Arc::new(Mutex::new(Default::default())), open_tx }, open_rx)
247 }
248
249 pub fn handle_header(
251 &self,
252 header: proto::Header,
253 identity: String,
254 encryption_type: livekit_protocol::encryption::Type,
255 ) {
256 let Ok(info) = AnyStreamInfo::try_from_with_encryption(header, encryption_type.into())
257 .inspect_err(|e| log::error!("Invalid header: {}", e))
258 else {
259 return;
260 };
261
262 let id = info.id().to_owned();
263 let bytes_total = info.total_length();
264 let stream_encryption_type = info.encryption_type();
265
266 let mut inner = self.inner.lock();
267 if inner.open_streams.contains_key(&id) {
268 log::error!("Stream '{}' already open", id);
269 return;
270 }
271
272 let (reader, chunk_tx) = AnyStreamReader::from(info);
273 let _ = self.open_tx.send((reader, identity));
274
275 let descriptor = Descriptor {
276 progress: StreamProgress { bytes_total, ..Default::default() },
277 chunk_tx,
278 encryption_type: stream_encryption_type,
279 };
280 inner.open_streams.insert(id, descriptor);
281 }
282
283 pub fn handle_chunk(
285 &self,
286 chunk: proto::Chunk,
287 encryption_type: livekit_protocol::encryption::Type,
288 ) {
289 let id = chunk.stream_id;
290 let mut inner = self.inner.lock();
291 let Some(descriptor) = inner.open_streams.get_mut(&id) else {
292 return;
293 };
294
295 if descriptor.encryption_type != encryption_type.into() {
296 inner.close_stream_with_error(&id, StreamError::EncryptionTypeMismatch);
297 return;
298 }
299
300 if descriptor.progress.chunk_index != chunk.chunk_index {
301 inner.close_stream_with_error(&id, StreamError::MissedChunk);
302 return;
303 }
304
305 descriptor.progress.chunk_index += 1;
306 descriptor.progress.bytes_processed += chunk.content.len() as u64;
307
308 if match descriptor.progress.bytes_total {
309 Some(total) => descriptor.progress.bytes_processed > total as u64,
310 None => false,
311 } {
312 inner.close_stream_with_error(&id, StreamError::LengthExceeded);
313 return;
314 }
315 inner.yield_chunk(&id, Bytes::from(chunk.content));
316 }
318
319 pub fn handle_trailer(&self, trailer: proto::Trailer) {
321 let id = trailer.stream_id;
322 let mut inner = self.inner.lock();
323 let Some(descriptor) = inner.open_streams.get_mut(&id) else {
324 return;
325 };
326
327 if !match descriptor.progress.bytes_total {
328 Some(total) => descriptor.progress.bytes_processed >= total as u64,
329 None => true,
330 } {
331 inner.close_stream_with_error(&id, StreamError::Incomplete);
332 return;
333 }
334 if !trailer.reason.is_empty() {
335 inner.close_stream_with_error(&id, StreamError::AbnormalEnd(trailer.reason));
336 return;
337 }
338 inner.close_stream(&id);
339 }
340}
341
342impl ManagerInner {
343 fn yield_chunk(&mut self, id: &str, chunk: Bytes) {
344 let Some(descriptor) = self.open_streams.get_mut(id) else {
345 return;
346 };
347 if descriptor.chunk_tx.send(Ok(chunk)).is_err() {
348 self.close_stream(id);
350 }
351 }
352
353 fn close_stream(&mut self, id: &str) {
354 self.open_streams.remove(id);
356 }
357
358 fn close_stream_with_error(&mut self, id: &str, error: StreamError) {
359 if let Some(descriptor) = self.open_streams.remove(id) {
360 let _ = descriptor.chunk_tx.send(Err(error));
361 }
362 }
363}