rtmp_rs/protocol/
handshake.rs1use bytes::{Buf, BufMut, Bytes, BytesMut};
26use std::time::{SystemTime, UNIX_EPOCH};
27
28use crate::error::{HandshakeError, Result};
29use crate::protocol::constants::{HANDSHAKE_SIZE, RTMP_VERSION};
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum HandshakeRole {
34 Client,
35 Server,
36}
37
38#[derive(Debug)]
40pub struct Handshake {
41 role: HandshakeRole,
42 state: HandshakeState,
43 our_packet: Option<[u8; HANDSHAKE_SIZE]>,
45 peer_packet: Option<[u8; HANDSHAKE_SIZE]>,
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50#[allow(dead_code)] enum HandshakeState {
52 Initial,
54 WaitingForPeerPacket,
56 NeedToSendResponse,
58 WaitingForPeerResponse,
60 Done,
62}
63
64impl Handshake {
65 pub fn new(role: HandshakeRole) -> Self {
67 Self {
68 role,
69 state: HandshakeState::Initial,
70 our_packet: None,
71 peer_packet: None,
72 }
73 }
74
75 pub fn is_done(&self) -> bool {
77 self.state == HandshakeState::Done
78 }
79
80 pub fn bytes_needed(&self) -> usize {
82 match self.state {
83 HandshakeState::Initial => 0,
84 HandshakeState::WaitingForPeerPacket => 1 + HANDSHAKE_SIZE, HandshakeState::NeedToSendResponse => 0,
86 HandshakeState::WaitingForPeerResponse => {
87 match self.role {
88 HandshakeRole::Client => HANDSHAKE_SIZE, HandshakeRole::Server => HANDSHAKE_SIZE, }
91 }
92 HandshakeState::Done => 0,
93 }
94 }
95
96 pub fn generate_initial(&mut self) -> Option<Bytes> {
101 if self.state != HandshakeState::Initial {
102 return None;
103 }
104
105 match self.role {
106 HandshakeRole::Client => {
107 let mut buf = BytesMut::with_capacity(1 + HANDSHAKE_SIZE);
108
109 buf.put_u8(RTMP_VERSION);
111
112 let c1 = generate_packet();
114 self.our_packet = Some(c1);
115 buf.put_slice(&c1);
116
117 self.state = HandshakeState::WaitingForPeerPacket;
118 Some(buf.freeze())
119 }
120 HandshakeRole::Server => {
121 self.state = HandshakeState::WaitingForPeerPacket;
123 None
124 }
125 }
126 }
127
128 pub fn process(&mut self, data: &mut Bytes) -> Result<Option<Bytes>> {
134 match self.state {
135 HandshakeState::WaitingForPeerPacket => self.process_peer_packet(data),
136 HandshakeState::WaitingForPeerResponse => self.process_peer_response(data),
137 _ => Ok(None),
138 }
139 }
140
141 fn process_peer_packet(&mut self, data: &mut Bytes) -> Result<Option<Bytes>> {
143 match self.role {
144 HandshakeRole::Server => {
145 if data.remaining() < 1 + HANDSHAKE_SIZE {
147 return Ok(None); }
149
150 let version = data.get_u8();
152 if version != RTMP_VERSION {
153 if version < 3 {
155 return Err(HandshakeError::InvalidVersion(version).into());
156 }
157 }
158
159 let mut c1 = [0u8; HANDSHAKE_SIZE];
161 data.copy_to_slice(&mut c1);
162 self.peer_packet = Some(c1);
163
164 let mut response = BytesMut::with_capacity(1 + HANDSHAKE_SIZE * 2);
166
167 response.put_u8(RTMP_VERSION);
169
170 let s1 = generate_packet();
172 self.our_packet = Some(s1);
173 response.put_slice(&s1);
174
175 let s2 = generate_echo(&c1);
177 response.put_slice(&s2);
178
179 self.state = HandshakeState::WaitingForPeerResponse;
180 Ok(Some(response.freeze()))
181 }
182 HandshakeRole::Client => {
183 if data.remaining() < 1 + HANDSHAKE_SIZE * 2 {
185 return Ok(None); }
187
188 let version = data.get_u8();
190 if version != RTMP_VERSION && version < 3 {
191 return Err(HandshakeError::InvalidVersion(version).into());
192 }
193
194 let mut s1 = [0u8; HANDSHAKE_SIZE];
196 data.copy_to_slice(&mut s1);
197 self.peer_packet = Some(s1);
198
199 let mut s2 = [0u8; HANDSHAKE_SIZE];
201 data.copy_to_slice(&mut s2);
202
203 let c2 = generate_echo(&s1);
208
209 self.state = HandshakeState::Done;
210 Ok(Some(Bytes::copy_from_slice(&c2)))
211 }
212 }
213 }
214
215 fn process_peer_response(&mut self, data: &mut Bytes) -> Result<Option<Bytes>> {
217 match self.role {
218 HandshakeRole::Server => {
219 if data.remaining() < HANDSHAKE_SIZE {
221 return Ok(None);
222 }
223
224 let mut c2 = [0u8; HANDSHAKE_SIZE];
226 data.copy_to_slice(&mut c2);
227
228 self.state = HandshakeState::Done;
230 Ok(None)
231 }
232 HandshakeRole::Client => {
233 self.state = HandshakeState::Done;
235 Ok(None)
236 }
237 }
238 }
239}
240
241fn generate_packet() -> [u8; HANDSHAKE_SIZE] {
248 let mut packet = [0u8; HANDSHAKE_SIZE];
249
250 let timestamp = SystemTime::now()
252 .duration_since(UNIX_EPOCH)
253 .map(|d| d.as_millis() as u32)
254 .unwrap_or(0);
255
256 packet[0..4].copy_from_slice(×tamp.to_be_bytes());
257
258 packet[4..8].copy_from_slice(&[0, 0, 0, 0]);
260
261 let mut seed = timestamp as u64;
264 for chunk in packet[8..].chunks_mut(8) {
265 seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
266 let bytes = seed.to_le_bytes();
267 let len = chunk.len().min(8);
268 chunk[..len].copy_from_slice(&bytes[..len]);
269 }
270
271 packet
272}
273
274fn generate_echo(peer_packet: &[u8; HANDSHAKE_SIZE]) -> [u8; HANDSHAKE_SIZE] {
281 let mut echo = *peer_packet;
282
283 let timestamp = SystemTime::now()
285 .duration_since(UNIX_EPOCH)
286 .map(|d| d.as_millis() as u32)
287 .unwrap_or(0);
288
289 echo[4..8].copy_from_slice(×tamp.to_be_bytes());
290
291 echo
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297
298 #[test]
299 fn test_client_server_handshake() {
300 let mut client = Handshake::new(HandshakeRole::Client);
301 let mut server = Handshake::new(HandshakeRole::Server);
302
303 let c0c1 = client
305 .generate_initial()
306 .expect("Client should generate C0C1");
307 assert_eq!(c0c1.len(), 1 + HANDSHAKE_SIZE);
308
309 let mut c0c1_buf = c0c1;
311 server.generate_initial(); let s0s1s2 = server
313 .process(&mut c0c1_buf)
314 .unwrap()
315 .expect("Server should generate S0S1S2");
316 assert_eq!(s0s1s2.len(), 1 + HANDSHAKE_SIZE * 2);
317
318 let mut s0s1s2_buf = s0s1s2;
320 let c2 = client
321 .process(&mut s0s1s2_buf)
322 .unwrap()
323 .expect("Client should generate C2");
324 assert_eq!(c2.len(), HANDSHAKE_SIZE);
325 assert!(client.is_done());
326
327 let mut c2_buf = c2;
329 let response = server.process(&mut c2_buf).unwrap();
330 assert!(response.is_none());
331 assert!(server.is_done());
332 }
333
334 #[test]
335 fn test_packet_generation() {
336 let packet = generate_packet();
337
338 let timestamp = u32::from_be_bytes([packet[0], packet[1], packet[2], packet[3]]);
340 assert!(timestamp > 0); assert_eq!(&packet[4..8], &[0, 0, 0, 0]);
344 }
345
346 #[test]
347 fn test_handshake_role_enum() {
348 assert_ne!(HandshakeRole::Client, HandshakeRole::Server);
349
350 let client_role = HandshakeRole::Client;
351 let server_role = HandshakeRole::Server;
352
353 assert_eq!(client_role, HandshakeRole::Client);
354 assert_eq!(server_role, HandshakeRole::Server);
355 }
356
357 #[test]
358 fn test_handshake_is_done() {
359 let mut client = Handshake::new(HandshakeRole::Client);
360 assert!(!client.is_done());
361
362 let c0c1 = client.generate_initial().unwrap();
364
365 assert!(!client.is_done());
367
368 let mut server = Handshake::new(HandshakeRole::Server);
370 server.generate_initial();
371
372 let mut c0c1_buf = c0c1;
373 let s0s1s2 = server.process(&mut c0c1_buf).unwrap().unwrap();
374
375 let mut s0s1s2_buf = s0s1s2;
377 let c2 = client.process(&mut s0s1s2_buf).unwrap().unwrap();
378
379 assert!(client.is_done());
381
382 let mut c2_buf = c2;
384 server.process(&mut c2_buf).unwrap();
385
386 assert!(server.is_done());
388 }
389
390 #[test]
391 fn test_bytes_needed() {
392 let mut client = Handshake::new(HandshakeRole::Client);
393
394 assert_eq!(client.bytes_needed(), 0);
396
397 client.generate_initial();
400 assert_eq!(client.bytes_needed(), 1 + HANDSHAKE_SIZE); let mut server = Handshake::new(HandshakeRole::Server);
403 assert_eq!(server.bytes_needed(), 0);
404
405 server.generate_initial();
407 assert_eq!(server.bytes_needed(), 1 + HANDSHAKE_SIZE); }
409
410 #[test]
411 fn test_server_initial_returns_none() {
412 let mut server = Handshake::new(HandshakeRole::Server);
413
414 let result = server.generate_initial();
417 assert!(result.is_none());
418 }
419
420 #[test]
421 fn test_client_initial_returns_c0c1() {
422 let mut client = Handshake::new(HandshakeRole::Client);
423
424 let c0c1 = client.generate_initial().unwrap();
425
426 assert_eq!(c0c1.len(), 1 + HANDSHAKE_SIZE);
428
429 assert_eq!(c0c1[0], RTMP_VERSION);
431 }
432
433 #[test]
434 fn test_double_generate_initial_returns_none() {
435 let mut client = Handshake::new(HandshakeRole::Client);
436
437 assert!(client.generate_initial().is_some());
439
440 assert!(client.generate_initial().is_none());
442 }
443
444 #[test]
445 fn test_echo_packet_preserves_random_data() {
446 let original = generate_packet();
447 let echo = generate_echo(&original);
448
449 assert_eq!(&original[8..], &echo[8..]);
451
452 assert_eq!(&original[0..4], &echo[0..4]);
454
455 }
457
458 #[test]
459 fn test_incomplete_c0c1() {
460 let mut server = Handshake::new(HandshakeRole::Server);
461 server.generate_initial();
462
463 let mut incomplete = Bytes::from(vec![RTMP_VERSION; 100]);
465
466 let result = server.process(&mut incomplete).unwrap();
467 assert!(result.is_none()); }
469
470 #[test]
471 fn test_incomplete_s0s1s2() {
472 let mut client = Handshake::new(HandshakeRole::Client);
473 client.generate_initial();
474
475 let mut incomplete = Bytes::from(vec![RTMP_VERSION; 1000]);
477
478 let result = client.process(&mut incomplete).unwrap();
479 assert!(result.is_none()); }
481
482 #[test]
483 fn test_invalid_version_rejected() {
484 let mut server = Handshake::new(HandshakeRole::Server);
485 server.generate_initial();
486
487 let mut invalid = BytesMut::with_capacity(1 + HANDSHAKE_SIZE);
489 invalid.put_u8(2); invalid.put_slice(&[0u8; HANDSHAKE_SIZE]);
491
492 let mut buf = invalid.freeze();
493 let result = server.process(&mut buf);
494
495 assert!(result.is_err());
496 }
497
498 #[test]
499 fn test_lenient_version_acceptance() {
500 let mut server = Handshake::new(HandshakeRole::Server);
501 server.generate_initial();
502
503 let mut valid = BytesMut::with_capacity(1 + HANDSHAKE_SIZE);
505 valid.put_u8(31); valid.put_slice(&generate_packet());
507
508 let mut buf = valid.freeze();
509 let result = server.process(&mut buf);
510
511 assert!(result.is_ok());
513 assert!(result.unwrap().is_some());
514 }
515
516 #[test]
517 fn test_handshake_packet_size_constant() {
518 assert_eq!(HANDSHAKE_SIZE, 1536);
519 }
520
521 #[test]
522 fn test_multiple_packets_different_random_data() {
523 let packet1 = generate_packet();
524 let packet2 = generate_packet();
525
526 assert!(&packet1[8..100] != &[0u8; 92][..]);
530 assert!(&packet2[8..100] != &[0u8; 92][..]);
531 }
532
533 #[test]
534 fn test_server_c2_processing() {
535 let mut client = Handshake::new(HandshakeRole::Client);
536 let mut server = Handshake::new(HandshakeRole::Server);
537
538 let c0c1 = client.generate_initial().unwrap();
540 server.generate_initial();
541
542 let mut c0c1_buf = c0c1;
543 let s0s1s2 = server.process(&mut c0c1_buf).unwrap().unwrap();
544
545 let mut s0s1s2_buf = s0s1s2;
546 let c2 = client.process(&mut s0s1s2_buf).unwrap().unwrap();
547
548 let mut c2_buf = c2;
550 let response = server.process(&mut c2_buf).unwrap();
551
552 assert!(response.is_none());
554 assert!(server.is_done());
555 }
556
557 #[test]
558 fn test_process_in_wrong_state() {
559 let mut client = Handshake::new(HandshakeRole::Client);
560
561 let mut buf = Bytes::from(vec![0u8; 3073]);
563 let result = client.process(&mut buf).unwrap();
564
565 assert!(result.is_none());
567 }
568}