rtmp_rs/protocol/
handshake.rs

1//! RTMP handshake implementation
2//!
3//! The RTMP handshake consists of three phases:
4//!
5//! ```text
6//! Client                                   Server
7//!   |                                        |
8//!   |------- C0 (1 byte: version) --------->|
9//!   |------- C1 (1536 bytes: time+random) ->|
10//!   |                                        |
11//!   |<------ S0 (1 byte: version) ----------|
12//!   |<------ S1 (1536 bytes: time+random) --|
13//!   |<------ S2 (1536 bytes: echo C1) ------|
14//!   |                                        |
15//!   |------- C2 (1536 bytes: echo S1) ----->|
16//!   |                                        |
17//!   |          [Handshake Complete]          |
18//! ```
19//!
20//! This implementation uses the "simple" handshake (no HMAC digest).
21//! Complex handshake with HMAC-SHA256 is used by some servers but not required.
22//!
23//! Reference: RTMP Specification Section 5.2
24
25use bytes::{Buf, BufMut, Bytes, BytesMut};
26use std::time::{SystemTime, UNIX_EPOCH};
27
28use crate::error::{HandshakeError, Result};
29use crate::protocol::constants::{HANDSHAKE_SIZE, RTMP_VERSION};
30
31/// Handshake role (client or server)
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum HandshakeRole {
34    Client,
35    Server,
36}
37
38/// Handshake state machine
39#[derive(Debug)]
40pub struct Handshake {
41    role: HandshakeRole,
42    state: HandshakeState,
43    /// Our C1/S1 packet (saved for verification)
44    our_packet: Option<[u8; HANDSHAKE_SIZE]>,
45    /// Peer's C1/S1 packet (saved for echo in C2/S2)
46    peer_packet: Option<[u8; HANDSHAKE_SIZE]>,
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50#[allow(dead_code)] // States are useful documentation, some used only in complex handshake
51enum HandshakeState {
52    /// Initial state - need to send C0C1/S0S1
53    Initial,
54    /// Waiting for peer's C0C1/S0S1
55    WaitingForPeerPacket,
56    /// Received peer packet, need to send C2/S2
57    NeedToSendResponse,
58    /// Waiting for peer's C2/S2
59    WaitingForPeerResponse,
60    /// Handshake complete
61    Done,
62}
63
64impl Handshake {
65    /// Create a new handshake state machine
66    pub fn new(role: HandshakeRole) -> Self {
67        Self {
68            role,
69            state: HandshakeState::Initial,
70            our_packet: None,
71            peer_packet: None,
72        }
73    }
74
75    /// Check if handshake is complete
76    pub fn is_done(&self) -> bool {
77        self.state == HandshakeState::Done
78    }
79
80    /// Get bytes needed before next state transition
81    pub fn bytes_needed(&self) -> usize {
82        match self.state {
83            HandshakeState::Initial => 0,
84            HandshakeState::WaitingForPeerPacket => 1 + HANDSHAKE_SIZE, // C0C1 or S0S1
85            HandshakeState::NeedToSendResponse => 0,
86            HandshakeState::WaitingForPeerResponse => {
87                match self.role {
88                    HandshakeRole::Client => HANDSHAKE_SIZE, // S2 only (S0S1 already received)
89                    HandshakeRole::Server => HANDSHAKE_SIZE, // C2 only
90                }
91            }
92            HandshakeState::Done => 0,
93        }
94    }
95
96    /// Generate initial packet (C0C1 for client, nothing for server initially)
97    ///
98    /// For client: returns C0+C1 (1 + 1536 bytes)
99    /// For server: returns None (server waits for C0C1 first)
100    pub fn generate_initial(&mut self) -> Option<Bytes> {
101        if self.state != HandshakeState::Initial {
102            return None;
103        }
104
105        match self.role {
106            HandshakeRole::Client => {
107                let mut buf = BytesMut::with_capacity(1 + HANDSHAKE_SIZE);
108
109                // C0: Version
110                buf.put_u8(RTMP_VERSION);
111
112                // C1: Time + Zero + Random
113                let c1 = generate_packet();
114                self.our_packet = Some(c1);
115                buf.put_slice(&c1);
116
117                self.state = HandshakeState::WaitingForPeerPacket;
118                Some(buf.freeze())
119            }
120            HandshakeRole::Server => {
121                // Server waits for client's C0C1 first
122                self.state = HandshakeState::WaitingForPeerPacket;
123                None
124            }
125        }
126    }
127
128    /// Process received data and return response if ready
129    ///
130    /// For server receiving C0C1: returns S0+S1+S2
131    /// For client receiving S0S1S2: returns C2
132    /// For server receiving C2: returns None (handshake done)
133    pub fn process(&mut self, data: &mut Bytes) -> Result<Option<Bytes>> {
134        match self.state {
135            HandshakeState::WaitingForPeerPacket => self.process_peer_packet(data),
136            HandshakeState::WaitingForPeerResponse => self.process_peer_response(data),
137            _ => Ok(None),
138        }
139    }
140
141    /// Process peer's initial packet (C0C1 or S0S1S2)
142    fn process_peer_packet(&mut self, data: &mut Bytes) -> Result<Option<Bytes>> {
143        match self.role {
144            HandshakeRole::Server => {
145                // Expecting C0 + C1
146                if data.remaining() < 1 + HANDSHAKE_SIZE {
147                    return Ok(None); // Need more data
148                }
149
150                // C0: Version check
151                let version = data.get_u8();
152                if version != RTMP_VERSION {
153                    // Be lenient - accept version 3-31 (some encoders send different values)
154                    if version < 3 {
155                        return Err(HandshakeError::InvalidVersion(version).into());
156                    }
157                }
158
159                // C1: Save peer packet
160                let mut c1 = [0u8; HANDSHAKE_SIZE];
161                data.copy_to_slice(&mut c1);
162                self.peer_packet = Some(c1);
163
164                // Generate S0 + S1 + S2
165                let mut response = BytesMut::with_capacity(1 + HANDSHAKE_SIZE * 2);
166
167                // S0: Version
168                response.put_u8(RTMP_VERSION);
169
170                // S1: Our packet
171                let s1 = generate_packet();
172                self.our_packet = Some(s1);
173                response.put_slice(&s1);
174
175                // S2: Echo C1 with our timestamp
176                let s2 = generate_echo(&c1);
177                response.put_slice(&s2);
178
179                self.state = HandshakeState::WaitingForPeerResponse;
180                Ok(Some(response.freeze()))
181            }
182            HandshakeRole::Client => {
183                // Expecting S0 + S1 + S2
184                if data.remaining() < 1 + HANDSHAKE_SIZE * 2 {
185                    return Ok(None); // Need more data
186                }
187
188                // S0: Version check
189                let version = data.get_u8();
190                if version != RTMP_VERSION && version < 3 {
191                    return Err(HandshakeError::InvalidVersion(version).into());
192                }
193
194                // S1: Save peer packet
195                let mut s1 = [0u8; HANDSHAKE_SIZE];
196                data.copy_to_slice(&mut s1);
197                self.peer_packet = Some(s1);
198
199                // S2: Verify echo of C1 (lenient - just consume)
200                let mut s2 = [0u8; HANDSHAKE_SIZE];
201                data.copy_to_slice(&mut s2);
202
203                // In lenient mode, don't strictly verify S2 matches C1
204                // Some servers don't echo correctly
205
206                // Generate C2: Echo S1
207                let c2 = generate_echo(&s1);
208
209                self.state = HandshakeState::Done;
210                Ok(Some(Bytes::copy_from_slice(&c2)))
211            }
212        }
213    }
214
215    /// Process peer's response (C2 for server)
216    fn process_peer_response(&mut self, data: &mut Bytes) -> Result<Option<Bytes>> {
217        match self.role {
218            HandshakeRole::Server => {
219                // Expecting C2
220                if data.remaining() < HANDSHAKE_SIZE {
221                    return Ok(None);
222                }
223
224                // C2: Verify echo of S1 (lenient)
225                let mut c2 = [0u8; HANDSHAKE_SIZE];
226                data.copy_to_slice(&mut c2);
227
228                // Lenient: don't strictly verify C2 matches S1
229                self.state = HandshakeState::Done;
230                Ok(None)
231            }
232            HandshakeRole::Client => {
233                // Client shouldn't be in this state
234                self.state = HandshakeState::Done;
235                Ok(None)
236            }
237        }
238    }
239}
240
241/// Generate a handshake packet (C1 or S1)
242///
243/// Format (1536 bytes):
244/// - Bytes 0-3: Timestamp (32-bit, big-endian)
245/// - Bytes 4-7: Zero (for simple handshake) or version (for complex)
246/// - Bytes 8-1535: Random data
247fn generate_packet() -> [u8; HANDSHAKE_SIZE] {
248    let mut packet = [0u8; HANDSHAKE_SIZE];
249
250    // Timestamp: milliseconds since some epoch
251    let timestamp = SystemTime::now()
252        .duration_since(UNIX_EPOCH)
253        .map(|d| d.as_millis() as u32)
254        .unwrap_or(0);
255
256    packet[0..4].copy_from_slice(&timestamp.to_be_bytes());
257
258    // Zero field (simple handshake)
259    packet[4..8].copy_from_slice(&[0, 0, 0, 0]);
260
261    // Random data - use simple PRNG seeded with timestamp
262    // Not cryptographically secure, but RTMP handshake doesn't require it
263    let mut seed = timestamp as u64;
264    for chunk in packet[8..].chunks_mut(8) {
265        seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
266        let bytes = seed.to_le_bytes();
267        let len = chunk.len().min(8);
268        chunk[..len].copy_from_slice(&bytes[..len]);
269    }
270
271    packet
272}
273
274/// Generate echo packet (C2 or S2)
275///
276/// Format:
277/// - Bytes 0-3: Peer's timestamp (from their C1/S1)
278/// - Bytes 4-7: Our timestamp
279/// - Bytes 8-1535: Copy of peer's random data
280fn generate_echo(peer_packet: &[u8; HANDSHAKE_SIZE]) -> [u8; HANDSHAKE_SIZE] {
281    let mut echo = *peer_packet;
282
283    // Bytes 4-7: Our receive timestamp
284    let timestamp = SystemTime::now()
285        .duration_since(UNIX_EPOCH)
286        .map(|d| d.as_millis() as u32)
287        .unwrap_or(0);
288
289    echo[4..8].copy_from_slice(&timestamp.to_be_bytes());
290
291    echo
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297
298    #[test]
299    fn test_client_server_handshake() {
300        let mut client = Handshake::new(HandshakeRole::Client);
301        let mut server = Handshake::new(HandshakeRole::Server);
302
303        // Client generates C0C1
304        let c0c1 = client
305            .generate_initial()
306            .expect("Client should generate C0C1");
307        assert_eq!(c0c1.len(), 1 + HANDSHAKE_SIZE);
308
309        // Server receives C0C1, generates S0S1S2
310        let mut c0c1_buf = c0c1;
311        server.generate_initial(); // Move server to waiting state
312        let s0s1s2 = server
313            .process(&mut c0c1_buf)
314            .unwrap()
315            .expect("Server should generate S0S1S2");
316        assert_eq!(s0s1s2.len(), 1 + HANDSHAKE_SIZE * 2);
317
318        // Client receives S0S1S2, generates C2
319        let mut s0s1s2_buf = s0s1s2;
320        let c2 = client
321            .process(&mut s0s1s2_buf)
322            .unwrap()
323            .expect("Client should generate C2");
324        assert_eq!(c2.len(), HANDSHAKE_SIZE);
325        assert!(client.is_done());
326
327        // Server receives C2
328        let mut c2_buf = c2;
329        let response = server.process(&mut c2_buf).unwrap();
330        assert!(response.is_none());
331        assert!(server.is_done());
332    }
333
334    #[test]
335    fn test_packet_generation() {
336        let packet = generate_packet();
337
338        // Should have timestamp in first 4 bytes
339        let timestamp = u32::from_be_bytes([packet[0], packet[1], packet[2], packet[3]]);
340        assert!(timestamp > 0); // Should be non-zero for reasonable system time
341
342        // Bytes 4-7 should be zero (simple handshake)
343        assert_eq!(&packet[4..8], &[0, 0, 0, 0]);
344    }
345}