Skip to main content

aerosocket_core/
handshake.rs

1//! WebSocket handshake implementation
2//!
3//! This module provides the WebSocket handshake functionality as defined in RFC 6455.
4//! It includes both client and server handshake logic.
5
6use crate::error::{Error, ProtocolError};
7use crate::protocol::constants::*;
8use crate::protocol::http_header::*;
9use crate::protocol::http_method;
10use crate::protocol::http_status::*;
11use crate::protocol::http_value;
12use base64::{engine::general_purpose, Engine as _};
13use sha1::{Digest, Sha1};
14use std::collections::HashMap;
15
16/// Authentication methods for WebSocket handshake
17#[derive(Debug, Clone)]
18pub enum Auth {
19    /// HTTP Basic authentication
20    Basic {
21        /// Username
22        username: String,
23        /// Password
24        password: String,
25    },
26    /// Bearer token authentication
27    Bearer {
28        /// Bearer token
29        token: String,
30    },
31}
32
33/// WebSocket handshake request information
34#[derive(Debug, Clone)]
35pub struct HandshakeRequest {
36    /// HTTP method (should be GET)
37    pub method: String,
38    /// Request URI
39    pub uri: String,
40    /// HTTP version
41    pub version: String,
42    /// HTTP headers
43    pub headers: HashMap<String, String>,
44    /// Request body (should be empty for WebSocket handshake)
45    pub body: Vec<u8>,
46}
47
48/// Compression configuration for WebSocket connections
49#[derive(Debug, Clone)]
50pub struct CompressionConfig {
51    /// Whether compression is enabled
52    pub enabled: bool,
53    /// Maximum window size for decompression (client to server)
54    pub client_max_window_bits: Option<u8>,
55    /// Maximum window size for decompression (server to client)
56    pub server_max_window_bits: Option<u8>,
57    /// Compression level (0-9, where 9 is maximum compression)
58    pub compression_level: Option<u32>,
59}
60
61impl Default for CompressionConfig {
62    fn default() -> Self {
63        Self {
64            enabled: false,
65            client_max_window_bits: Some(15),
66            server_max_window_bits: Some(15),
67            compression_level: Some(6),
68        }
69    }
70}
71
72/// WebSocket handshake response
73#[derive(Debug, Clone)]
74pub struct HandshakeResponse {
75    /// HTTP status code
76    pub status: u16,
77    /// HTTP status message
78    pub status_message: String,
79    /// HTTP headers
80    pub headers: HashMap<String, String>,
81    /// Response body (should be empty for WebSocket handshake)
82    pub body: Vec<u8>,
83}
84
85/// WebSocket handshake configuration
86#[derive(Debug, Clone, Default)]
87pub struct HandshakeConfig {
88    /// WebSocket protocols to offer/accept
89    pub protocols: Vec<String>,
90    /// WebSocket extensions to offer/accept
91    pub extensions: Vec<String>,
92    /// Origin to send (client only)
93    pub origin: Option<String>,
94    /// Allowed origins for CORS (server only, empty means allow all)
95    pub allowed_origins: Vec<String>,
96    /// Host header value (client only)
97    pub host: Option<String>,
98    /// Authentication
99    pub auth: Option<Auth>,
100    /// Compression configuration
101    pub compression: CompressionConfig,
102    /// Additional headers
103    pub extra_headers: HashMap<String, String>,
104}
105
106/// Generate a random WebSocket key
107pub fn generate_key() -> String {
108    use rand::RngCore;
109    let mut key_bytes = [0u8; 16];
110    rand::rng().fill_bytes(&mut key_bytes);
111    general_purpose::STANDARD.encode(key_bytes)
112}
113
114/// Compute WebSocket accept key from client key
115pub fn compute_accept_key(client_key: &str) -> Result<String, Error> {
116    let combined = format!("{}{}", client_key, WEBSOCKET_MAGIC);
117    let hash = Sha1::digest(combined.as_bytes());
118    Ok(general_purpose::STANDARD.encode(hash))
119}
120
121/// Validate WebSocket key format
122pub fn validate_key(key: &str) -> bool {
123    key.len() == 24 && general_purpose::STANDARD.decode(key).is_ok()
124}
125
126/// Validate WebSocket version
127pub fn validate_version(version: &str) -> bool {
128    version == WEBSOCKET_VERSION
129}
130
131/// Create a client handshake request
132pub fn create_client_handshake(
133    uri: &str,
134    config: &HandshakeConfig,
135) -> Result<HandshakeRequest, Error> {
136    let mut headers = HashMap::new();
137
138    // Required headers
139    headers.insert(
140        HEADER_UPGRADE.to_string(),
141        http_value::WEBSOCKET.to_string(),
142    );
143    headers.insert(
144        HEADER_CONNECTION.to_string(),
145        http_value::UPGRADE.to_string(),
146    );
147    headers.insert(HEADER_SEC_WEBSOCKET_KEY.to_string(), generate_key());
148    headers.insert(
149        HEADER_SEC_WEBSOCKET_VERSION.to_string(),
150        WEBSOCKET_VERSION.to_string(),
151    );
152
153    // Optional headers
154    if let Some(host) = &config.host {
155        headers.insert(HOST.to_string(), host.clone());
156    }
157
158    if let Some(origin) = &config.origin {
159        headers.insert(ORIGIN.to_string(), origin.clone());
160    }
161
162    if !config.protocols.is_empty() {
163        headers.insert(
164            HEADER_SEC_WEBSOCKET_PROTOCOL.to_string(),
165            config.protocols.join(", "),
166        );
167    }
168
169    if !config.extensions.is_empty() {
170        headers.insert(
171            HEADER_SEC_WEBSOCKET_EXTENSIONS.to_string(),
172            config.extensions.join(", "),
173        );
174    }
175
176    // Add compression extension if enabled
177    #[cfg(feature = "compression")]
178    if config.compression.enabled {
179        let mut ext_parts: Vec<String> = vec!["permessage-deflate".to_string()];
180        if let Some(bits) = config.compression.client_max_window_bits {
181            ext_parts.push(format!("client_max_window_bits={}", bits));
182        }
183        if let Some(bits) = config.compression.server_max_window_bits {
184            ext_parts.push(format!("server_max_window_bits={}", bits));
185        }
186        let compression_ext = ext_parts.join("; ");
187        let existing = headers
188            .get(HEADER_SEC_WEBSOCKET_EXTENSIONS)
189            .cloned()
190            .unwrap_or_default();
191        let new_value = if existing.is_empty() {
192            compression_ext
193        } else {
194            format!("{}, {}", existing, compression_ext)
195        };
196        headers.insert(HEADER_SEC_WEBSOCKET_EXTENSIONS.to_string(), new_value);
197    }
198
199    // Add authentication header
200    if let Some(auth) = &config.auth {
201        let value = match auth {
202            Auth::Basic { username, password } => {
203                let credentials = format!("{}:{}", username, password);
204                format!("Basic {}", general_purpose::STANDARD.encode(credentials))
205            }
206            Auth::Bearer { token } => format!("Bearer {}", token),
207        };
208        headers.insert(AUTHORIZATION.to_string(), value);
209    }
210
211    // Add extra headers
212    for (key, value) in &config.extra_headers {
213        headers.insert(key.clone(), value.clone());
214    }
215
216    Ok(HandshakeRequest {
217        method: http_method::GET.to_string(),
218        uri: uri.to_string(),
219        version: "HTTP/1.1".to_string(),
220        headers,
221        body: vec![],
222    })
223}
224
225/// Parse a client handshake request
226pub fn parse_client_handshake(request: &str) -> Result<HandshakeRequest, Error> {
227    let mut lines = request.lines();
228
229    // Parse request line
230    let request_line = lines.next().ok_or_else(|| {
231        Error::Protocol(ProtocolError::InvalidFormat(
232            "Missing request line".to_string(),
233        ))
234    })?;
235
236    let mut parts = request_line.split_whitespace();
237    let method = parts
238        .next()
239        .ok_or_else(|| Error::Protocol(ProtocolError::InvalidFormat("Missing method".to_string())))?
240        .to_string();
241
242    let uri = parts
243        .next()
244        .ok_or_else(|| Error::Protocol(ProtocolError::InvalidFormat("Missing URI".to_string())))?
245        .to_string();
246
247    let version = parts
248        .next()
249        .ok_or_else(|| {
250            Error::Protocol(ProtocolError::InvalidFormat(
251                "Missing HTTP version".to_string(),
252            ))
253        })?
254        .to_string();
255
256    // Validate method
257    if method != http_method::GET {
258        return Err(Error::Protocol(ProtocolError::InvalidMethod(method)));
259    }
260
261    // Parse headers
262    let mut headers = HashMap::new();
263    for line in lines {
264        if line.is_empty() {
265            break; // End of headers
266        }
267
268        if let Some((key, value)) = line.split_once(':') {
269            headers.insert(key.trim().to_lowercase(), value.trim().to_string());
270        } else {
271            return Err(Error::Protocol(ProtocolError::InvalidHeader {
272                header: "unknown".to_string(),
273                value: line.to_string(),
274            }));
275        }
276    }
277
278    Ok(HandshakeRequest {
279        method,
280        uri,
281        version,
282        headers,
283        body: vec![],
284    })
285}
286
287/// Validate a client handshake request
288pub fn validate_client_handshake(
289    request: &HandshakeRequest,
290    config: &HandshakeConfig,
291) -> Result<(), Error> {
292    // Check required headers
293    let upgrade = request
294        .headers
295        .get(HEADER_UPGRADE)
296        .ok_or_else(|| Error::Protocol(ProtocolError::MissingHeader(HEADER_UPGRADE.to_string())))?;
297
298    if upgrade.to_lowercase() != http_value::WEBSOCKET {
299        return Err(Error::Protocol(ProtocolError::InvalidHeader {
300            header: HEADER_UPGRADE.to_string(),
301            value: upgrade.clone(),
302        }));
303    }
304
305    let connection = request.headers.get(HEADER_CONNECTION).ok_or_else(|| {
306        Error::Protocol(ProtocolError::MissingHeader(HEADER_CONNECTION.to_string()))
307    })?;
308
309    if !connection.to_lowercase().contains("upgrade") {
310        return Err(Error::Protocol(ProtocolError::InvalidHeader {
311            header: HEADER_CONNECTION.to_string(),
312            value: connection.clone(),
313        }));
314    }
315
316    let key = request
317        .headers
318        .get(HEADER_SEC_WEBSOCKET_KEY)
319        .ok_or_else(|| {
320            Error::Protocol(ProtocolError::MissingHeader(
321                HEADER_SEC_WEBSOCKET_KEY.to_string(),
322            ))
323        })?;
324
325    if !validate_key(key) {
326        return Err(Error::Protocol(ProtocolError::InvalidHeader {
327            header: HEADER_SEC_WEBSOCKET_KEY.to_string(),
328            value: key.clone(),
329        }));
330    }
331
332    let version = request
333        .headers
334        .get(HEADER_SEC_WEBSOCKET_VERSION)
335        .ok_or_else(|| {
336            Error::Protocol(ProtocolError::MissingHeader(
337                HEADER_SEC_WEBSOCKET_VERSION.to_string(),
338            ))
339        })?;
340
341    if !validate_version(version) {
342        return Err(Error::Protocol(ProtocolError::InvalidHeader {
343            header: HEADER_SEC_WEBSOCKET_VERSION.to_string(),
344            value: version.clone(),
345        }));
346    }
347
348    // Check optional headers
349    if !config.allowed_origins.is_empty() {
350        if let Some(client_origin) = request.headers.get(ORIGIN) {
351            if !config.allowed_origins.contains(client_origin) {
352                return Err(Error::Protocol(ProtocolError::InvalidOrigin {
353                    expected: config.allowed_origins.join(", "),
354                    received: client_origin.clone(),
355                }));
356            }
357        }
358    }
359
360    if !config.protocols.is_empty() {
361        if let Some(protocol_header) = request.headers.get(HEADER_SEC_WEBSOCKET_PROTOCOL) {
362            let client_protocols: Vec<&str> =
363                protocol_header.split(',').map(|s| s.trim()).collect();
364            if !client_protocols
365                .iter()
366                .any(|p| config.protocols.contains(&p.to_string()))
367            {
368                return Err(Error::Protocol(ProtocolError::UnsupportedProtocol(
369                    protocol_header.clone(),
370                )));
371            }
372        } else {
373            return Err(Error::Protocol(ProtocolError::MissingHeader(
374                HEADER_SEC_WEBSOCKET_PROTOCOL.to_string(),
375            )));
376        }
377    }
378
379    Ok(())
380}
381
382/// Create a server handshake response
383pub fn create_server_handshake(
384    request: &HandshakeRequest,
385    config: &HandshakeConfig,
386) -> Result<HandshakeResponse, Error> {
387    let mut headers = HashMap::new();
388
389    // Required headers
390    headers.insert(
391        HEADER_UPGRADE.to_string(),
392        http_value::WEBSOCKET.to_string(),
393    );
394    headers.insert(
395        HEADER_CONNECTION.to_string(),
396        http_value::UPGRADE.to_string(),
397    );
398
399    // Compute accept key
400    if let Some(client_key) = request.headers.get(HEADER_SEC_WEBSOCKET_KEY) {
401        let accept_key = compute_accept_key(client_key)?;
402        headers.insert(HEADER_SEC_WEBSOCKET_ACCEPT.to_string(), accept_key);
403    } else {
404        return Err(Error::Protocol(ProtocolError::MissingHeader(
405            HEADER_SEC_WEBSOCKET_KEY.to_string(),
406        )));
407    }
408
409    // Protocol negotiation
410    if !config.protocols.is_empty() {
411        if let Some(protocol_header) = request.headers.get(HEADER_SEC_WEBSOCKET_PROTOCOL) {
412            let client_protocols: Vec<&str> =
413                protocol_header.split(',').map(|s| s.trim()).collect();
414            for protocol in &config.protocols {
415                if client_protocols.contains(&protocol.as_str()) {
416                    headers.insert(HEADER_SEC_WEBSOCKET_PROTOCOL.to_string(), protocol.clone());
417                    break;
418                }
419            }
420        }
421    }
422
423    // Extension negotiation
424    #[cfg(feature = "compression")]
425    if config.compression.enabled {
426        if let Some(ext_header) = request.headers.get(HEADER_SEC_WEBSOCKET_EXTENSIONS) {
427            if ext_header.contains("permessage-deflate") {
428                let mut ext_parts: Vec<String> = vec!["permessage-deflate".to_string()];
429                if let Some(bits) = config.compression.server_max_window_bits {
430                    ext_parts.push(format!("server_max_window_bits={}", bits));
431                }
432                if let Some(bits) = config.compression.client_max_window_bits {
433                    ext_parts.push(format!("client_max_window_bits={}", bits));
434                }
435                headers.insert(
436                    HEADER_SEC_WEBSOCKET_EXTENSIONS.to_string(),
437                    ext_parts.join("; "),
438                );
439            }
440        }
441    }
442
443    // Add extra headers
444    for (key, value) in &config.extra_headers {
445        headers.insert(key.clone(), value.clone());
446    }
447
448    Ok(HandshakeResponse {
449        status: SWITCHING_PROTOCOLS,
450        status_message: "Switching Protocols".to_string(),
451        headers,
452        body: vec![],
453    })
454}
455
456/// Parse a server handshake response
457pub fn parse_server_handshake(response: &str) -> Result<HandshakeResponse, Error> {
458    let mut lines = response.lines();
459
460    // Parse status line
461    let status_line = lines.next().ok_or_else(|| {
462        Error::Protocol(ProtocolError::InvalidFormat(
463            "Missing status line".to_string(),
464        ))
465    })?;
466
467    let mut parts = status_line.split_whitespace();
468    let _version = parts
469        .next()
470        .ok_or_else(|| {
471            Error::Protocol(ProtocolError::InvalidFormat(
472                "Missing HTTP version".to_string(),
473            ))
474        })?
475        .to_string();
476
477    let status_str = parts.next().ok_or_else(|| {
478        Error::Protocol(ProtocolError::InvalidFormat(
479            "Missing status code".to_string(),
480        ))
481    })?;
482
483    let status = status_str.parse::<u16>().map_err(|_| {
484        Error::Protocol(ProtocolError::InvalidFormat(
485            "Invalid status code".to_string(),
486        ))
487    })?;
488
489    let status_message = parts.collect::<Vec<&str>>().join(" ");
490
491    // Parse headers
492    let mut headers = HashMap::new();
493    for line in lines {
494        if line.is_empty() {
495            break; // End of headers
496        }
497
498        if let Some((key, value)) = line.split_once(':') {
499            headers.insert(key.trim().to_lowercase(), value.trim().to_string());
500        } else {
501            return Err(Error::Protocol(ProtocolError::InvalidHeader {
502                header: "unknown".to_string(),
503                value: line.to_string(),
504            }));
505        }
506    }
507
508    Ok(HandshakeResponse {
509        status,
510        status_message,
511        headers,
512        body: vec![],
513    })
514}
515
516/// Validate a server handshake response
517pub fn validate_server_handshake(
518    response: &HandshakeResponse,
519    client_key: &str,
520) -> Result<(), Error> {
521    // Check status code
522    if response.status != SWITCHING_PROTOCOLS {
523        return Err(Error::Protocol(ProtocolError::UnexpectedStatus(
524            response.status,
525        )));
526    }
527
528    // Check required headers
529    let upgrade = response
530        .headers
531        .get(HEADER_UPGRADE)
532        .ok_or_else(|| Error::Protocol(ProtocolError::MissingHeader(HEADER_UPGRADE.to_string())))?;
533
534    if upgrade.to_lowercase() != http_value::WEBSOCKET {
535        return Err(Error::Protocol(ProtocolError::InvalidHeader {
536            header: HEADER_UPGRADE.to_string(),
537            value: upgrade.clone(),
538        }));
539    }
540
541    let connection = response.headers.get(HEADER_CONNECTION).ok_or_else(|| {
542        Error::Protocol(ProtocolError::MissingHeader(HEADER_CONNECTION.to_string()))
543    })?;
544
545    if !connection.to_lowercase().contains("upgrade") {
546        return Err(Error::Protocol(ProtocolError::InvalidHeader {
547            header: HEADER_CONNECTION.to_string(),
548            value: connection.clone(),
549        }));
550    }
551
552    let accept = response
553        .headers
554        .get(HEADER_SEC_WEBSOCKET_ACCEPT)
555        .ok_or_else(|| {
556            Error::Protocol(ProtocolError::MissingHeader(
557                HEADER_SEC_WEBSOCKET_ACCEPT.to_string(),
558            ))
559        })?;
560
561    let expected_accept = compute_accept_key(client_key)?;
562    if accept.as_str() != expected_accept {
563        return Err(Error::Protocol(ProtocolError::InvalidAcceptKey {
564            expected: expected_accept,
565            received: accept.clone(),
566        }));
567    }
568
569    Ok(())
570}
571
572/// Convert handshake request to HTTP string
573pub fn request_to_string(request: &HandshakeRequest) -> String {
574    let mut lines = vec![format!(
575        "{} {} {}",
576        request.method, request.uri, request.version
577    )];
578
579    for (key, value) in &request.headers {
580        lines.push(format!("{}: {}", key, value));
581    }
582
583    lines.push(String::new()); // Empty line after headers
584    lines.join("\r\n")
585}
586
587/// Convert handshake response to HTTP string
588pub fn response_to_string(response: &HandshakeResponse) -> String {
589    let mut lines = vec![format!(
590        "HTTP/1.1 {} {}",
591        response.status, response.status_message
592    )];
593
594    for (key, value) in &response.headers {
595        lines.push(format!("{}: {}", key, value));
596    }
597
598    lines.push(String::new()); // Empty line after headers
599    lines.join("\r\n")
600}
601
602#[cfg(test)]
603mod tests {
604    use super::*;
605
606    #[test]
607    fn test_key_generation() {
608        let key = generate_key();
609        assert_eq!(key.len(), 24);
610        assert!(validate_key(&key));
611    }
612
613    #[test]
614    fn test_accept_key_calculation() {
615        let key = "dGhlIHNhbXBsZSBub25jZQ=="; // "the sample nonce"
616        let expected = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=";
617        let accept = compute_accept_key(key).unwrap();
618        assert_eq!(accept, expected);
619    }
620
621    #[test]
622    fn test_client_handshake_creation() {
623        let config = HandshakeConfig {
624            host: Some("example.com".to_string()),
625            protocols: vec!["chat".to_string()],
626            ..Default::default()
627        };
628
629        let request = create_client_handshake("ws://example.com/chat", &config).unwrap();
630        assert_eq!(request.method, "GET");
631        assert_eq!(request.uri, "ws://example.com/chat");
632        assert_eq!(request.headers.get("upgrade").unwrap(), "websocket");
633        assert_eq!(
634            request.headers.get("sec-websocket-protocol").unwrap(),
635            "chat"
636        );
637    }
638
639    #[test]
640    fn test_client_handshake_parsing() {
641        let raw_request = r#"GET /chat HTTP/1.1
642Host: example.com
643Upgrade: websocket
644Connection: Upgrade
645Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
646Sec-WebSocket-Version: 13
647
648"#;
649
650        let request = parse_client_handshake(raw_request).unwrap();
651        assert_eq!(request.method, "GET");
652        assert_eq!(request.uri, "/chat");
653        assert_eq!(request.headers.get("upgrade").unwrap(), "websocket");
654    }
655}