rtmp_rs/session/
state.rs

1//! Session state machine
2//!
3//! Tracks the overall state of an RTMP session from connection to disconnection.
4
5use std::collections::HashMap;
6use std::net::SocketAddr;
7use std::time::Instant;
8
9use super::stream::StreamState;
10use crate::protocol::message::ConnectParams;
11use crate::protocol::quirks::EncoderType;
12
13/// Session lifecycle state
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum SessionPhase {
16    /// TCP connected, handshake not started
17    Connected,
18    /// Handshake in progress
19    Handshaking,
20    /// Handshake complete, waiting for connect command
21    WaitingConnect,
22    /// Connect command received and accepted
23    Active,
24    /// Session is closing
25    Closing,
26    /// Session closed
27    Closed,
28}
29
30/// Complete session state
31#[derive(Debug)]
32pub struct SessionState {
33    /// Unique session ID
34    pub id: u64,
35
36    /// Remote peer address
37    pub peer_addr: SocketAddr,
38
39    /// Current phase
40    pub phase: SessionPhase,
41
42    /// Connection start time
43    pub connected_at: Instant,
44
45    /// Time when handshake completed
46    pub handshake_completed_at: Option<Instant>,
47
48    /// Connect parameters (after connect command)
49    pub connect_params: Option<ConnectParams>,
50
51    /// Detected encoder type
52    pub encoder_type: EncoderType,
53
54    /// Per-stream states (keyed by message stream ID)
55    pub streams: HashMap<u32, StreamState>,
56
57    /// Next message stream ID to allocate
58    next_stream_id: u32,
59
60    /// Negotiated chunk size (incoming)
61    pub in_chunk_size: u32,
62
63    /// Negotiated chunk size (outgoing)
64    pub out_chunk_size: u32,
65
66    /// Window acknowledgement size
67    pub window_ack_size: u32,
68
69    /// Bytes received since last acknowledgement
70    pub bytes_received: u64,
71
72    /// Bytes sent
73    pub bytes_sent: u64,
74
75    /// Last acknowledgement sequence
76    pub last_ack_sequence: u32,
77}
78
79impl SessionState {
80    /// Create a new session state
81    pub fn new(id: u64, peer_addr: SocketAddr) -> Self {
82        Self {
83            id,
84            peer_addr,
85            phase: SessionPhase::Connected,
86            connected_at: Instant::now(),
87            handshake_completed_at: None,
88            connect_params: None,
89            encoder_type: EncoderType::Unknown,
90            streams: HashMap::new(),
91            next_stream_id: 1, // Stream 0 is reserved for NetConnection
92            in_chunk_size: 128,
93            out_chunk_size: 128,
94            window_ack_size: 2_500_000,
95            bytes_received: 0,
96            bytes_sent: 0,
97            last_ack_sequence: 0,
98        }
99    }
100
101    /// Transition to handshaking phase
102    pub fn start_handshake(&mut self) {
103        if self.phase == SessionPhase::Connected {
104            self.phase = SessionPhase::Handshaking;
105        }
106    }
107
108    /// Complete handshake
109    pub fn complete_handshake(&mut self) {
110        if self.phase == SessionPhase::Handshaking {
111            self.phase = SessionPhase::WaitingConnect;
112            self.handshake_completed_at = Some(Instant::now());
113        }
114    }
115
116    /// Handle connect command
117    pub fn on_connect(&mut self, params: ConnectParams, encoder_type: EncoderType) {
118        self.connect_params = Some(params);
119        self.encoder_type = encoder_type;
120        self.phase = SessionPhase::Active;
121    }
122
123    /// Allocate a new message stream ID
124    pub fn allocate_stream_id(&mut self) -> u32 {
125        let id = self.next_stream_id;
126        self.next_stream_id += 1;
127        self.streams.insert(id, StreamState::new(id));
128        id
129    }
130
131    /// Get a stream by ID
132    pub fn get_stream(&self, stream_id: u32) -> Option<&StreamState> {
133        self.streams.get(&stream_id)
134    }
135
136    /// Get a mutable stream by ID
137    pub fn get_stream_mut(&mut self, stream_id: u32) -> Option<&mut StreamState> {
138        self.streams.get_mut(&stream_id)
139    }
140
141    /// Remove a stream
142    pub fn remove_stream(&mut self, stream_id: u32) -> Option<StreamState> {
143        self.streams.remove(&stream_id)
144    }
145
146    /// Update bytes received and check if acknowledgement needed
147    pub fn add_bytes_received(&mut self, bytes: u64) -> bool {
148        self.bytes_received += bytes;
149
150        // Check if we need to send acknowledgement
151        let delta = self.bytes_received as u32 - self.last_ack_sequence;
152        delta >= self.window_ack_size
153    }
154
155    /// Mark acknowledgement sent
156    pub fn mark_ack_sent(&mut self) {
157        self.last_ack_sequence = self.bytes_received as u32;
158    }
159
160    /// Get session duration
161    pub fn duration(&self) -> std::time::Duration {
162        self.connected_at.elapsed()
163    }
164
165    /// Check if session is active
166    pub fn is_active(&self) -> bool {
167        self.phase == SessionPhase::Active
168    }
169
170    /// Start closing the session
171    pub fn close(&mut self) {
172        self.phase = SessionPhase::Closing;
173    }
174
175    /// Get the application name
176    pub fn app(&self) -> Option<&str> {
177        self.connect_params.as_ref().map(|p| p.app.as_str())
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184    use std::net::{IpAddr, Ipv4Addr};
185
186    #[test]
187    fn test_session_lifecycle() {
188        let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 1935);
189        let mut state = SessionState::new(1, addr);
190
191        assert_eq!(state.phase, SessionPhase::Connected);
192
193        state.start_handshake();
194        assert_eq!(state.phase, SessionPhase::Handshaking);
195
196        state.complete_handshake();
197        assert_eq!(state.phase, SessionPhase::WaitingConnect);
198        assert!(state.handshake_completed_at.is_some());
199
200        let params = ConnectParams::default();
201        state.on_connect(params, EncoderType::Obs);
202        assert_eq!(state.phase, SessionPhase::Active);
203        assert!(state.is_active());
204    }
205
206    #[test]
207    fn test_stream_allocation() {
208        let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 1935);
209        let mut state = SessionState::new(1, addr);
210
211        let id1 = state.allocate_stream_id();
212        let id2 = state.allocate_stream_id();
213
214        assert_eq!(id1, 1);
215        assert_eq!(id2, 2);
216        assert!(state.get_stream(1).is_some());
217        assert!(state.get_stream(2).is_some());
218    }
219}