1use crate::error::{Error, ProtocolError};
7use crate::protocol::constants::*;
8use crate::protocol::http_header::*;
9use crate::protocol::http_method;
10use crate::protocol::http_status::*;
11use crate::protocol::http_value;
12use base64::{engine::general_purpose, Engine as _};
13use sha1::{Digest, Sha1};
14use std::collections::HashMap;
15
16#[derive(Debug, Clone)]
18pub struct HandshakeRequest {
19 pub method: String,
21 pub uri: String,
23 pub version: String,
25 pub headers: HashMap<String, String>,
27 pub body: Vec<u8>,
29}
30
31#[derive(Debug, Clone)]
33pub struct HandshakeResponse {
34 pub status: u16,
36 pub status_message: String,
38 pub headers: HashMap<String, String>,
40 pub body: Vec<u8>,
42}
43
44#[derive(Debug, Clone, Default)]
46pub struct HandshakeConfig {
47 pub protocols: Vec<String>,
49 pub extensions: Vec<String>,
51 pub origin: Option<String>,
53 pub host: Option<String>,
55 pub extra_headers: HashMap<String, String>,
57}
58
59pub fn generate_key() -> String {
61 use rand::RngCore;
62 let mut key_bytes = [0u8; 16];
63 rand::thread_rng().fill_bytes(&mut key_bytes);
64 general_purpose::STANDARD.encode(key_bytes)
65}
66
67pub fn compute_accept_key(client_key: &str) -> Result<String, Error> {
69 let combined = format!("{}{}", client_key, WEBSOCKET_MAGIC);
70 let hash = Sha1::digest(combined.as_bytes());
71 Ok(general_purpose::STANDARD.encode(hash))
72}
73
74pub fn validate_key(key: &str) -> bool {
76 key.len() == 24 && general_purpose::STANDARD.decode(key).is_ok()
77}
78
79pub fn validate_version(version: &str) -> bool {
81 version == WEBSOCKET_VERSION
82}
83
84pub fn create_client_handshake(
86 uri: &str,
87 config: &HandshakeConfig,
88) -> Result<HandshakeRequest, Error> {
89 let mut headers = HashMap::new();
90
91 headers.insert(
93 HEADER_UPGRADE.to_string(),
94 http_value::WEBSOCKET.to_string(),
95 );
96 headers.insert(
97 HEADER_CONNECTION.to_string(),
98 http_value::UPGRADE.to_string(),
99 );
100 headers.insert(HEADER_SEC_WEBSOCKET_KEY.to_string(), generate_key());
101 headers.insert(
102 HEADER_SEC_WEBSOCKET_VERSION.to_string(),
103 WEBSOCKET_VERSION.to_string(),
104 );
105
106 if let Some(host) = &config.host {
108 headers.insert(HOST.to_string(), host.clone());
109 }
110
111 if let Some(origin) = &config.origin {
112 headers.insert(ORIGIN.to_string(), origin.clone());
113 }
114
115 if !config.protocols.is_empty() {
116 headers.insert(
117 HEADER_SEC_WEBSOCKET_PROTOCOL.to_string(),
118 config.protocols.join(", "),
119 );
120 }
121
122 if !config.extensions.is_empty() {
123 headers.insert(
124 HEADER_SEC_WEBSOCKET_EXTENSIONS.to_string(),
125 config.extensions.join(", "),
126 );
127 }
128
129 for (key, value) in &config.extra_headers {
131 headers.insert(key.clone(), value.clone());
132 }
133
134 Ok(HandshakeRequest {
135 method: http_method::GET.to_string(),
136 uri: uri.to_string(),
137 version: "HTTP/1.1".to_string(),
138 headers,
139 body: vec![],
140 })
141}
142
143pub fn parse_client_handshake(request: &str) -> Result<HandshakeRequest, Error> {
145 let mut lines = request.lines();
146
147 let request_line = lines.next().ok_or_else(|| {
149 Error::Protocol(ProtocolError::InvalidFormat(
150 "Missing request line".to_string(),
151 ))
152 })?;
153
154 let mut parts = request_line.split_whitespace();
155 let method = parts
156 .next()
157 .ok_or_else(|| Error::Protocol(ProtocolError::InvalidFormat("Missing method".to_string())))?
158 .to_string();
159
160 let uri = parts
161 .next()
162 .ok_or_else(|| Error::Protocol(ProtocolError::InvalidFormat("Missing URI".to_string())))?
163 .to_string();
164
165 let version = parts
166 .next()
167 .ok_or_else(|| {
168 Error::Protocol(ProtocolError::InvalidFormat(
169 "Missing HTTP version".to_string(),
170 ))
171 })?
172 .to_string();
173
174 if method != http_method::GET {
176 return Err(Error::Protocol(ProtocolError::InvalidMethod(method)));
177 }
178
179 let mut headers = HashMap::new();
181 for line in lines {
182 if line.is_empty() {
183 break; }
185
186 if let Some((key, value)) = line.split_once(':') {
187 headers.insert(key.trim().to_lowercase(), value.trim().to_string());
188 } else {
189 return Err(Error::Protocol(ProtocolError::InvalidHeader {
190 header: "unknown".to_string(),
191 value: line.to_string(),
192 }));
193 }
194 }
195
196 Ok(HandshakeRequest {
197 method,
198 uri,
199 version,
200 headers,
201 body: vec![],
202 })
203}
204
205pub fn validate_client_handshake(
207 request: &HandshakeRequest,
208 config: &HandshakeConfig,
209) -> Result<(), Error> {
210 let upgrade = request
212 .headers
213 .get(HEADER_UPGRADE)
214 .ok_or_else(|| Error::Protocol(ProtocolError::MissingHeader(HEADER_UPGRADE.to_string())))?;
215
216 if upgrade.to_lowercase() != http_value::WEBSOCKET {
217 return Err(Error::Protocol(ProtocolError::InvalidHeaderValue {
218 header: HEADER_UPGRADE.to_string(),
219 value: upgrade.clone(),
220 }));
221 }
222
223 let connection = request.headers.get(HEADER_CONNECTION).ok_or_else(|| {
224 Error::Protocol(ProtocolError::MissingHeader(HEADER_CONNECTION.to_string()))
225 })?;
226
227 if !connection.to_lowercase().contains("upgrade") {
228 return Err(Error::Protocol(ProtocolError::InvalidHeaderValue {
229 header: HEADER_CONNECTION.to_string(),
230 value: connection.clone(),
231 }));
232 }
233
234 let key = request
235 .headers
236 .get(HEADER_SEC_WEBSOCKET_KEY)
237 .ok_or_else(|| {
238 Error::Protocol(ProtocolError::MissingHeader(
239 HEADER_SEC_WEBSOCKET_KEY.to_string(),
240 ))
241 })?;
242
243 if !validate_key(key) {
244 return Err(Error::Protocol(ProtocolError::InvalidHeaderValue {
245 header: HEADER_SEC_WEBSOCKET_KEY.to_string(),
246 value: key.clone(),
247 }));
248 }
249
250 let version = request
251 .headers
252 .get(HEADER_SEC_WEBSOCKET_VERSION)
253 .ok_or_else(|| {
254 Error::Protocol(ProtocolError::MissingHeader(
255 HEADER_SEC_WEBSOCKET_VERSION.to_string(),
256 ))
257 })?;
258
259 if !validate_version(version) {
260 return Err(Error::Protocol(ProtocolError::InvalidHeaderValue {
261 header: HEADER_SEC_WEBSOCKET_VERSION.to_string(),
262 value: version.clone(),
263 }));
264 }
265
266 if let Some(origin) = &config.origin {
268 if let Some(client_origin) = request.headers.get(ORIGIN) {
269 if client_origin != origin {
270 return Err(Error::Protocol(ProtocolError::InvalidOrigin {
271 expected: origin.clone(),
272 received: client_origin.clone(),
273 }));
274 }
275 }
276 }
277
278 if !config.protocols.is_empty() {
279 if let Some(protocol_header) = request.headers.get(HEADER_SEC_WEBSOCKET_PROTOCOL) {
280 let client_protocols: Vec<&str> =
281 protocol_header.split(',').map(|s| s.trim()).collect();
282 if !client_protocols
283 .iter()
284 .any(|p| config.protocols.contains(&p.to_string()))
285 {
286 return Err(Error::Protocol(ProtocolError::UnsupportedProtocol(
287 protocol_header.clone(),
288 )));
289 }
290 } else {
291 return Err(Error::Protocol(ProtocolError::MissingHeader(
292 HEADER_SEC_WEBSOCKET_PROTOCOL.to_string(),
293 )));
294 }
295 }
296
297 Ok(())
298}
299
300pub fn create_server_handshake(
302 request: &HandshakeRequest,
303 config: &HandshakeConfig,
304) -> Result<HandshakeResponse, Error> {
305 let mut headers = HashMap::new();
306
307 headers.insert(
309 HEADER_UPGRADE.to_string(),
310 http_value::WEBSOCKET.to_string(),
311 );
312 headers.insert(
313 HEADER_CONNECTION.to_string(),
314 http_value::UPGRADE.to_string(),
315 );
316
317 if let Some(client_key) = request.headers.get(HEADER_SEC_WEBSOCKET_KEY) {
319 let accept_key = compute_accept_key(client_key)?;
320 headers.insert(HEADER_SEC_WEBSOCKET_ACCEPT.to_string(), accept_key);
321 } else {
322 return Err(Error::Protocol(ProtocolError::MissingHeader(
323 HEADER_SEC_WEBSOCKET_KEY.to_string(),
324 )));
325 }
326
327 if !config.protocols.is_empty() {
329 if let Some(protocol_header) = request.headers.get(HEADER_SEC_WEBSOCKET_PROTOCOL) {
330 let client_protocols: Vec<&str> =
331 protocol_header.split(',').map(|s| s.trim()).collect();
332 for protocol in &config.protocols {
333 if client_protocols.contains(&protocol.as_str()) {
334 headers.insert(HEADER_SEC_WEBSOCKET_PROTOCOL.to_string(), protocol.clone());
335 break;
336 }
337 }
338 }
339 }
340
341 for (key, value) in &config.extra_headers {
343 headers.insert(key.clone(), value.clone());
344 }
345
346 Ok(HandshakeResponse {
347 status: SWITCHING_PROTOCOLS,
348 status_message: "Switching Protocols".to_string(),
349 headers,
350 body: vec![],
351 })
352}
353
354pub fn parse_server_handshake(response: &str) -> Result<HandshakeResponse, Error> {
356 let mut lines = response.lines();
357
358 let status_line = lines.next().ok_or_else(|| {
360 Error::Protocol(ProtocolError::InvalidFormat(
361 "Missing status line".to_string(),
362 ))
363 })?;
364
365 let mut parts = status_line.split_whitespace();
366 let _version = parts
367 .next()
368 .ok_or_else(|| {
369 Error::Protocol(ProtocolError::InvalidFormat(
370 "Missing HTTP version".to_string(),
371 ))
372 })?
373 .to_string();
374
375 let status_str = parts.next().ok_or_else(|| {
376 Error::Protocol(ProtocolError::InvalidFormat(
377 "Missing status code".to_string(),
378 ))
379 })?;
380
381 let status = status_str.parse::<u16>().map_err(|_| {
382 Error::Protocol(ProtocolError::InvalidFormat(
383 "Invalid status code".to_string(),
384 ))
385 })?;
386
387 let status_message = parts.collect::<Vec<&str>>().join(" ");
388
389 let mut headers = HashMap::new();
391 for line in lines {
392 if line.is_empty() {
393 break; }
395
396 if let Some((key, value)) = line.split_once(':') {
397 headers.insert(key.trim().to_lowercase(), value.trim().to_string());
398 } else {
399 return Err(Error::Protocol(ProtocolError::InvalidHeader {
400 header: "unknown".to_string(),
401 value: line.to_string(),
402 }));
403 }
404 }
405
406 Ok(HandshakeResponse {
407 status,
408 status_message,
409 headers,
410 body: vec![],
411 })
412}
413
414pub fn validate_server_handshake(
416 response: &HandshakeResponse,
417 client_key: &str,
418) -> Result<(), Error> {
419 if response.status != SWITCHING_PROTOCOLS {
421 return Err(Error::Protocol(ProtocolError::UnexpectedStatus(
422 response.status,
423 )));
424 }
425
426 let upgrade = response
428 .headers
429 .get(HEADER_UPGRADE)
430 .ok_or_else(|| Error::Protocol(ProtocolError::MissingHeader(HEADER_UPGRADE.to_string())))?;
431
432 if upgrade.to_lowercase() != http_value::WEBSOCKET {
433 return Err(Error::Protocol(ProtocolError::InvalidHeaderValue {
434 header: HEADER_UPGRADE.to_string(),
435 value: upgrade.clone(),
436 }));
437 }
438
439 let connection = response.headers.get(HEADER_CONNECTION).ok_or_else(|| {
440 Error::Protocol(ProtocolError::MissingHeader(HEADER_CONNECTION.to_string()))
441 })?;
442
443 if !connection.to_lowercase().contains("upgrade") {
444 return Err(Error::Protocol(ProtocolError::InvalidHeaderValue {
445 header: HEADER_CONNECTION.to_string(),
446 value: connection.clone(),
447 }));
448 }
449
450 let accept = response
451 .headers
452 .get(HEADER_SEC_WEBSOCKET_ACCEPT)
453 .ok_or_else(|| {
454 Error::Protocol(ProtocolError::MissingHeader(
455 HEADER_SEC_WEBSOCKET_ACCEPT.to_string(),
456 ))
457 })?;
458
459 let expected_accept = compute_accept_key(client_key)?;
460 if accept.as_str() != expected_accept {
461 return Err(Error::Protocol(ProtocolError::InvalidAcceptKey {
462 expected: expected_accept,
463 received: accept.clone(),
464 }));
465 }
466
467 Ok(())
468}
469
470pub fn request_to_string(request: &HandshakeRequest) -> String {
472 let mut lines = vec![format!(
473 "{} {} {}",
474 request.method, request.uri, request.version
475 )];
476
477 for (key, value) in &request.headers {
478 lines.push(format!("{}: {}", key, value));
479 }
480
481 lines.push(String::new()); lines.join("\r\n")
483}
484
485pub fn response_to_string(response: &HandshakeResponse) -> String {
487 let mut lines = vec![format!(
488 "HTTP/1.1 {} {}",
489 response.status, response.status_message
490 )];
491
492 for (key, value) in &response.headers {
493 lines.push(format!("{}: {}", key, value));
494 }
495
496 lines.push(String::new()); lines.join("\r\n")
498}
499
500#[cfg(test)]
501mod tests {
502 use super::*;
503
504 #[test]
505 fn test_key_generation() {
506 let key = generate_key();
507 assert_eq!(key.len(), 24);
508 assert!(validate_key(&key));
509 }
510
511 #[test]
512 fn test_accept_key_calculation() {
513 let key = "dGhlIHNhbXBsZSBub25jZQ=="; let expected = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=";
515 let accept = compute_accept_key(key).unwrap();
516 assert_eq!(accept, expected);
517 }
518
519 #[test]
520 fn test_client_handshake_creation() {
521 let config = HandshakeConfig {
522 host: Some("example.com".to_string()),
523 protocols: vec!["chat".to_string()],
524 ..Default::default()
525 };
526
527 let request = create_client_handshake("ws://example.com/chat", &config).unwrap();
528 assert_eq!(request.method, "GET");
529 assert_eq!(request.uri, "ws://example.com/chat");
530 assert_eq!(request.headers.get("upgrade").unwrap(), "websocket");
531 assert_eq!(
532 request.headers.get("sec-websocket-protocol").unwrap(),
533 "chat"
534 );
535 }
536
537 #[test]
538 fn test_client_handshake_parsing() {
539 let raw_request = r#"GET /chat HTTP/1.1
540Host: example.com
541Upgrade: websocket
542Connection: Upgrade
543Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
544Sec-WebSocket-Version: 13
545
546"#;
547
548 let request = parse_client_handshake(raw_request).unwrap();
549 assert_eq!(request.method, "GET");
550 assert_eq!(request.uri, "/chat");
551 assert_eq!(request.headers.get("upgrade").unwrap(), "websocket");
552 }
553}