1use std::{
2 io,
3 convert::TryFrom,
4 fmt,
5};
6
7use sha1::{Sha1, Digest};
8use crate::{respond, Request, Response};
9use nom::{
10 IResult,
11 bits::{
12 bits,
13 complete::take,
14 },
15};
16use futures::{
17 AsyncRead,
18 AsyncWrite,
19 AsyncWriteExt,
20 AsyncReadExt,
21};
22
23const MAX_PAYLOAD_SIZE: u64 = 16_000;
24
25#[derive(Debug, Clone)]
26pub enum WebSocketError {
27 ConnectionNotUpgrade,
28 NoConnectionHeader,
29 NoUpgradeHeader,
30 UpgradeNotToWebSocket,
31 WrongVersion,
32 NoKey,
33 TooBig,
34 ProtocolError,
35 IOError(String),
36 BadOpcode,
37 ConnectionClosed,
38}
39
40impl From<io::Error> for WebSocketError {
41 fn from(err: io::Error) -> Self {
42 if err.kind() == io::ErrorKind::UnexpectedEof {
43 WebSocketError::ConnectionClosed
44 } else {
45 WebSocketError::IOError(format!("{:?}", err))
46 }
47 }
48}
49
50impl<E> From<nom::Err<E>> for WebSocketError {
51 fn from(_err: nom::Err<E>) -> Self {
52 WebSocketError::ProtocolError
53 }
54}
55
56impl fmt::Display for WebSocketError {
57 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
58 write!(f, "problem establishing websocket connection")
59 }
60}
61
62#[derive(Debug, Clone, Copy, PartialEq)]
63pub enum MessageType {
64 Continuation,
65 Text,
66 Binary,
67 Close,
68 Ping,
69 Pong,
70}
71
72impl MessageType {
73 pub fn is_control(&self) -> bool {
74 match self {
75 MessageType::Ping | MessageType::Pong | MessageType::Close => true,
76 _ => false,
77 }
78 }
79}
80
81impl TryFrom<u8> for MessageType {
82 type Error = WebSocketError;
83
84 fn try_from(b: u8) -> Result<Self, Self::Error> {
85 Ok(match b {
86 0x0 => MessageType::Continuation,
87 0x1 => MessageType::Text,
88 0x2 => MessageType::Binary,
89 0x8 => MessageType::Close,
90 0x9 => MessageType::Ping,
91 0xA => MessageType::Pong,
92 _ => return Err(WebSocketError::BadOpcode),
93 })
94 }
95}
96
97impl Into<u8> for MessageType {
98 fn into(self) -> u8 {
99 match self {
100 MessageType::Continuation => 0x0,
101 MessageType::Text => 0x1,
102 MessageType::Binary => 0x2,
103 MessageType::Close => 0x8,
104 MessageType::Ping => 0x9,
105 MessageType::Pong => 0xA,
106 }
107 }
108}
109
110#[derive(Debug, Clone)]
111pub struct Message{
112 pub typ: MessageType,
113 pub contents: Vec<u8>,
114}
115
116pub async fn upgrade<S>(req: &Request, mut stream: S) -> Result<(WebSocketReader<S>, WebSocketWriter<S>), WebSocketError>
117where S: AsyncRead + AsyncWrite + Clone + Unpin
118{
119 match req.headers.get("Connection") {
121 Some(header) => {
122 let mut ok = false;
123 if let Ok(txt) = std::str::from_utf8(header) {
124 if let Some(_) = txt.find("Upgrade") {
125 ok = true;
126 }
127 }
128 if !ok {
129 Err(WebSocketError::ConnectionNotUpgrade)?
130 }
131 },
132 None => Err(WebSocketError::NoConnectionHeader)?,
133 };
134 match req.headers.get("Upgrade") {
135 Some(header) => if header != b"websocket" { Err(WebSocketError::UpgradeNotToWebSocket)? },
136 None => Err(WebSocketError::NoUpgradeHeader)?,
137 };
138 match req.headers.get("Sec-WebSocket-Version") {
139 Some(header) => if header != b"13" { Err(WebSocketError::WrongVersion)? },
140 None => Err(WebSocketError::WrongVersion)?,
141 };
142 let key = match req.headers.get("Sec-WebSocket-Key") {
144 Some(k) => k,
145 None => Err(WebSocketError::NoKey)?,
146 };
147 let mut hasher = Sha1::new();
148 hasher.update(&key);
149 hasher.update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
151 let result = hasher.finalize();
152 let mut headers = vec!();
153 headers.push(("Upgrade".into(), Vec::from("websocket")));
154 headers.push(("Connection".into(), Vec::from("Upgrade")));
155 headers.push(("Sec-WebSocket-Accept".into(), base64::encode(&result[..]).into()));
156 respond(&mut stream, Response{
158 code: 101,
159 reason: "Switching Protocols",
160 headers,
161 }).await?;
162 stream.flush().await?;
163 Ok((WebSocketReader{
164 stream: stream.clone(),
165 buffered_message: None,
166 }, WebSocketWriter{
167 stream,
168 }))
169}
170
171pub struct WebSocketReader<S>
172where S: AsyncRead + Unpin
173{
174 stream: S,
175 buffered_message: Option<(MessageType, Vec<u8>)>,
176}
177
178impl<S> WebSocketReader<S>
179where S: AsyncRead + Unpin
180{
181 pub async fn recv(&mut self) -> Result<Message, WebSocketError> {
182 loop {
183 let header = read_header(&mut self.stream).await?;
184 if header.payload_len > MAX_PAYLOAD_SIZE {
185 Err(WebSocketError::TooBig)?;
186 }
187 let mut contents = vec![0u8; header.payload_len as usize];
189 self.stream.read_exact(&mut contents).await?;
190 let len = contents.len();
192 for i in 0..len {
193 contents[i] = contents[i] ^ header.masking_key[i % header.masking_key.len()];
194 }
195 let typ = MessageType::try_from(header.opcode)?;
196 if typ.is_control() {
197 return Ok(Message{contents, typ});
198 }
199 if header.fin == 0 && typ != MessageType::Continuation {
201 self.buffered_message = Some((typ, contents));
202 } else if header.fin == 0 {
203 match &mut self.buffered_message {
204 Some((_, old)) => {
205 old.append(&mut contents);
206 },
207 None => return Err(WebSocketError::BadOpcode),
208 }
209 } else {
210 let (typ, contents) = self.buffered_message.take().unwrap_or((typ, contents));
211 return Ok(Message{typ, contents});
212 }
213 }
214 }
215}
216
217pub struct WebSocketWriter<S>
218where S: AsyncWrite + Unpin
219{
220 stream: S,
221}
222
223impl<S> WebSocketWriter<S>
224where S: AsyncWrite + Unpin
225{
226 pub async fn write(&mut self, msg: &Message) -> Result<(), WebSocketError> {
227 let res = WebSocketHeader{
228 fin: 1,
229 opcode: msg.typ.into(),
230 mask: 0,
231 payload_len: msg.contents.len() as u64,
232 masking_key: vec!(),
233 };
234 self.stream.write_all(&mut res.to_vec()).await?;
235 self.stream.write_all(&msg.contents).await?;
236 self.stream.flush().await?;
237 Ok(())
238 }
239}
240
241#[derive(Debug, Clone)]
242struct WebSocketHeader{
243 fin: u8,
244 opcode: u8,
245 mask: u8,
246 payload_len: u64,
247 masking_key: Vec<u8>,
248}
249
250impl WebSocketHeader {
251 fn to_vec(&self) -> Vec<u8> {
252 let mut ret = Vec::with_capacity(70);
253 ret.push((self.fin << 7) | self.opcode);
254 ret.extend(if self.payload_len < 126 {
255 vec!(self.payload_len as u8)
256 } else if self.payload_len < u16::MAX as u64 {
257 let mut ret = vec!(126u8);
258 ret.extend(&(self.payload_len as u16).to_be_bytes());
259 ret
260 } else {
261 let mut ret = vec!(127u8);
262 ret.extend(&(self.payload_len as u16).to_be_bytes());
263 ret
264 });
265 ret
266 }
267}
268
269pub async fn handle_control<S>(msg: &Message, wrt: &mut WebSocketWriter<S>) -> Result<bool, WebSocketError>
271where S: AsyncWrite + Unpin
272{
273 match msg.typ {
274 MessageType::Pong => {
275 let msg = Message{
276 typ: MessageType::Pong,
277 contents: msg.contents.clone(),
278 };
279 wrt.write(&msg).await?;
280 Ok(true)
281 },
282 MessageType::Close => {
283 Err(WebSocketError::ConnectionClosed)
284 }
285 _ => {
286 Ok(false)
287 },
288 }
289}
290
291async fn read_header<S>(stream: &mut S) -> Result<WebSocketHeader, WebSocketError>
292where S: AsyncRead + Unpin
293{
294 let mut header_fixed = vec![0u8; 2];
297 stream.read_exact(&mut header_fixed).await?;
298 let (_, mut res) = read_header_internal(&header_fixed)?;
299 header_fixed[1] &= 0b01111111;
300 if res.payload_len == 126 {
301 let mut len = [0u8; 2];
303 stream.read_exact(&mut len).await?;
304 res.payload_len = u16::from_be_bytes(len) as u64;
305 } else if res.payload_len == 127 {
306 let mut len = [0u8; 8];
308 stream.read_exact(&mut len).await?;
309 res.payload_len = u64::from_be_bytes(len) as u64;
310 }
311 if res.mask != 0 {
312 let mut mask_key = vec![0u8; 4];
313 stream.read_exact(&mut mask_key).await?;
314 res.masking_key = mask_key;
315 }
316 Ok(res)
317}
318
319fn read_header_internal(input: &[u8]) -> IResult<&[u8], WebSocketHeader> {
320 bits(read_header_internal_bits)(input)
321}
322
323fn read_header_internal_bits(input: (&[u8], usize)) -> IResult<(&[u8], usize), WebSocketHeader>
324{
325 let (input, fin) = take(1usize)(input)?;
326 let (input, _rsv1): ((&[u8], usize), u8) = take(1usize)(input)?;
327 let (input, _rsv2): ((&[u8], usize), u8) = take(1usize)(input)?;
328 let (input, _rsv3): ((&[u8], usize), u8) = take(1usize)(input)?;
329 let (input, opcode) = take(4usize)(input)?;
330 let (input, mask) = take(1usize)(input)?;
331 let (input, payload_len) = take(7usize)(input)?;
332 Ok((input, WebSocketHeader{fin, opcode, mask, payload_len, masking_key: vec!()}))
333}