rtmp_rs/protocol/
handshake.rs

1//! RTMP handshake implementation
2//!
3//! The RTMP handshake consists of three phases:
4//!
5//! ```text
6//! Client                                   Server
7//!   |                                        |
8//!   |------- C0 (1 byte: version) --------->|
9//!   |------- C1 (1536 bytes: time+random) ->|
10//!   |                                        |
11//!   |<------ S0 (1 byte: version) ----------|
12//!   |<------ S1 (1536 bytes: time+random) --|
13//!   |<------ S2 (1536 bytes: echo C1) ------|
14//!   |                                        |
15//!   |------- C2 (1536 bytes: echo S1) ----->|
16//!   |                                        |
17//!   |          [Handshake Complete]          |
18//! ```
19//!
20//! This implementation uses the "simple" handshake (no HMAC digest).
21//! Complex handshake with HMAC-SHA256 is used by some servers but not required.
22//!
23//! Reference: RTMP Specification Section 5.2
24
25use bytes::{Buf, BufMut, Bytes, BytesMut};
26use std::time::{SystemTime, UNIX_EPOCH};
27
28use crate::error::{HandshakeError, Result};
29use crate::protocol::constants::{HANDSHAKE_SIZE, RTMP_VERSION};
30
31/// Handshake role (client or server)
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum HandshakeRole {
34    Client,
35    Server,
36}
37
38/// Handshake state machine
39#[derive(Debug)]
40pub struct Handshake {
41    role: HandshakeRole,
42    state: HandshakeState,
43    /// Our C1/S1 packet (saved for verification)
44    our_packet: Option<[u8; HANDSHAKE_SIZE]>,
45    /// Peer's C1/S1 packet (saved for echo in C2/S2)
46    peer_packet: Option<[u8; HANDSHAKE_SIZE]>,
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50#[allow(dead_code)] // States are useful documentation, some used only in complex handshake
51enum HandshakeState {
52    /// Initial state - need to send C0C1/S0S1
53    Initial,
54    /// Waiting for peer's C0C1/S0S1
55    WaitingForPeerPacket,
56    /// Received peer packet, need to send C2/S2
57    NeedToSendResponse,
58    /// Waiting for peer's C2/S2
59    WaitingForPeerResponse,
60    /// Handshake complete
61    Done,
62}
63
64impl Handshake {
65    /// Create a new handshake state machine
66    pub fn new(role: HandshakeRole) -> Self {
67        Self {
68            role,
69            state: HandshakeState::Initial,
70            our_packet: None,
71            peer_packet: None,
72        }
73    }
74
75    /// Check if handshake is complete
76    pub fn is_done(&self) -> bool {
77        self.state == HandshakeState::Done
78    }
79
80    /// Get bytes needed before next state transition
81    pub fn bytes_needed(&self) -> usize {
82        match self.state {
83            HandshakeState::Initial => 0,
84            HandshakeState::WaitingForPeerPacket => 1 + HANDSHAKE_SIZE, // C0C1 or S0S1
85            HandshakeState::NeedToSendResponse => 0,
86            HandshakeState::WaitingForPeerResponse => {
87                match self.role {
88                    HandshakeRole::Client => HANDSHAKE_SIZE, // S2 only (S0S1 already received)
89                    HandshakeRole::Server => HANDSHAKE_SIZE, // C2 only
90                }
91            }
92            HandshakeState::Done => 0,
93        }
94    }
95
96    /// Generate initial packet (C0C1 for client, nothing for server initially)
97    ///
98    /// For client: returns C0+C1 (1 + 1536 bytes)
99    /// For server: returns None (server waits for C0C1 first)
100    pub fn generate_initial(&mut self) -> Option<Bytes> {
101        if self.state != HandshakeState::Initial {
102            return None;
103        }
104
105        match self.role {
106            HandshakeRole::Client => {
107                let mut buf = BytesMut::with_capacity(1 + HANDSHAKE_SIZE);
108
109                // C0: Version
110                buf.put_u8(RTMP_VERSION);
111
112                // C1: Time + Zero + Random
113                let c1 = generate_packet();
114                self.our_packet = Some(c1);
115                buf.put_slice(&c1);
116
117                self.state = HandshakeState::WaitingForPeerPacket;
118                Some(buf.freeze())
119            }
120            HandshakeRole::Server => {
121                // Server waits for client's C0C1 first
122                self.state = HandshakeState::WaitingForPeerPacket;
123                None
124            }
125        }
126    }
127
128    /// Process received data and return response if ready
129    ///
130    /// For server receiving C0C1: returns S0+S1+S2
131    /// For client receiving S0S1S2: returns C2
132    /// For server receiving C2: returns None (handshake done)
133    pub fn process(&mut self, data: &mut Bytes) -> Result<Option<Bytes>> {
134        match self.state {
135            HandshakeState::WaitingForPeerPacket => self.process_peer_packet(data),
136            HandshakeState::WaitingForPeerResponse => self.process_peer_response(data),
137            _ => Ok(None),
138        }
139    }
140
141    /// Process peer's initial packet (C0C1 or S0S1S2)
142    fn process_peer_packet(&mut self, data: &mut Bytes) -> Result<Option<Bytes>> {
143        match self.role {
144            HandshakeRole::Server => {
145                // Expecting C0 + C1
146                if data.remaining() < 1 + HANDSHAKE_SIZE {
147                    return Ok(None); // Need more data
148                }
149
150                // C0: Version check
151                let version = data.get_u8();
152                if version != RTMP_VERSION {
153                    // Be lenient - accept version 3-31 (some encoders send different values)
154                    if version < 3 {
155                        return Err(HandshakeError::InvalidVersion(version).into());
156                    }
157                }
158
159                // C1: Save peer packet
160                let mut c1 = [0u8; HANDSHAKE_SIZE];
161                data.copy_to_slice(&mut c1);
162                self.peer_packet = Some(c1);
163
164                // Generate S0 + S1 + S2
165                let mut response = BytesMut::with_capacity(1 + HANDSHAKE_SIZE * 2);
166
167                // S0: Version
168                response.put_u8(RTMP_VERSION);
169
170                // S1: Our packet
171                let s1 = generate_packet();
172                self.our_packet = Some(s1);
173                response.put_slice(&s1);
174
175                // S2: Echo C1 with our timestamp
176                let s2 = generate_echo(&c1);
177                response.put_slice(&s2);
178
179                self.state = HandshakeState::WaitingForPeerResponse;
180                Ok(Some(response.freeze()))
181            }
182            HandshakeRole::Client => {
183                // Expecting S0 + S1 + S2
184                if data.remaining() < 1 + HANDSHAKE_SIZE * 2 {
185                    return Ok(None); // Need more data
186                }
187
188                // S0: Version check
189                let version = data.get_u8();
190                if version != RTMP_VERSION && version < 3 {
191                    return Err(HandshakeError::InvalidVersion(version).into());
192                }
193
194                // S1: Save peer packet
195                let mut s1 = [0u8; HANDSHAKE_SIZE];
196                data.copy_to_slice(&mut s1);
197                self.peer_packet = Some(s1);
198
199                // S2: Verify echo of C1 (lenient - just consume)
200                let mut s2 = [0u8; HANDSHAKE_SIZE];
201                data.copy_to_slice(&mut s2);
202
203                // In lenient mode, don't strictly verify S2 matches C1
204                // Some servers don't echo correctly
205
206                // Generate C2: Echo S1
207                let c2 = generate_echo(&s1);
208
209                self.state = HandshakeState::Done;
210                Ok(Some(Bytes::copy_from_slice(&c2)))
211            }
212        }
213    }
214
215    /// Process peer's response (C2 for server)
216    fn process_peer_response(&mut self, data: &mut Bytes) -> Result<Option<Bytes>> {
217        match self.role {
218            HandshakeRole::Server => {
219                // Expecting C2
220                if data.remaining() < HANDSHAKE_SIZE {
221                    return Ok(None);
222                }
223
224                // C2: Verify echo of S1 (lenient)
225                let mut c2 = [0u8; HANDSHAKE_SIZE];
226                data.copy_to_slice(&mut c2);
227
228                // Lenient: don't strictly verify C2 matches S1
229                self.state = HandshakeState::Done;
230                Ok(None)
231            }
232            HandshakeRole::Client => {
233                // Client shouldn't be in this state
234                self.state = HandshakeState::Done;
235                Ok(None)
236            }
237        }
238    }
239}
240
241/// Generate a handshake packet (C1 or S1)
242///
243/// Format (1536 bytes):
244/// - Bytes 0-3: Timestamp (32-bit, big-endian)
245/// - Bytes 4-7: Zero (for simple handshake) or version (for complex)
246/// - Bytes 8-1535: Random data
247fn generate_packet() -> [u8; HANDSHAKE_SIZE] {
248    let mut packet = [0u8; HANDSHAKE_SIZE];
249
250    // Timestamp: milliseconds since some epoch
251    let timestamp = SystemTime::now()
252        .duration_since(UNIX_EPOCH)
253        .map(|d| d.as_millis() as u32)
254        .unwrap_or(0);
255
256    packet[0..4].copy_from_slice(&timestamp.to_be_bytes());
257
258    // Zero field (simple handshake)
259    packet[4..8].copy_from_slice(&[0, 0, 0, 0]);
260
261    // Random data - use simple PRNG seeded with timestamp
262    // Not cryptographically secure, but RTMP handshake doesn't require it
263    let mut seed = timestamp as u64;
264    for chunk in packet[8..].chunks_mut(8) {
265        seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
266        let bytes = seed.to_le_bytes();
267        let len = chunk.len().min(8);
268        chunk[..len].copy_from_slice(&bytes[..len]);
269    }
270
271    packet
272}
273
274/// Generate echo packet (C2 or S2)
275///
276/// Format:
277/// - Bytes 0-3: Peer's timestamp (from their C1/S1)
278/// - Bytes 4-7: Our timestamp
279/// - Bytes 8-1535: Copy of peer's random data
280fn generate_echo(peer_packet: &[u8; HANDSHAKE_SIZE]) -> [u8; HANDSHAKE_SIZE] {
281    let mut echo = *peer_packet;
282
283    // Bytes 4-7: Our receive timestamp
284    let timestamp = SystemTime::now()
285        .duration_since(UNIX_EPOCH)
286        .map(|d| d.as_millis() as u32)
287        .unwrap_or(0);
288
289    echo[4..8].copy_from_slice(&timestamp.to_be_bytes());
290
291    echo
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297
298    #[test]
299    fn test_client_server_handshake() {
300        let mut client = Handshake::new(HandshakeRole::Client);
301        let mut server = Handshake::new(HandshakeRole::Server);
302
303        // Client generates C0C1
304        let c0c1 = client
305            .generate_initial()
306            .expect("Client should generate C0C1");
307        assert_eq!(c0c1.len(), 1 + HANDSHAKE_SIZE);
308
309        // Server receives C0C1, generates S0S1S2
310        let mut c0c1_buf = c0c1;
311        server.generate_initial(); // Move server to waiting state
312        let s0s1s2 = server
313            .process(&mut c0c1_buf)
314            .unwrap()
315            .expect("Server should generate S0S1S2");
316        assert_eq!(s0s1s2.len(), 1 + HANDSHAKE_SIZE * 2);
317
318        // Client receives S0S1S2, generates C2
319        let mut s0s1s2_buf = s0s1s2;
320        let c2 = client
321            .process(&mut s0s1s2_buf)
322            .unwrap()
323            .expect("Client should generate C2");
324        assert_eq!(c2.len(), HANDSHAKE_SIZE);
325        assert!(client.is_done());
326
327        // Server receives C2
328        let mut c2_buf = c2;
329        let response = server.process(&mut c2_buf).unwrap();
330        assert!(response.is_none());
331        assert!(server.is_done());
332    }
333
334    #[test]
335    fn test_packet_generation() {
336        let packet = generate_packet();
337
338        // Should have timestamp in first 4 bytes
339        let timestamp = u32::from_be_bytes([packet[0], packet[1], packet[2], packet[3]]);
340        assert!(timestamp > 0); // Should be non-zero for reasonable system time
341
342        // Bytes 4-7 should be zero (simple handshake)
343        assert_eq!(&packet[4..8], &[0, 0, 0, 0]);
344    }
345
346    #[test]
347    fn test_handshake_role_enum() {
348        assert_ne!(HandshakeRole::Client, HandshakeRole::Server);
349
350        let client_role = HandshakeRole::Client;
351        let server_role = HandshakeRole::Server;
352
353        assert_eq!(client_role, HandshakeRole::Client);
354        assert_eq!(server_role, HandshakeRole::Server);
355    }
356
357    #[test]
358    fn test_handshake_is_done() {
359        let mut client = Handshake::new(HandshakeRole::Client);
360        assert!(!client.is_done());
361
362        // Generate C0C1
363        let c0c1 = client.generate_initial().unwrap();
364
365        // Still not done
366        assert!(!client.is_done());
367
368        // Create server and process
369        let mut server = Handshake::new(HandshakeRole::Server);
370        server.generate_initial();
371
372        let mut c0c1_buf = c0c1;
373        let s0s1s2 = server.process(&mut c0c1_buf).unwrap().unwrap();
374
375        // Client processes S0S1S2
376        let mut s0s1s2_buf = s0s1s2;
377        let c2 = client.process(&mut s0s1s2_buf).unwrap().unwrap();
378
379        // Client is now done
380        assert!(client.is_done());
381
382        // Server processes C2
383        let mut c2_buf = c2;
384        server.process(&mut c2_buf).unwrap();
385
386        // Server is now done
387        assert!(server.is_done());
388    }
389
390    #[test]
391    fn test_bytes_needed() {
392        let mut client = Handshake::new(HandshakeRole::Client);
393
394        // Initial state - no bytes needed yet
395        assert_eq!(client.bytes_needed(), 0);
396
397        // After generating C0C1, waiting for S0S1 (the impl expects S0S1 first,
398        // then transitions to waiting for S2 in WaitingForPeerResponse)
399        client.generate_initial();
400        assert_eq!(client.bytes_needed(), 1 + HANDSHAKE_SIZE); // S0S1
401
402        let mut server = Handshake::new(HandshakeRole::Server);
403        assert_eq!(server.bytes_needed(), 0);
404
405        // Server waiting for C0C1
406        server.generate_initial();
407        assert_eq!(server.bytes_needed(), 1 + HANDSHAKE_SIZE); // C0C1
408    }
409
410    #[test]
411    fn test_server_initial_returns_none() {
412        let mut server = Handshake::new(HandshakeRole::Server);
413
414        // Server's generate_initial should return None
415        // (server waits for client's C0C1)
416        let result = server.generate_initial();
417        assert!(result.is_none());
418    }
419
420    #[test]
421    fn test_client_initial_returns_c0c1() {
422        let mut client = Handshake::new(HandshakeRole::Client);
423
424        let c0c1 = client.generate_initial().unwrap();
425
426        // Should be C0 (1 byte) + C1 (1536 bytes)
427        assert_eq!(c0c1.len(), 1 + HANDSHAKE_SIZE);
428
429        // C0 should be RTMP version
430        assert_eq!(c0c1[0], RTMP_VERSION);
431    }
432
433    #[test]
434    fn test_double_generate_initial_returns_none() {
435        let mut client = Handshake::new(HandshakeRole::Client);
436
437        // First call should work
438        assert!(client.generate_initial().is_some());
439
440        // Second call should return None (wrong state)
441        assert!(client.generate_initial().is_none());
442    }
443
444    #[test]
445    fn test_echo_packet_preserves_random_data() {
446        let original = generate_packet();
447        let echo = generate_echo(&original);
448
449        // Random data portion (bytes 8-1535) should be preserved
450        assert_eq!(&original[8..], &echo[8..]);
451
452        // Timestamp portion (bytes 0-3) should be preserved
453        assert_eq!(&original[0..4], &echo[0..4]);
454
455        // Bytes 4-7 are our receive timestamp (may differ)
456    }
457
458    #[test]
459    fn test_incomplete_c0c1() {
460        let mut server = Handshake::new(HandshakeRole::Server);
461        server.generate_initial();
462
463        // Send incomplete C0C1 (only 100 bytes instead of 1537)
464        let mut incomplete = Bytes::from(vec![RTMP_VERSION; 100]);
465
466        let result = server.process(&mut incomplete).unwrap();
467        assert!(result.is_none()); // Should need more data
468    }
469
470    #[test]
471    fn test_incomplete_s0s1s2() {
472        let mut client = Handshake::new(HandshakeRole::Client);
473        client.generate_initial();
474
475        // Send incomplete S0S1S2
476        let mut incomplete = Bytes::from(vec![RTMP_VERSION; 1000]);
477
478        let result = client.process(&mut incomplete).unwrap();
479        assert!(result.is_none()); // Should need more data
480    }
481
482    #[test]
483    fn test_invalid_version_rejected() {
484        let mut server = Handshake::new(HandshakeRole::Server);
485        server.generate_initial();
486
487        // Send C0 with invalid version (< 3)
488        let mut invalid = BytesMut::with_capacity(1 + HANDSHAKE_SIZE);
489        invalid.put_u8(2); // Invalid version
490        invalid.put_slice(&[0u8; HANDSHAKE_SIZE]);
491
492        let mut buf = invalid.freeze();
493        let result = server.process(&mut buf);
494
495        assert!(result.is_err());
496    }
497
498    #[test]
499    fn test_lenient_version_acceptance() {
500        let mut server = Handshake::new(HandshakeRole::Server);
501        server.generate_initial();
502
503        // Send C0 with version >= 3 (should be accepted in lenient mode)
504        let mut valid = BytesMut::with_capacity(1 + HANDSHAKE_SIZE);
505        valid.put_u8(31); // Higher version but >= 3
506        valid.put_slice(&generate_packet());
507
508        let mut buf = valid.freeze();
509        let result = server.process(&mut buf);
510
511        // Should succeed (lenient parsing)
512        assert!(result.is_ok());
513        assert!(result.unwrap().is_some());
514    }
515
516    #[test]
517    fn test_handshake_packet_size_constant() {
518        assert_eq!(HANDSHAKE_SIZE, 1536);
519    }
520
521    #[test]
522    fn test_multiple_packets_different_random_data() {
523        let packet1 = generate_packet();
524        let packet2 = generate_packet();
525
526        // Random portions should be different (high probability)
527        // Note: This could theoretically fail with astronomically low probability
528        // Just check they're not all zeros
529        assert!(&packet1[8..100] != &[0u8; 92][..]);
530        assert!(&packet2[8..100] != &[0u8; 92][..]);
531    }
532
533    #[test]
534    fn test_server_c2_processing() {
535        let mut client = Handshake::new(HandshakeRole::Client);
536        let mut server = Handshake::new(HandshakeRole::Server);
537
538        // Full handshake
539        let c0c1 = client.generate_initial().unwrap();
540        server.generate_initial();
541
542        let mut c0c1_buf = c0c1;
543        let s0s1s2 = server.process(&mut c0c1_buf).unwrap().unwrap();
544
545        let mut s0s1s2_buf = s0s1s2;
546        let c2 = client.process(&mut s0s1s2_buf).unwrap().unwrap();
547
548        // Server processes C2
549        let mut c2_buf = c2;
550        let response = server.process(&mut c2_buf).unwrap();
551
552        // Server should return None (no response needed after C2)
553        assert!(response.is_none());
554        assert!(server.is_done());
555    }
556
557    #[test]
558    fn test_process_in_wrong_state() {
559        let mut client = Handshake::new(HandshakeRole::Client);
560
561        // Try to process without generating initial
562        let mut buf = Bytes::from(vec![0u8; 3073]);
563        let result = client.process(&mut buf).unwrap();
564
565        // Should return None (wrong state)
566        assert!(result.is_none());
567    }
568}