1use std::time::Duration;
4
5use bytes::BytesMut;
6use scuffle_bytes_util::{BytesCursorExt, StringCow};
7use scuffle_context::ContextFutExt;
8use scuffle_future_ext::FutureExt;
9use tokio::io::{AsyncReadExt, AsyncWriteExt};
10
11use crate::chunk::CHUNK_SIZE;
12use crate::chunk::reader::ChunkReader;
13use crate::chunk::writer::ChunkWriter;
14use crate::command_messages::netconnection::{
15    CapsExMask, NetConnectionCommand, NetConnectionCommandConnect, NetConnectionCommandConnectResult,
16};
17use crate::command_messages::netstream::{NetStreamCommand, NetStreamCommandPublishPublishingType};
18use crate::command_messages::on_status::{OnStatus, OnStatusCode};
19use crate::command_messages::{Command, CommandResultLevel, CommandType};
20use crate::handshake;
21use crate::handshake::HandshakeServer;
22use crate::messages::MessageData;
23use crate::protocol_control_messages::{
24    ProtocolControlMessageAcknowledgement, ProtocolControlMessageSetChunkSize, ProtocolControlMessageSetPeerBandwidth,
25    ProtocolControlMessageSetPeerBandwidthLimitType, ProtocolControlMessageWindowAcknowledgementSize,
26};
27use crate::user_control_messages::EventMessageStreamBegin;
28
29mod error;
30mod handler;
31
32pub use error::ServerSessionError;
33pub use handler::{SessionData, SessionHandler};
34
35const DEFAULT_ACKNOWLEDGEMENT_WINDOW_SIZE: u32 = 2_500_000; pub struct ServerSession<S, H> {
45    ctx: Option<scuffle_context::Context>,
48    reconnect_request_sent: bool,
50    app_name: Option<StringCow<'static>>,
60    caps_ex: Option<CapsExMask>,
61    io: S,
63    handler: H,
64    acknowledgement_window_size: u32,
66    sequence_number: u32,
69    read_buf: BytesMut,
71    write_buf: Vec<u8>,
73    skip_read: bool,
77    chunk_reader: ChunkReader,
80    chunk_writer: ChunkWriter,
82    publishing_stream_ids: Vec<u32>,
84}
85
86impl<S, H> ServerSession<S, H> {
87    pub fn new(io: S, handler: H) -> Self {
89        Self {
90            ctx: None,
91            reconnect_request_sent: false,
92            app_name: None,
93            caps_ex: None,
94            io,
95            handler,
96            acknowledgement_window_size: DEFAULT_ACKNOWLEDGEMENT_WINDOW_SIZE,
97            sequence_number: 0,
98            skip_read: false,
99            chunk_reader: ChunkReader::default(),
100            chunk_writer: ChunkWriter::default(),
101            read_buf: BytesMut::new(),
102            write_buf: Vec::new(),
103            publishing_stream_ids: Vec::new(),
104        }
105    }
106
107    pub fn with_context(mut self, ctx: scuffle_context::Context) -> Self {
109        self.ctx = Some(ctx);
110        self
111    }
112}
113
114impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, H: SessionHandler> ServerSession<S, H> {
115    pub async fn run(mut self) -> Result<bool, crate::error::RtmpError> {
121        let ctx = self.ctx.clone().unwrap_or_else(scuffle_context::Context::global);
122
123        let mut handshaker = HandshakeServer::default();
124        loop {
126            match self.drive_handshake(&mut handshaker).with_context(&ctx).await {
127                Some(Ok(false)) => self.flush().await?, Some(Ok(true)) => break,                Some(Err(e)) => return Err(e),
130                None => return Ok(false), }
132        }
133
134        drop(handshaker);
137
138        tracing::debug!("handshake complete");
139
140        loop {
142            match self.drive().await {
143                Ok(true) => self.flush().await?, Ok(false) => break,              Err(err) if err.is_client_closed() => {
146                    tracing::debug!("client closed the connection");
149                    break;
150                }
151                Err(e) => return Err(e),
152            }
153        }
154
155        Ok(self.publishing_stream_ids.is_empty())
160    }
161
162    async fn drive_handshake(&mut self, handshaker: &mut HandshakeServer) -> Result<bool, crate::error::RtmpError> {
170        const READ_SIZE: usize = handshake::RTMP_HANDSHAKE_SIZE + 1;
172        self.read_buf.reserve(READ_SIZE);
173
174        let mut bytes_read = 0;
175        while bytes_read < READ_SIZE {
176            let n = self
177                .io
178                .read_buf(&mut self.read_buf)
179                .with_timeout(Duration::from_secs(2))
180                .await
181                .map_err(ServerSessionError::Timeout)??;
182            bytes_read += n;
183
184            self.sequence_number = self.sequence_number.wrapping_add(n.try_into().unwrap_or(u32::MAX));
185        }
186
187        let mut cursor = std::io::Cursor::new(self.read_buf.split().freeze());
188
189        handshaker.handshake(&mut cursor, &mut self.write_buf)?;
190
191        if handshaker.is_finished() {
192            let over_read = cursor.extract_remaining();
193
194            if !over_read.is_empty() {
195                self.skip_read = true;
196                self.read_buf.extend_from_slice(&over_read);
197            }
198
199            self.send_set_chunk_size().await?;
200
201            Ok(true)
205        } else {
206            Ok(false)
210        }
211    }
212
213    async fn drive(&mut self) -> Result<bool, crate::error::RtmpError> {
222        if !self.reconnect_request_sent
224            && self.caps_ex.is_some_and(|c| c.intersects(CapsExMask::Reconnect))
225            && self.ctx.as_ref().is_some_and(|ctx| ctx.is_done())
226        {
227            tracing::debug!("sending reconnect request");
228
229            OnStatus {
230                code: OnStatusCode::NET_CONNECTION_CONNECT_RECONNECT_REQUEST,
231                level: CommandResultLevel::Status,
232                description: None,
233                others: None,
234            }
235            .write(&mut self.write_buf, 0.0)?;
236
237            self.reconnect_request_sent = true;
238        }
239
240        if self.skip_read {
242            self.skip_read = false;
243        } else {
244            self.read_buf.reserve(CHUNK_SIZE);
245
246            let n = self
247                .io
248                .read_buf(&mut self.read_buf)
249                .with_timeout(Duration::from_millis(2500))
250                .await
251                .map_err(ServerSessionError::Timeout)?? as u32;
252
253            if n == 0 {
254                return Ok(false);
255            }
256
257            if (self.sequence_number % self.acknowledgement_window_size) + n >= self.acknowledgement_window_size {
268                tracing::debug!(sequence_number = %self.sequence_number, "sending acknowledgement");
269
270                ProtocolControlMessageAcknowledgement {
272                    sequence_number: self.sequence_number,
273                }
274                .write(&mut self.write_buf, &self.chunk_writer)?;
275            }
276
277            self.sequence_number = self.sequence_number.wrapping_add(n);
279        }
280
281        self.process_chunks().await?;
282
283        Ok(true)
284    }
285
286    async fn process_chunks(&mut self) -> Result<(), crate::error::RtmpError> {
288        while let Some(chunk) = self.chunk_reader.read_chunk(&mut self.read_buf)? {
289            let timestamp = chunk.message_header.timestamp;
290            let msg_stream_id = chunk.message_header.msg_stream_id;
291
292            let msg = MessageData::read(&chunk)?;
293            self.process_message(msg, msg_stream_id, timestamp).await?;
294        }
295
296        Ok(())
297    }
298
299    async fn process_message(
301        &mut self,
302        msg: MessageData<'_>,
303        stream_id: u32,
304        timestamp: u32,
305    ) -> Result<(), crate::error::RtmpError> {
306        match msg {
307            MessageData::Amf0Command(command) => self.on_command_message(stream_id, command).await?,
308            MessageData::SetChunkSize(ProtocolControlMessageSetChunkSize { chunk_size }) => {
309                self.on_set_chunk_size(chunk_size as usize)?;
310            }
311            MessageData::SetAcknowledgementWindowSize(ProtocolControlMessageWindowAcknowledgementSize {
312                acknowledgement_window_size,
313            }) => {
314                self.on_acknowledgement_window_size(acknowledgement_window_size)?;
315            }
316            MessageData::AudioData { data } => {
317                self.handler
318                    .on_data(stream_id, SessionData::Audio { timestamp, data })
319                    .await?;
320            }
321            MessageData::VideoData { data } => {
322                self.handler
323                    .on_data(stream_id, SessionData::Video { timestamp, data })
324                    .await?;
325            }
326            MessageData::DataAmf0 { data } => {
327                self.handler.on_data(stream_id, SessionData::Amf0 { timestamp, data }).await?;
328            }
329            MessageData::Unknown(unknown_message) => {
330                self.handler.on_unknown_message(stream_id, unknown_message).await?;
331            }
332            _ => {}
334        }
335
336        Ok(())
337    }
338
339    async fn send_set_chunk_size(&mut self) -> Result<(), crate::error::RtmpError> {
341        ProtocolControlMessageSetChunkSize {
342            chunk_size: CHUNK_SIZE as u32,
343        }
344        .write(&mut self.write_buf, &self.chunk_writer)?;
345        self.chunk_writer.set_chunk_size(CHUNK_SIZE);
346
347        Ok(())
348    }
349
350    async fn on_command_message(&mut self, stream_id: u32, command: Command<'_>) -> Result<(), crate::error::RtmpError> {
353        match command.command_type {
354            CommandType::NetConnection(NetConnectionCommand::Connect(connect)) => {
355                self.on_command_connect(stream_id, command.transaction_id, connect).await?;
356            }
357            CommandType::NetConnection(NetConnectionCommand::CreateStream) => {
358                self.on_command_create_stream(stream_id, command.transaction_id).await?;
359            }
360            CommandType::NetStream(NetStreamCommand::Play { .. })
361            | CommandType::NetStream(NetStreamCommand::Play2 { .. }) => {
362                return Err(crate::error::RtmpError::Session(ServerSessionError::PlayNotSupported));
363            }
364            CommandType::NetStream(NetStreamCommand::DeleteStream {
365                stream_id: delete_stream_id,
366            }) => {
367                self.on_command_delete_stream(stream_id, command.transaction_id, delete_stream_id)
368                    .await?;
369            }
370            CommandType::NetStream(NetStreamCommand::CloseStream) => {
371                }
373            CommandType::NetStream(NetStreamCommand::Publish {
374                publishing_name,
375                publishing_type,
376            }) => {
377                self.on_command_publish(stream_id, command.transaction_id, publishing_name.as_str(), publishing_type)
378                    .await?;
379            }
380            CommandType::Unknown(unknown_command) => {
381                self.handler.on_unknown_command(stream_id, unknown_command).await?;
382            }
383            _ => {}
385        }
386
387        Ok(())
388    }
389
390    fn on_set_chunk_size(&mut self, chunk_size: usize) -> Result<(), crate::error::RtmpError> {
393        if self.chunk_reader.update_max_chunk_size(chunk_size) {
394            Ok(())
395        } else {
396            Err(crate::error::RtmpError::Session(ServerSessionError::InvalidChunkSize(
397                chunk_size,
398            )))
399        }
400    }
401
402    fn on_acknowledgement_window_size(&mut self, acknowledgement_window_size: u32) -> Result<(), crate::error::RtmpError> {
405        tracing::debug!(acknowledgement_window_size = %acknowledgement_window_size, "received new acknowledgement window size");
406        self.acknowledgement_window_size = acknowledgement_window_size;
407        Ok(())
408    }
409
410    async fn on_command_connect(
414        &mut self,
415        _stream_id: u32,
416        transaction_id: f64,
417        connect: NetConnectionCommandConnect<'_>,
418    ) -> Result<(), crate::error::RtmpError> {
419        ProtocolControlMessageWindowAcknowledgementSize {
420            acknowledgement_window_size: CHUNK_SIZE as u32,
421        }
422        .write(&mut self.write_buf, &self.chunk_writer)?;
423
424        ProtocolControlMessageSetPeerBandwidth {
425            acknowledgement_window_size: CHUNK_SIZE as u32,
426            limit_type: ProtocolControlMessageSetPeerBandwidthLimitType::Dynamic,
427        }
428        .write(&mut self.write_buf, &self.chunk_writer)?;
429
430        self.app_name = Some(connect.app.into_owned());
431        self.caps_ex = connect.caps_ex;
432
433        let result = NetConnectionCommand::ConnectResult(NetConnectionCommandConnectResult::default());
434
435        Command {
436            command_type: CommandType::NetConnection(result),
437            transaction_id,
438        }
439        .write(&mut self.write_buf, &self.chunk_writer)?;
440
441        Ok(())
442    }
443
444    async fn on_command_create_stream(
449        &mut self,
450        _stream_id: u32,
451        transaction_id: f64,
452    ) -> Result<(), crate::error::RtmpError> {
453        Command {
455            command_type: CommandType::NetConnection(NetConnectionCommand::CreateStreamResult { stream_id: 1.0 }),
456            transaction_id,
457        }
458        .write(&mut self.write_buf, &self.chunk_writer)?;
459
460        Ok(())
461    }
462
463    async fn on_command_delete_stream(
468        &mut self,
469        _stream_id: u32,
470        transaction_id: f64,
471        delete_stream_id: f64,
472    ) -> Result<(), crate::error::RtmpError> {
473        let stream_id = delete_stream_id as u32;
474
475        self.handler.on_unpublish(stream_id).await?;
476
477        self.publishing_stream_ids.retain(|id| *id != stream_id);
479
480        Command {
481            command_type: CommandType::OnStatus(OnStatus {
482                level: CommandResultLevel::Status,
483                code: OnStatusCode::NET_STREAM_DELETE_STREAM_SUCCESS,
484                description: None,
485                others: None,
486            }),
487            transaction_id,
488        }
489        .write(&mut self.write_buf, &self.chunk_writer)?;
490
491        Ok(())
492    }
493
494    async fn on_command_publish(
498        &mut self,
499        stream_id: u32,
500        transaction_id: f64,
501        publishing_name: &str,
502        _publishing_type: NetStreamCommandPublishPublishingType<'_>,
503    ) -> Result<(), crate::error::RtmpError> {
504        let Some(app_name) = &self.app_name else {
505            return Err(crate::error::RtmpError::Session(ServerSessionError::PublishBeforeConnect));
507        };
508
509        self.handler.on_publish(stream_id, app_name.as_ref(), publishing_name).await?;
510
511        self.publishing_stream_ids.push(stream_id);
512
513        EventMessageStreamBegin { stream_id }.write(&self.chunk_writer, &mut self.write_buf)?;
514
515        Command {
516            command_type: CommandType::OnStatus(OnStatus {
517                level: CommandResultLevel::Status,
518                code: OnStatusCode::NET_STREAM_PUBLISH_START,
519                description: None,
520                others: None,
521            }),
522            transaction_id,
523        }
524        .write(&mut self.write_buf, &self.chunk_writer)?;
525
526        Ok(())
527    }
528
529    async fn flush(&mut self) -> Result<(), crate::error::RtmpError> {
530        if !self.write_buf.is_empty() {
531            self.io
532                .write_all(self.write_buf.as_ref())
533                .with_timeout(Duration::from_secs(2))
534                .await
535                .map_err(ServerSessionError::Timeout)??;
536            self.write_buf.clear();
537        }
538
539        Ok(())
540    }
541}