1use crate::error::{Error, ProtocolError};
7use crate::protocol::constants::*;
8use crate::protocol::http_header::*;
9use crate::protocol::http_method;
10use crate::protocol::http_status::*;
11use crate::protocol::http_value;
12use base64::{engine::general_purpose, Engine as _};
13use sha1::{Digest, Sha1};
14use std::collections::HashMap;
15
16#[derive(Debug, Clone)]
18pub enum Auth {
19 Basic {
21 username: String,
23 password: String,
25 },
26 Bearer {
28 token: String,
30 },
31}
32
33#[derive(Debug, Clone)]
35pub struct HandshakeRequest {
36 pub method: String,
38 pub uri: String,
40 pub version: String,
42 pub headers: HashMap<String, String>,
44 pub body: Vec<u8>,
46}
47
48#[derive(Debug, Clone)]
50pub struct CompressionConfig {
51 pub enabled: bool,
53 pub client_max_window_bits: Option<u8>,
55 pub server_max_window_bits: Option<u8>,
57 pub compression_level: Option<u32>,
59}
60
61impl Default for CompressionConfig {
62 fn default() -> Self {
63 Self {
64 enabled: false,
65 client_max_window_bits: Some(15),
66 server_max_window_bits: Some(15),
67 compression_level: Some(6),
68 }
69 }
70}
71
72#[derive(Debug, Clone)]
74pub struct HandshakeResponse {
75 pub status: u16,
77 pub status_message: String,
79 pub headers: HashMap<String, String>,
81 pub body: Vec<u8>,
83}
84
85#[derive(Debug, Clone, Default)]
87pub struct HandshakeConfig {
88 pub protocols: Vec<String>,
90 pub extensions: Vec<String>,
92 pub origin: Option<String>,
94 pub allowed_origins: Vec<String>,
96 pub host: Option<String>,
98 pub auth: Option<Auth>,
100 pub compression: CompressionConfig,
102 pub extra_headers: HashMap<String, String>,
104}
105
106pub fn generate_key() -> String {
108 use rand::RngCore;
109 let mut key_bytes = [0u8; 16];
110 rand::rng().fill_bytes(&mut key_bytes);
111 general_purpose::STANDARD.encode(key_bytes)
112}
113
114pub fn compute_accept_key(client_key: &str) -> Result<String, Error> {
116 let combined = format!("{}{}", client_key, WEBSOCKET_MAGIC);
117 let hash = Sha1::digest(combined.as_bytes());
118 Ok(general_purpose::STANDARD.encode(hash))
119}
120
121pub fn validate_key(key: &str) -> bool {
123 key.len() == 24 && general_purpose::STANDARD.decode(key).is_ok()
124}
125
126pub fn validate_version(version: &str) -> bool {
128 version == WEBSOCKET_VERSION
129}
130
131pub fn create_client_handshake(
133 uri: &str,
134 config: &HandshakeConfig,
135) -> Result<HandshakeRequest, Error> {
136 let mut headers = HashMap::new();
137
138 headers.insert(
140 HEADER_UPGRADE.to_string(),
141 http_value::WEBSOCKET.to_string(),
142 );
143 headers.insert(
144 HEADER_CONNECTION.to_string(),
145 http_value::UPGRADE.to_string(),
146 );
147 headers.insert(HEADER_SEC_WEBSOCKET_KEY.to_string(), generate_key());
148 headers.insert(
149 HEADER_SEC_WEBSOCKET_VERSION.to_string(),
150 WEBSOCKET_VERSION.to_string(),
151 );
152
153 if let Some(host) = &config.host {
155 headers.insert(HOST.to_string(), host.clone());
156 }
157
158 if let Some(origin) = &config.origin {
159 headers.insert(ORIGIN.to_string(), origin.clone());
160 }
161
162 if !config.protocols.is_empty() {
163 headers.insert(
164 HEADER_SEC_WEBSOCKET_PROTOCOL.to_string(),
165 config.protocols.join(", "),
166 );
167 }
168
169 if !config.extensions.is_empty() {
170 headers.insert(
171 HEADER_SEC_WEBSOCKET_EXTENSIONS.to_string(),
172 config.extensions.join(", "),
173 );
174 }
175
176 #[cfg(feature = "compression")]
178 if config.compression.enabled {
179 let mut ext_parts: Vec<String> = vec!["permessage-deflate".to_string()];
180 if let Some(bits) = config.compression.client_max_window_bits {
181 ext_parts.push(format!("client_max_window_bits={}", bits));
182 }
183 if let Some(bits) = config.compression.server_max_window_bits {
184 ext_parts.push(format!("server_max_window_bits={}", bits));
185 }
186 let compression_ext = ext_parts.join("; ");
187 let existing = headers
188 .get(HEADER_SEC_WEBSOCKET_EXTENSIONS)
189 .cloned()
190 .unwrap_or_default();
191 let new_value = if existing.is_empty() {
192 compression_ext
193 } else {
194 format!("{}, {}", existing, compression_ext)
195 };
196 headers.insert(HEADER_SEC_WEBSOCKET_EXTENSIONS.to_string(), new_value);
197 }
198
199 if let Some(auth) = &config.auth {
201 let value = match auth {
202 Auth::Basic { username, password } => {
203 let credentials = format!("{}:{}", username, password);
204 format!("Basic {}", general_purpose::STANDARD.encode(credentials))
205 }
206 Auth::Bearer { token } => format!("Bearer {}", token),
207 };
208 headers.insert(AUTHORIZATION.to_string(), value);
209 }
210
211 for (key, value) in &config.extra_headers {
213 headers.insert(key.clone(), value.clone());
214 }
215
216 Ok(HandshakeRequest {
217 method: http_method::GET.to_string(),
218 uri: uri.to_string(),
219 version: "HTTP/1.1".to_string(),
220 headers,
221 body: vec![],
222 })
223}
224
225pub fn parse_client_handshake(request: &str) -> Result<HandshakeRequest, Error> {
227 let mut lines = request.lines();
228
229 let request_line = lines.next().ok_or_else(|| {
231 Error::Protocol(ProtocolError::InvalidFormat(
232 "Missing request line".to_string(),
233 ))
234 })?;
235
236 let mut parts = request_line.split_whitespace();
237 let method = parts
238 .next()
239 .ok_or_else(|| Error::Protocol(ProtocolError::InvalidFormat("Missing method".to_string())))?
240 .to_string();
241
242 let uri = parts
243 .next()
244 .ok_or_else(|| Error::Protocol(ProtocolError::InvalidFormat("Missing URI".to_string())))?
245 .to_string();
246
247 let version = parts
248 .next()
249 .ok_or_else(|| {
250 Error::Protocol(ProtocolError::InvalidFormat(
251 "Missing HTTP version".to_string(),
252 ))
253 })?
254 .to_string();
255
256 if method != http_method::GET {
258 return Err(Error::Protocol(ProtocolError::InvalidMethod(method)));
259 }
260
261 let mut headers = HashMap::new();
263 for line in lines {
264 if line.is_empty() {
265 break; }
267
268 if let Some((key, value)) = line.split_once(':') {
269 headers.insert(key.trim().to_lowercase(), value.trim().to_string());
270 } else {
271 return Err(Error::Protocol(ProtocolError::InvalidHeader {
272 header: "unknown".to_string(),
273 value: line.to_string(),
274 }));
275 }
276 }
277
278 Ok(HandshakeRequest {
279 method,
280 uri,
281 version,
282 headers,
283 body: vec![],
284 })
285}
286
287pub fn validate_client_handshake(
289 request: &HandshakeRequest,
290 config: &HandshakeConfig,
291) -> Result<(), Error> {
292 let upgrade = request
294 .headers
295 .get(HEADER_UPGRADE)
296 .ok_or_else(|| Error::Protocol(ProtocolError::MissingHeader(HEADER_UPGRADE.to_string())))?;
297
298 if upgrade.to_lowercase() != http_value::WEBSOCKET {
299 return Err(Error::Protocol(ProtocolError::InvalidHeader {
300 header: HEADER_UPGRADE.to_string(),
301 value: upgrade.clone(),
302 }));
303 }
304
305 let connection = request.headers.get(HEADER_CONNECTION).ok_or_else(|| {
306 Error::Protocol(ProtocolError::MissingHeader(HEADER_CONNECTION.to_string()))
307 })?;
308
309 if !connection.to_lowercase().contains("upgrade") {
310 return Err(Error::Protocol(ProtocolError::InvalidHeader {
311 header: HEADER_CONNECTION.to_string(),
312 value: connection.clone(),
313 }));
314 }
315
316 let key = request
317 .headers
318 .get(HEADER_SEC_WEBSOCKET_KEY)
319 .ok_or_else(|| {
320 Error::Protocol(ProtocolError::MissingHeader(
321 HEADER_SEC_WEBSOCKET_KEY.to_string(),
322 ))
323 })?;
324
325 if !validate_key(key) {
326 return Err(Error::Protocol(ProtocolError::InvalidHeader {
327 header: HEADER_SEC_WEBSOCKET_KEY.to_string(),
328 value: key.clone(),
329 }));
330 }
331
332 let version = request
333 .headers
334 .get(HEADER_SEC_WEBSOCKET_VERSION)
335 .ok_or_else(|| {
336 Error::Protocol(ProtocolError::MissingHeader(
337 HEADER_SEC_WEBSOCKET_VERSION.to_string(),
338 ))
339 })?;
340
341 if !validate_version(version) {
342 return Err(Error::Protocol(ProtocolError::InvalidHeader {
343 header: HEADER_SEC_WEBSOCKET_VERSION.to_string(),
344 value: version.clone(),
345 }));
346 }
347
348 if !config.allowed_origins.is_empty() {
350 if let Some(client_origin) = request.headers.get(ORIGIN) {
351 if !config.allowed_origins.contains(client_origin) {
352 return Err(Error::Protocol(ProtocolError::InvalidOrigin {
353 expected: config.allowed_origins.join(", "),
354 received: client_origin.clone(),
355 }));
356 }
357 }
358 }
359
360 if !config.protocols.is_empty() {
361 if let Some(protocol_header) = request.headers.get(HEADER_SEC_WEBSOCKET_PROTOCOL) {
362 let client_protocols: Vec<&str> =
363 protocol_header.split(',').map(|s| s.trim()).collect();
364 if !client_protocols
365 .iter()
366 .any(|p| config.protocols.contains(&p.to_string()))
367 {
368 return Err(Error::Protocol(ProtocolError::UnsupportedProtocol(
369 protocol_header.clone(),
370 )));
371 }
372 } else {
373 return Err(Error::Protocol(ProtocolError::MissingHeader(
374 HEADER_SEC_WEBSOCKET_PROTOCOL.to_string(),
375 )));
376 }
377 }
378
379 Ok(())
380}
381
382pub fn create_server_handshake(
384 request: &HandshakeRequest,
385 config: &HandshakeConfig,
386) -> Result<HandshakeResponse, Error> {
387 let mut headers = HashMap::new();
388
389 headers.insert(
391 HEADER_UPGRADE.to_string(),
392 http_value::WEBSOCKET.to_string(),
393 );
394 headers.insert(
395 HEADER_CONNECTION.to_string(),
396 http_value::UPGRADE.to_string(),
397 );
398
399 if let Some(client_key) = request.headers.get(HEADER_SEC_WEBSOCKET_KEY) {
401 let accept_key = compute_accept_key(client_key)?;
402 headers.insert(HEADER_SEC_WEBSOCKET_ACCEPT.to_string(), accept_key);
403 } else {
404 return Err(Error::Protocol(ProtocolError::MissingHeader(
405 HEADER_SEC_WEBSOCKET_KEY.to_string(),
406 )));
407 }
408
409 if !config.protocols.is_empty() {
411 if let Some(protocol_header) = request.headers.get(HEADER_SEC_WEBSOCKET_PROTOCOL) {
412 let client_protocols: Vec<&str> =
413 protocol_header.split(',').map(|s| s.trim()).collect();
414 for protocol in &config.protocols {
415 if client_protocols.contains(&protocol.as_str()) {
416 headers.insert(HEADER_SEC_WEBSOCKET_PROTOCOL.to_string(), protocol.clone());
417 break;
418 }
419 }
420 }
421 }
422
423 #[cfg(feature = "compression")]
425 if config.compression.enabled {
426 if let Some(ext_header) = request.headers.get(HEADER_SEC_WEBSOCKET_EXTENSIONS) {
427 if ext_header.contains("permessage-deflate") {
428 let mut ext_parts: Vec<String> = vec!["permessage-deflate".to_string()];
429 if let Some(bits) = config.compression.server_max_window_bits {
430 ext_parts.push(format!("server_max_window_bits={}", bits));
431 }
432 if let Some(bits) = config.compression.client_max_window_bits {
433 ext_parts.push(format!("client_max_window_bits={}", bits));
434 }
435 headers.insert(
436 HEADER_SEC_WEBSOCKET_EXTENSIONS.to_string(),
437 ext_parts.join("; "),
438 );
439 }
440 }
441 }
442
443 for (key, value) in &config.extra_headers {
445 headers.insert(key.clone(), value.clone());
446 }
447
448 Ok(HandshakeResponse {
449 status: SWITCHING_PROTOCOLS,
450 status_message: "Switching Protocols".to_string(),
451 headers,
452 body: vec![],
453 })
454}
455
456pub fn parse_server_handshake(response: &str) -> Result<HandshakeResponse, Error> {
458 let mut lines = response.lines();
459
460 let status_line = lines.next().ok_or_else(|| {
462 Error::Protocol(ProtocolError::InvalidFormat(
463 "Missing status line".to_string(),
464 ))
465 })?;
466
467 let mut parts = status_line.split_whitespace();
468 let _version = parts
469 .next()
470 .ok_or_else(|| {
471 Error::Protocol(ProtocolError::InvalidFormat(
472 "Missing HTTP version".to_string(),
473 ))
474 })?
475 .to_string();
476
477 let status_str = parts.next().ok_or_else(|| {
478 Error::Protocol(ProtocolError::InvalidFormat(
479 "Missing status code".to_string(),
480 ))
481 })?;
482
483 let status = status_str.parse::<u16>().map_err(|_| {
484 Error::Protocol(ProtocolError::InvalidFormat(
485 "Invalid status code".to_string(),
486 ))
487 })?;
488
489 let status_message = parts.collect::<Vec<&str>>().join(" ");
490
491 let mut headers = HashMap::new();
493 for line in lines {
494 if line.is_empty() {
495 break; }
497
498 if let Some((key, value)) = line.split_once(':') {
499 headers.insert(key.trim().to_lowercase(), value.trim().to_string());
500 } else {
501 return Err(Error::Protocol(ProtocolError::InvalidHeader {
502 header: "unknown".to_string(),
503 value: line.to_string(),
504 }));
505 }
506 }
507
508 Ok(HandshakeResponse {
509 status,
510 status_message,
511 headers,
512 body: vec![],
513 })
514}
515
516pub fn validate_server_handshake(
518 response: &HandshakeResponse,
519 client_key: &str,
520) -> Result<(), Error> {
521 if response.status != SWITCHING_PROTOCOLS {
523 return Err(Error::Protocol(ProtocolError::UnexpectedStatus(
524 response.status,
525 )));
526 }
527
528 let upgrade = response
530 .headers
531 .get(HEADER_UPGRADE)
532 .ok_or_else(|| Error::Protocol(ProtocolError::MissingHeader(HEADER_UPGRADE.to_string())))?;
533
534 if upgrade.to_lowercase() != http_value::WEBSOCKET {
535 return Err(Error::Protocol(ProtocolError::InvalidHeader {
536 header: HEADER_UPGRADE.to_string(),
537 value: upgrade.clone(),
538 }));
539 }
540
541 let connection = response.headers.get(HEADER_CONNECTION).ok_or_else(|| {
542 Error::Protocol(ProtocolError::MissingHeader(HEADER_CONNECTION.to_string()))
543 })?;
544
545 if !connection.to_lowercase().contains("upgrade") {
546 return Err(Error::Protocol(ProtocolError::InvalidHeader {
547 header: HEADER_CONNECTION.to_string(),
548 value: connection.clone(),
549 }));
550 }
551
552 let accept = response
553 .headers
554 .get(HEADER_SEC_WEBSOCKET_ACCEPT)
555 .ok_or_else(|| {
556 Error::Protocol(ProtocolError::MissingHeader(
557 HEADER_SEC_WEBSOCKET_ACCEPT.to_string(),
558 ))
559 })?;
560
561 let expected_accept = compute_accept_key(client_key)?;
562 if accept.as_str() != expected_accept {
563 return Err(Error::Protocol(ProtocolError::InvalidAcceptKey {
564 expected: expected_accept,
565 received: accept.clone(),
566 }));
567 }
568
569 Ok(())
570}
571
572pub fn request_to_string(request: &HandshakeRequest) -> String {
574 let mut lines = vec![format!(
575 "{} {} {}",
576 request.method, request.uri, request.version
577 )];
578
579 for (key, value) in &request.headers {
580 lines.push(format!("{}: {}", key, value));
581 }
582
583 lines.push(String::new()); lines.join("\r\n")
585}
586
587pub fn response_to_string(response: &HandshakeResponse) -> String {
589 let mut lines = vec![format!(
590 "HTTP/1.1 {} {}",
591 response.status, response.status_message
592 )];
593
594 for (key, value) in &response.headers {
595 lines.push(format!("{}: {}", key, value));
596 }
597
598 lines.push(String::new()); lines.join("\r\n")
600}
601
602#[cfg(test)]
603mod tests {
604 use super::*;
605
606 #[test]
607 fn test_key_generation() {
608 let key = generate_key();
609 assert_eq!(key.len(), 24);
610 assert!(validate_key(&key));
611 }
612
613 #[test]
614 fn test_accept_key_calculation() {
615 let key = "dGhlIHNhbXBsZSBub25jZQ=="; let expected = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=";
617 let accept = compute_accept_key(key).unwrap();
618 assert_eq!(accept, expected);
619 }
620
621 #[test]
622 fn test_client_handshake_creation() {
623 let config = HandshakeConfig {
624 host: Some("example.com".to_string()),
625 protocols: vec!["chat".to_string()],
626 ..Default::default()
627 };
628
629 let request = create_client_handshake("ws://example.com/chat", &config).unwrap();
630 assert_eq!(request.method, "GET");
631 assert_eq!(request.uri, "ws://example.com/chat");
632 assert_eq!(request.headers.get("upgrade").unwrap(), "websocket");
633 assert_eq!(
634 request.headers.get("sec-websocket-protocol").unwrap(),
635 "chat"
636 );
637 }
638
639 #[test]
640 fn test_client_handshake_parsing() {
641 let raw_request = r#"GET /chat HTTP/1.1
642Host: example.com
643Upgrade: websocket
644Connection: Upgrade
645Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
646Sec-WebSocket-Version: 13
647
648"#;
649
650 let request = parse_client_handshake(raw_request).unwrap();
651 assert_eq!(request.method, "GET");
652 assert_eq!(request.uri, "/chat");
653 assert_eq!(request.headers.get("upgrade").unwrap(), "websocket");
654 }
655}