aerosocket_core/
handshake.rs

1//! WebSocket handshake implementation
2//!
3//! This module provides the WebSocket handshake functionality as defined in RFC 6455.
4//! It includes both client and server handshake logic.
5
6use crate::error::{Error, ProtocolError};
7use crate::protocol::constants::*;
8use crate::protocol::http_header::*;
9use crate::protocol::http_method;
10use crate::protocol::http_status::*;
11use crate::protocol::http_value;
12use base64::{engine::general_purpose, Engine as _};
13use sha1::{Digest, Sha1};
14use std::collections::HashMap;
15
16/// WebSocket handshake request information
17#[derive(Debug, Clone)]
18pub struct HandshakeRequest {
19    /// HTTP method (should be GET)
20    pub method: String,
21    /// Request URI
22    pub uri: String,
23    /// HTTP version
24    pub version: String,
25    /// HTTP headers
26    pub headers: HashMap<String, String>,
27    /// Request body (should be empty for WebSocket handshake)
28    pub body: Vec<u8>,
29}
30
31/// WebSocket handshake response information
32#[derive(Debug, Clone)]
33pub struct HandshakeResponse {
34    /// HTTP status code
35    pub status: u16,
36    /// HTTP status message
37    pub status_message: String,
38    /// HTTP headers
39    pub headers: HashMap<String, String>,
40    /// Response body (should be empty for WebSocket handshake)
41    pub body: Vec<u8>,
42}
43
44/// WebSocket handshake configuration
45#[derive(Debug, Clone, Default)]
46pub struct HandshakeConfig {
47    /// WebSocket protocols to offer/accept
48    pub protocols: Vec<String>,
49    /// WebSocket extensions to offer/accept
50    pub extensions: Vec<String>,
51    /// Origin to check against (server only)
52    pub origin: Option<String>,
53    /// Host header value (client only)
54    pub host: Option<String>,
55    /// Additional headers
56    pub extra_headers: HashMap<String, String>,
57}
58
59/// Generate a random WebSocket key
60pub fn generate_key() -> String {
61    use rand::RngCore;
62    let mut key_bytes = [0u8; 16];
63    rand::thread_rng().fill_bytes(&mut key_bytes);
64    general_purpose::STANDARD.encode(key_bytes)
65}
66
67/// Compute WebSocket accept key from client key
68pub fn compute_accept_key(client_key: &str) -> Result<String, Error> {
69    let combined = format!("{}{}", client_key, WEBSOCKET_MAGIC);
70    let hash = Sha1::digest(combined.as_bytes());
71    Ok(general_purpose::STANDARD.encode(hash))
72}
73
74/// Validate WebSocket key format
75pub fn validate_key(key: &str) -> bool {
76    key.len() == 24 && general_purpose::STANDARD.decode(key).is_ok()
77}
78
79/// Validate WebSocket version
80pub fn validate_version(version: &str) -> bool {
81    version == WEBSOCKET_VERSION
82}
83
84/// Create a client handshake request
85pub fn create_client_handshake(
86    uri: &str,
87    config: &HandshakeConfig,
88) -> Result<HandshakeRequest, Error> {
89    let mut headers = HashMap::new();
90
91    // Required headers
92    headers.insert(
93        HEADER_UPGRADE.to_string(),
94        http_value::WEBSOCKET.to_string(),
95    );
96    headers.insert(
97        HEADER_CONNECTION.to_string(),
98        http_value::UPGRADE.to_string(),
99    );
100    headers.insert(HEADER_SEC_WEBSOCKET_KEY.to_string(), generate_key());
101    headers.insert(
102        HEADER_SEC_WEBSOCKET_VERSION.to_string(),
103        WEBSOCKET_VERSION.to_string(),
104    );
105
106    // Optional headers
107    if let Some(host) = &config.host {
108        headers.insert(HOST.to_string(), host.clone());
109    }
110
111    if let Some(origin) = &config.origin {
112        headers.insert(ORIGIN.to_string(), origin.clone());
113    }
114
115    if !config.protocols.is_empty() {
116        headers.insert(
117            HEADER_SEC_WEBSOCKET_PROTOCOL.to_string(),
118            config.protocols.join(", "),
119        );
120    }
121
122    if !config.extensions.is_empty() {
123        headers.insert(
124            HEADER_SEC_WEBSOCKET_EXTENSIONS.to_string(),
125            config.extensions.join(", "),
126        );
127    }
128
129    // Add extra headers
130    for (key, value) in &config.extra_headers {
131        headers.insert(key.clone(), value.clone());
132    }
133
134    Ok(HandshakeRequest {
135        method: http_method::GET.to_string(),
136        uri: uri.to_string(),
137        version: "HTTP/1.1".to_string(),
138        headers,
139        body: vec![],
140    })
141}
142
143/// Parse a client handshake request
144pub fn parse_client_handshake(request: &str) -> Result<HandshakeRequest, Error> {
145    let mut lines = request.lines();
146
147    // Parse request line
148    let request_line = lines.next().ok_or_else(|| {
149        Error::Protocol(ProtocolError::InvalidFormat(
150            "Missing request line".to_string(),
151        ))
152    })?;
153
154    let mut parts = request_line.split_whitespace();
155    let method = parts
156        .next()
157        .ok_or_else(|| Error::Protocol(ProtocolError::InvalidFormat("Missing method".to_string())))?
158        .to_string();
159
160    let uri = parts
161        .next()
162        .ok_or_else(|| Error::Protocol(ProtocolError::InvalidFormat("Missing URI".to_string())))?
163        .to_string();
164
165    let version = parts
166        .next()
167        .ok_or_else(|| {
168            Error::Protocol(ProtocolError::InvalidFormat(
169                "Missing HTTP version".to_string(),
170            ))
171        })?
172        .to_string();
173
174    // Validate method
175    if method != http_method::GET {
176        return Err(Error::Protocol(ProtocolError::InvalidMethod(method)));
177    }
178
179    // Parse headers
180    let mut headers = HashMap::new();
181    for line in lines {
182        if line.is_empty() {
183            break; // End of headers
184        }
185
186        if let Some((key, value)) = line.split_once(':') {
187            headers.insert(key.trim().to_lowercase(), value.trim().to_string());
188        } else {
189            return Err(Error::Protocol(ProtocolError::InvalidHeader {
190                header: "unknown".to_string(),
191                value: line.to_string(),
192            }));
193        }
194    }
195
196    Ok(HandshakeRequest {
197        method,
198        uri,
199        version,
200        headers,
201        body: vec![],
202    })
203}
204
205/// Validate a client handshake request
206pub fn validate_client_handshake(
207    request: &HandshakeRequest,
208    config: &HandshakeConfig,
209) -> Result<(), Error> {
210    // Check required headers
211    let upgrade = request
212        .headers
213        .get(HEADER_UPGRADE)
214        .ok_or_else(|| Error::Protocol(ProtocolError::MissingHeader(HEADER_UPGRADE.to_string())))?;
215
216    if upgrade.to_lowercase() != http_value::WEBSOCKET {
217        return Err(Error::Protocol(ProtocolError::InvalidHeaderValue {
218            header: HEADER_UPGRADE.to_string(),
219            value: upgrade.clone(),
220        }));
221    }
222
223    let connection = request.headers.get(HEADER_CONNECTION).ok_or_else(|| {
224        Error::Protocol(ProtocolError::MissingHeader(HEADER_CONNECTION.to_string()))
225    })?;
226
227    if !connection.to_lowercase().contains("upgrade") {
228        return Err(Error::Protocol(ProtocolError::InvalidHeaderValue {
229            header: HEADER_CONNECTION.to_string(),
230            value: connection.clone(),
231        }));
232    }
233
234    let key = request
235        .headers
236        .get(HEADER_SEC_WEBSOCKET_KEY)
237        .ok_or_else(|| {
238            Error::Protocol(ProtocolError::MissingHeader(
239                HEADER_SEC_WEBSOCKET_KEY.to_string(),
240            ))
241        })?;
242
243    if !validate_key(key) {
244        return Err(Error::Protocol(ProtocolError::InvalidHeaderValue {
245            header: HEADER_SEC_WEBSOCKET_KEY.to_string(),
246            value: key.clone(),
247        }));
248    }
249
250    let version = request
251        .headers
252        .get(HEADER_SEC_WEBSOCKET_VERSION)
253        .ok_or_else(|| {
254            Error::Protocol(ProtocolError::MissingHeader(
255                HEADER_SEC_WEBSOCKET_VERSION.to_string(),
256            ))
257        })?;
258
259    if !validate_version(version) {
260        return Err(Error::Protocol(ProtocolError::InvalidHeaderValue {
261            header: HEADER_SEC_WEBSOCKET_VERSION.to_string(),
262            value: version.clone(),
263        }));
264    }
265
266    // Check optional headers
267    if let Some(origin) = &config.origin {
268        if let Some(client_origin) = request.headers.get(ORIGIN) {
269            if client_origin != origin {
270                return Err(Error::Protocol(ProtocolError::InvalidOrigin {
271                    expected: origin.clone(),
272                    received: client_origin.clone(),
273                }));
274            }
275        }
276    }
277
278    if !config.protocols.is_empty() {
279        if let Some(protocol_header) = request.headers.get(HEADER_SEC_WEBSOCKET_PROTOCOL) {
280            let client_protocols: Vec<&str> =
281                protocol_header.split(',').map(|s| s.trim()).collect();
282            if !client_protocols
283                .iter()
284                .any(|p| config.protocols.contains(&p.to_string()))
285            {
286                return Err(Error::Protocol(ProtocolError::UnsupportedProtocol(
287                    protocol_header.clone(),
288                )));
289            }
290        } else {
291            return Err(Error::Protocol(ProtocolError::MissingHeader(
292                HEADER_SEC_WEBSOCKET_PROTOCOL.to_string(),
293            )));
294        }
295    }
296
297    Ok(())
298}
299
300/// Create a server handshake response
301pub fn create_server_handshake(
302    request: &HandshakeRequest,
303    config: &HandshakeConfig,
304) -> Result<HandshakeResponse, Error> {
305    let mut headers = HashMap::new();
306
307    // Required headers
308    headers.insert(
309        HEADER_UPGRADE.to_string(),
310        http_value::WEBSOCKET.to_string(),
311    );
312    headers.insert(
313        HEADER_CONNECTION.to_string(),
314        http_value::UPGRADE.to_string(),
315    );
316
317    // Compute accept key
318    if let Some(client_key) = request.headers.get(HEADER_SEC_WEBSOCKET_KEY) {
319        let accept_key = compute_accept_key(client_key)?;
320        headers.insert(HEADER_SEC_WEBSOCKET_ACCEPT.to_string(), accept_key);
321    } else {
322        return Err(Error::Protocol(ProtocolError::MissingHeader(
323            HEADER_SEC_WEBSOCKET_KEY.to_string(),
324        )));
325    }
326
327    // Protocol negotiation
328    if !config.protocols.is_empty() {
329        if let Some(protocol_header) = request.headers.get(HEADER_SEC_WEBSOCKET_PROTOCOL) {
330            let client_protocols: Vec<&str> =
331                protocol_header.split(',').map(|s| s.trim()).collect();
332            for protocol in &config.protocols {
333                if client_protocols.contains(&protocol.as_str()) {
334                    headers.insert(HEADER_SEC_WEBSOCKET_PROTOCOL.to_string(), protocol.clone());
335                    break;
336                }
337            }
338        }
339    }
340
341    // Add extra headers
342    for (key, value) in &config.extra_headers {
343        headers.insert(key.clone(), value.clone());
344    }
345
346    Ok(HandshakeResponse {
347        status: SWITCHING_PROTOCOLS,
348        status_message: "Switching Protocols".to_string(),
349        headers,
350        body: vec![],
351    })
352}
353
354/// Parse a server handshake response
355pub fn parse_server_handshake(response: &str) -> Result<HandshakeResponse, Error> {
356    let mut lines = response.lines();
357
358    // Parse status line
359    let status_line = lines.next().ok_or_else(|| {
360        Error::Protocol(ProtocolError::InvalidFormat(
361            "Missing status line".to_string(),
362        ))
363    })?;
364
365    let mut parts = status_line.split_whitespace();
366    let _version = parts
367        .next()
368        .ok_or_else(|| {
369            Error::Protocol(ProtocolError::InvalidFormat(
370                "Missing HTTP version".to_string(),
371            ))
372        })?
373        .to_string();
374
375    let status_str = parts.next().ok_or_else(|| {
376        Error::Protocol(ProtocolError::InvalidFormat(
377            "Missing status code".to_string(),
378        ))
379    })?;
380
381    let status = status_str.parse::<u16>().map_err(|_| {
382        Error::Protocol(ProtocolError::InvalidFormat(
383            "Invalid status code".to_string(),
384        ))
385    })?;
386
387    let status_message = parts.collect::<Vec<&str>>().join(" ");
388
389    // Parse headers
390    let mut headers = HashMap::new();
391    for line in lines {
392        if line.is_empty() {
393            break; // End of headers
394        }
395
396        if let Some((key, value)) = line.split_once(':') {
397            headers.insert(key.trim().to_lowercase(), value.trim().to_string());
398        } else {
399            return Err(Error::Protocol(ProtocolError::InvalidHeader {
400                header: "unknown".to_string(),
401                value: line.to_string(),
402            }));
403        }
404    }
405
406    Ok(HandshakeResponse {
407        status,
408        status_message,
409        headers,
410        body: vec![],
411    })
412}
413
414/// Validate a server handshake response
415pub fn validate_server_handshake(
416    response: &HandshakeResponse,
417    client_key: &str,
418) -> Result<(), Error> {
419    // Check status code
420    if response.status != SWITCHING_PROTOCOLS {
421        return Err(Error::Protocol(ProtocolError::UnexpectedStatus(
422            response.status,
423        )));
424    }
425
426    // Check required headers
427    let upgrade = response
428        .headers
429        .get(HEADER_UPGRADE)
430        .ok_or_else(|| Error::Protocol(ProtocolError::MissingHeader(HEADER_UPGRADE.to_string())))?;
431
432    if upgrade.to_lowercase() != http_value::WEBSOCKET {
433        return Err(Error::Protocol(ProtocolError::InvalidHeaderValue {
434            header: HEADER_UPGRADE.to_string(),
435            value: upgrade.clone(),
436        }));
437    }
438
439    let connection = response.headers.get(HEADER_CONNECTION).ok_or_else(|| {
440        Error::Protocol(ProtocolError::MissingHeader(HEADER_CONNECTION.to_string()))
441    })?;
442
443    if !connection.to_lowercase().contains("upgrade") {
444        return Err(Error::Protocol(ProtocolError::InvalidHeaderValue {
445            header: HEADER_CONNECTION.to_string(),
446            value: connection.clone(),
447        }));
448    }
449
450    let accept = response
451        .headers
452        .get(HEADER_SEC_WEBSOCKET_ACCEPT)
453        .ok_or_else(|| {
454            Error::Protocol(ProtocolError::MissingHeader(
455                HEADER_SEC_WEBSOCKET_ACCEPT.to_string(),
456            ))
457        })?;
458
459    let expected_accept = compute_accept_key(client_key)?;
460    if accept.as_str() != expected_accept {
461        return Err(Error::Protocol(ProtocolError::InvalidAcceptKey {
462            expected: expected_accept,
463            received: accept.clone(),
464        }));
465    }
466
467    Ok(())
468}
469
470/// Convert handshake request to HTTP string
471pub fn request_to_string(request: &HandshakeRequest) -> String {
472    let mut lines = vec![format!(
473        "{} {} {}",
474        request.method, request.uri, request.version
475    )];
476
477    for (key, value) in &request.headers {
478        lines.push(format!("{}: {}", key, value));
479    }
480
481    lines.push(String::new()); // Empty line after headers
482    lines.join("\r\n")
483}
484
485/// Convert handshake response to HTTP string
486pub fn response_to_string(response: &HandshakeResponse) -> String {
487    let mut lines = vec![format!(
488        "HTTP/1.1 {} {}",
489        response.status, response.status_message
490    )];
491
492    for (key, value) in &response.headers {
493        lines.push(format!("{}: {}", key, value));
494    }
495
496    lines.push(String::new()); // Empty line after headers
497    lines.join("\r\n")
498}
499
500#[cfg(test)]
501mod tests {
502    use super::*;
503
504    #[test]
505    fn test_key_generation() {
506        let key = generate_key();
507        assert_eq!(key.len(), 24);
508        assert!(validate_key(&key));
509    }
510
511    #[test]
512    fn test_accept_key_calculation() {
513        let key = "dGhlIHNhbXBsZSBub25jZQ=="; // "the sample nonce"
514        let expected = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=";
515        let accept = compute_accept_key(key).unwrap();
516        assert_eq!(accept, expected);
517    }
518
519    #[test]
520    fn test_client_handshake_creation() {
521        let config = HandshakeConfig {
522            host: Some("example.com".to_string()),
523            protocols: vec!["chat".to_string()],
524            ..Default::default()
525        };
526
527        let request = create_client_handshake("ws://example.com/chat", &config).unwrap();
528        assert_eq!(request.method, "GET");
529        assert_eq!(request.uri, "ws://example.com/chat");
530        assert_eq!(request.headers.get("upgrade").unwrap(), "websocket");
531        assert_eq!(
532            request.headers.get("sec-websocket-protocol").unwrap(),
533            "chat"
534        );
535    }
536
537    #[test]
538    fn test_client_handshake_parsing() {
539        let raw_request = r#"GET /chat HTTP/1.1
540Host: example.com
541Upgrade: websocket
542Connection: Upgrade
543Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
544Sec-WebSocket-Version: 13
545
546"#;
547
548        let request = parse_client_handshake(raw_request).unwrap();
549        assert_eq!(request.method, "GET");
550        assert_eq!(request.uri, "/chat");
551        assert_eq!(request.headers.get("upgrade").unwrap(), "websocket");
552    }
553}