rtmp_rs/protocol/
handshake.rs1use bytes::{Buf, BufMut, Bytes, BytesMut};
26use std::time::{SystemTime, UNIX_EPOCH};
27
28use crate::error::{HandshakeError, Result};
29use crate::protocol::constants::{HANDSHAKE_SIZE, RTMP_VERSION};
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum HandshakeRole {
34 Client,
35 Server,
36}
37
38#[derive(Debug)]
40pub struct Handshake {
41 role: HandshakeRole,
42 state: HandshakeState,
43 our_packet: Option<[u8; HANDSHAKE_SIZE]>,
45 peer_packet: Option<[u8; HANDSHAKE_SIZE]>,
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50#[allow(dead_code)] enum HandshakeState {
52 Initial,
54 WaitingForPeerPacket,
56 NeedToSendResponse,
58 WaitingForPeerResponse,
60 Done,
62}
63
64impl Handshake {
65 pub fn new(role: HandshakeRole) -> Self {
67 Self {
68 role,
69 state: HandshakeState::Initial,
70 our_packet: None,
71 peer_packet: None,
72 }
73 }
74
75 pub fn is_done(&self) -> bool {
77 self.state == HandshakeState::Done
78 }
79
80 pub fn bytes_needed(&self) -> usize {
82 match self.state {
83 HandshakeState::Initial => 0,
84 HandshakeState::WaitingForPeerPacket => 1 + HANDSHAKE_SIZE, HandshakeState::NeedToSendResponse => 0,
86 HandshakeState::WaitingForPeerResponse => {
87 match self.role {
88 HandshakeRole::Client => HANDSHAKE_SIZE, HandshakeRole::Server => HANDSHAKE_SIZE, }
91 }
92 HandshakeState::Done => 0,
93 }
94 }
95
96 pub fn generate_initial(&mut self) -> Option<Bytes> {
101 if self.state != HandshakeState::Initial {
102 return None;
103 }
104
105 match self.role {
106 HandshakeRole::Client => {
107 let mut buf = BytesMut::with_capacity(1 + HANDSHAKE_SIZE);
108
109 buf.put_u8(RTMP_VERSION);
111
112 let c1 = generate_packet();
114 self.our_packet = Some(c1);
115 buf.put_slice(&c1);
116
117 self.state = HandshakeState::WaitingForPeerPacket;
118 Some(buf.freeze())
119 }
120 HandshakeRole::Server => {
121 self.state = HandshakeState::WaitingForPeerPacket;
123 None
124 }
125 }
126 }
127
128 pub fn process(&mut self, data: &mut Bytes) -> Result<Option<Bytes>> {
134 match self.state {
135 HandshakeState::WaitingForPeerPacket => self.process_peer_packet(data),
136 HandshakeState::WaitingForPeerResponse => self.process_peer_response(data),
137 _ => Ok(None),
138 }
139 }
140
141 fn process_peer_packet(&mut self, data: &mut Bytes) -> Result<Option<Bytes>> {
143 match self.role {
144 HandshakeRole::Server => {
145 if data.remaining() < 1 + HANDSHAKE_SIZE {
147 return Ok(None); }
149
150 let version = data.get_u8();
152 if version != RTMP_VERSION {
153 if version < 3 {
155 return Err(HandshakeError::InvalidVersion(version).into());
156 }
157 }
158
159 let mut c1 = [0u8; HANDSHAKE_SIZE];
161 data.copy_to_slice(&mut c1);
162 self.peer_packet = Some(c1);
163
164 let mut response = BytesMut::with_capacity(1 + HANDSHAKE_SIZE * 2);
166
167 response.put_u8(RTMP_VERSION);
169
170 let s1 = generate_packet();
172 self.our_packet = Some(s1);
173 response.put_slice(&s1);
174
175 let s2 = generate_echo(&c1);
177 response.put_slice(&s2);
178
179 self.state = HandshakeState::WaitingForPeerResponse;
180 Ok(Some(response.freeze()))
181 }
182 HandshakeRole::Client => {
183 if data.remaining() < 1 + HANDSHAKE_SIZE * 2 {
185 return Ok(None); }
187
188 let version = data.get_u8();
190 if version != RTMP_VERSION && version < 3 {
191 return Err(HandshakeError::InvalidVersion(version).into());
192 }
193
194 let mut s1 = [0u8; HANDSHAKE_SIZE];
196 data.copy_to_slice(&mut s1);
197 self.peer_packet = Some(s1);
198
199 let mut s2 = [0u8; HANDSHAKE_SIZE];
201 data.copy_to_slice(&mut s2);
202
203 let c2 = generate_echo(&s1);
208
209 self.state = HandshakeState::Done;
210 Ok(Some(Bytes::copy_from_slice(&c2)))
211 }
212 }
213 }
214
215 fn process_peer_response(&mut self, data: &mut Bytes) -> Result<Option<Bytes>> {
217 match self.role {
218 HandshakeRole::Server => {
219 if data.remaining() < HANDSHAKE_SIZE {
221 return Ok(None);
222 }
223
224 let mut c2 = [0u8; HANDSHAKE_SIZE];
226 data.copy_to_slice(&mut c2);
227
228 self.state = HandshakeState::Done;
230 Ok(None)
231 }
232 HandshakeRole::Client => {
233 self.state = HandshakeState::Done;
235 Ok(None)
236 }
237 }
238 }
239}
240
241fn generate_packet() -> [u8; HANDSHAKE_SIZE] {
248 let mut packet = [0u8; HANDSHAKE_SIZE];
249
250 let timestamp = SystemTime::now()
252 .duration_since(UNIX_EPOCH)
253 .map(|d| d.as_millis() as u32)
254 .unwrap_or(0);
255
256 packet[0..4].copy_from_slice(×tamp.to_be_bytes());
257
258 packet[4..8].copy_from_slice(&[0, 0, 0, 0]);
260
261 let mut seed = timestamp as u64;
264 for chunk in packet[8..].chunks_mut(8) {
265 seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
266 let bytes = seed.to_le_bytes();
267 let len = chunk.len().min(8);
268 chunk[..len].copy_from_slice(&bytes[..len]);
269 }
270
271 packet
272}
273
274fn generate_echo(peer_packet: &[u8; HANDSHAKE_SIZE]) -> [u8; HANDSHAKE_SIZE] {
281 let mut echo = *peer_packet;
282
283 let timestamp = SystemTime::now()
285 .duration_since(UNIX_EPOCH)
286 .map(|d| d.as_millis() as u32)
287 .unwrap_or(0);
288
289 echo[4..8].copy_from_slice(×tamp.to_be_bytes());
290
291 echo
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297
298 #[test]
299 fn test_client_server_handshake() {
300 let mut client = Handshake::new(HandshakeRole::Client);
301 let mut server = Handshake::new(HandshakeRole::Server);
302
303 let c0c1 = client
305 .generate_initial()
306 .expect("Client should generate C0C1");
307 assert_eq!(c0c1.len(), 1 + HANDSHAKE_SIZE);
308
309 let mut c0c1_buf = c0c1;
311 server.generate_initial(); let s0s1s2 = server
313 .process(&mut c0c1_buf)
314 .unwrap()
315 .expect("Server should generate S0S1S2");
316 assert_eq!(s0s1s2.len(), 1 + HANDSHAKE_SIZE * 2);
317
318 let mut s0s1s2_buf = s0s1s2;
320 let c2 = client
321 .process(&mut s0s1s2_buf)
322 .unwrap()
323 .expect("Client should generate C2");
324 assert_eq!(c2.len(), HANDSHAKE_SIZE);
325 assert!(client.is_done());
326
327 let mut c2_buf = c2;
329 let response = server.process(&mut c2_buf).unwrap();
330 assert!(response.is_none());
331 assert!(server.is_done());
332 }
333
334 #[test]
335 fn test_packet_generation() {
336 let packet = generate_packet();
337
338 let timestamp = u32::from_be_bytes([packet[0], packet[1], packet[2], packet[3]]);
340 assert!(timestamp > 0); assert_eq!(&packet[4..8], &[0, 0, 0, 0]);
344 }
345}