blitz_ws/handshake/
core.rs

1//! WebSocket handshake control
2
3use std::{
4    fmt::{Debug, Display},
5    io::{Read, Write},
6};
7
8use base64::Engine;
9use sha1::{Digest, Sha1};
10
11use crate::{
12    error::{Error, Result},
13    handshake::machine::{HandshakeMachine, RoundResult, StageResult, TryParse},
14};
15
16/// A WebSocket Handshake
17#[derive(Debug)]
18pub struct MidHandshake<Role: HandshakeRole> {
19    /// The Handshake role
20    pub role: Role,
21    /// The handshake machine instance
22    pub machine: HandshakeMachine<Role::InternalStream>,
23}
24
25impl<Role: HandshakeRole> MidHandshake<Role> {
26    /// Allows access to the machine
27    pub fn get_ref(&self) -> &HandshakeMachine<Role::InternalStream> {
28        &self.machine
29    }
30
31    /// Allows mutable access to the machine
32    pub fn get_mut(&mut self) -> &mut HandshakeMachine<Role::InternalStream> {
33        &mut self.machine
34    }
35
36    /// Restarts the handshake process
37    pub fn handshake(mut self) -> Result<Role::FinalResult, HandshakeError<Role>> {
38        let mut machine = self.machine;
39
40        loop {
41            machine = match machine.single_round()? {
42                RoundResult::WouldBlock(m) => {
43                    return Err(HandshakeError::Interrupted(MidHandshake { machine: m, ..self }))
44                }
45                RoundResult::Incomplete(m) => m,
46                RoundResult::StageFinished(s) => match self.role.stage_finished(s)? {
47                    ProcessingResult::Continue(m) => m,
48                    ProcessingResult::Done(res) => return Ok(res),
49                },
50            }
51        }
52    }
53}
54
55/// A handshake result
56pub enum HandshakeError<Role: HandshakeRole> {
57    /// Handshake was interrupted (would block)
58    Interrupted(MidHandshake<Role>),
59    /// Handshake failed
60    Failure(Error),
61}
62
63impl<Role: HandshakeRole> Debug for HandshakeError<Role> {
64    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65        match self {
66            Self::Interrupted(_) => write!(f, "HandshakeError::Interrupted(...)"),
67            Self::Failure(e) => write!(f, "HandshakeError::Failure({:?})", e),
68        }
69    }
70}
71
72impl<Role: HandshakeRole> Display for HandshakeError<Role> {
73    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74        match self {
75            Self::Interrupted(_) => write!(f, "Interrupted handshake (WouldBlock)"),
76            Self::Failure(e) => write!(f, "{e}"),
77        }
78    }
79}
80
81impl<Role: HandshakeRole> std::error::Error for HandshakeError<Role> {}
82
83impl<Role: HandshakeRole> From<Error> for HandshakeError<Role> {
84    fn from(value: Error) -> Self {
85        HandshakeError::Failure(value)
86    }
87}
88
89/// Handshake Role
90pub trait HandshakeRole {
91    #[doc(hidden)]
92    type IncomingData: TryParse;
93    #[doc(hidden)]
94    type InternalStream: Read + Write;
95    #[doc(hidden)]
96    type FinalResult;
97
98    #[doc(hidden)]
99    fn stage_finished(
100        &mut self,
101        finish: StageResult<Self::IncomingData, Self::InternalStream>,
102    ) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>>;
103}
104
105#[doc(hidden)]
106#[derive(Debug)]
107pub enum ProcessingResult<Stream, FinalResult> {
108    Continue(HandshakeMachine<Stream>),
109    Done(FinalResult),
110}
111
112/// Derives the `Sec-WebSocket-Accept` header value from a `Sec-WebSocket-Key` request header.
113///
114/// This function can be used to perform a handshake before passing a raw TCP stream to
115/// [`WebSocket::with_config`][crate::protocol::WebSocket::with_config]
116pub fn derive_accept_key(req_key: &[u8]) -> String {
117    const WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
118
119    let mut hasher = Sha1::default();
120    <Sha1 as Digest>::update(&mut hasher, req_key);
121    <Sha1 as Digest>::update(&mut hasher, WS_GUID);
122
123    base64::engine::general_purpose::STANDARD.encode(<Sha1 as Digest>::finalize(hasher))
124}