1use std::collections::HashMap;
6use std::net::SocketAddr;
7use std::time::Instant;
8
9use super::stream::StreamState;
10use crate::protocol::message::ConnectParams;
11use crate::protocol::quirks::EncoderType;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum SessionPhase {
16 Connected,
18 Handshaking,
20 WaitingConnect,
22 Active,
24 Closing,
26 Closed,
28}
29
30#[derive(Debug)]
32pub struct SessionState {
33 pub id: u64,
35
36 pub peer_addr: SocketAddr,
38
39 pub phase: SessionPhase,
41
42 pub connected_at: Instant,
44
45 pub handshake_completed_at: Option<Instant>,
47
48 pub connect_params: Option<ConnectParams>,
50
51 pub encoder_type: EncoderType,
53
54 pub streams: HashMap<u32, StreamState>,
56
57 next_stream_id: u32,
59
60 pub in_chunk_size: u32,
62
63 pub out_chunk_size: u32,
65
66 pub window_ack_size: u32,
68
69 pub bytes_received: u64,
71
72 pub bytes_sent: u64,
74
75 pub last_ack_sequence: u32,
77}
78
79impl SessionState {
80 pub fn new(id: u64, peer_addr: SocketAddr) -> Self {
82 Self {
83 id,
84 peer_addr,
85 phase: SessionPhase::Connected,
86 connected_at: Instant::now(),
87 handshake_completed_at: None,
88 connect_params: None,
89 encoder_type: EncoderType::Unknown,
90 streams: HashMap::new(),
91 next_stream_id: 1, in_chunk_size: 128,
93 out_chunk_size: 128,
94 window_ack_size: 2_500_000,
95 bytes_received: 0,
96 bytes_sent: 0,
97 last_ack_sequence: 0,
98 }
99 }
100
101 pub fn start_handshake(&mut self) {
103 if self.phase == SessionPhase::Connected {
104 self.phase = SessionPhase::Handshaking;
105 }
106 }
107
108 pub fn complete_handshake(&mut self) {
110 if self.phase == SessionPhase::Handshaking {
111 self.phase = SessionPhase::WaitingConnect;
112 self.handshake_completed_at = Some(Instant::now());
113 }
114 }
115
116 pub fn on_connect(&mut self, params: ConnectParams, encoder_type: EncoderType) {
118 self.connect_params = Some(params);
119 self.encoder_type = encoder_type;
120 self.phase = SessionPhase::Active;
121 }
122
123 pub fn allocate_stream_id(&mut self) -> u32 {
125 let id = self.next_stream_id;
126 self.next_stream_id += 1;
127 self.streams.insert(id, StreamState::new(id));
128 id
129 }
130
131 pub fn get_stream(&self, stream_id: u32) -> Option<&StreamState> {
133 self.streams.get(&stream_id)
134 }
135
136 pub fn get_stream_mut(&mut self, stream_id: u32) -> Option<&mut StreamState> {
138 self.streams.get_mut(&stream_id)
139 }
140
141 pub fn remove_stream(&mut self, stream_id: u32) -> Option<StreamState> {
143 self.streams.remove(&stream_id)
144 }
145
146 pub fn add_bytes_received(&mut self, bytes: u64) -> bool {
148 self.bytes_received += bytes;
149
150 let delta = self.bytes_received as u32 - self.last_ack_sequence;
152 delta >= self.window_ack_size
153 }
154
155 pub fn mark_ack_sent(&mut self) {
157 self.last_ack_sequence = self.bytes_received as u32;
158 }
159
160 pub fn duration(&self) -> std::time::Duration {
162 self.connected_at.elapsed()
163 }
164
165 pub fn is_active(&self) -> bool {
167 self.phase == SessionPhase::Active
168 }
169
170 pub fn close(&mut self) {
172 self.phase = SessionPhase::Closing;
173 }
174
175 pub fn app(&self) -> Option<&str> {
177 self.connect_params.as_ref().map(|p| p.app.as_str())
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184 use std::net::{IpAddr, Ipv4Addr};
185
186 #[test]
187 fn test_session_lifecycle() {
188 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 1935);
189 let mut state = SessionState::new(1, addr);
190
191 assert_eq!(state.phase, SessionPhase::Connected);
192
193 state.start_handshake();
194 assert_eq!(state.phase, SessionPhase::Handshaking);
195
196 state.complete_handshake();
197 assert_eq!(state.phase, SessionPhase::WaitingConnect);
198 assert!(state.handshake_completed_at.is_some());
199
200 let params = ConnectParams::default();
201 state.on_connect(params, EncoderType::Obs);
202 assert_eq!(state.phase, SessionPhase::Active);
203 assert!(state.is_active());
204 }
205
206 #[test]
207 fn test_stream_allocation() {
208 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 1935);
209 let mut state = SessionState::new(1, addr);
210
211 let id1 = state.allocate_stream_id();
212 let id2 = state.allocate_stream_id();
213
214 assert_eq!(id1, 1);
215 assert_eq!(id2, 2);
216 assert!(state.get_stream(1).is_some());
217 assert!(state.get_stream(2).is_some());
218 }
219}