twist/proto/server/
fragmented.rs1use bytes::BytesMut;
3use extension::{PerFrameExtensions, PerMessageExtensions};
4use frame::WebSocket;
5use frame::base::{Frame, OpCode};
6use futures::{Async, Poll, Sink, StartSend, Stream};
7use slog::Logger;
8use std::{io, str};
9use util;
10use uuid::Uuid;
11use vatfluid::{Success, validate};
12
13pub struct Fragmented<T> {
15 uuid: Uuid,
17 upstream: T,
19 started: bool,
21 complete: bool,
23 opcode: OpCode,
25 buf: BytesMut,
27 pos: usize,
29 permessage_extensions: PerMessageExtensions,
31 #[allow(dead_code)]
33 perframe_extensions: PerFrameExtensions,
34 stdout: Option<Logger>,
36 stderr: Option<Logger>,
38}
39
40impl<T> Fragmented<T> {
41 pub fn new(upstream: T,
43 uuid: Uuid,
44 permessage_extensions: PerMessageExtensions,
45 perframe_extensions: PerFrameExtensions)
46 -> Fragmented<T> {
47 Fragmented {
48 uuid: uuid,
49 upstream: upstream,
50 started: false,
51 complete: false,
52 opcode: OpCode::Close,
53 buf: BytesMut::with_capacity(1024),
54 pos: 0,
55 permessage_extensions: permessage_extensions,
56 perframe_extensions: perframe_extensions,
57 stdout: None,
58 stderr: None,
59 }
60 }
61
62 pub fn stdout(&mut self, logger: Logger) -> &mut Fragmented<T> {
64 let stdout = logger.new(o!("proto" => "fragmented"));
65 self.stdout = Some(stdout);
66 self
67 }
68
69 pub fn stderr(&mut self, logger: Logger) -> &mut Fragmented<T> {
71 let stderr = logger.new(o!("proto" => "fragmented"));
72 self.stderr = Some(stderr);
73 self
74 }
75
76 fn ext_chain_decode(&self, frame: &mut Frame) -> Result<(), io::Error> {
78 let opcode = frame.opcode();
79 if frame.fin() && (opcode == OpCode::Text || opcode == OpCode::Binary) {
81 let pm_lock = self.permessage_extensions.clone();
82 let mut map = match pm_lock.lock() {
83 Ok(guard) => guard,
84 Err(poisoned) => poisoned.into_inner(),
85 };
86 let vec_pm_exts = map.entry(self.uuid).or_insert_with(Vec::new);
87 for ext in vec_pm_exts.iter_mut() {
88 if ext.enabled() {
89 ext.decode(frame)?;
90 }
91 }
92 }
93 Ok(())
94 }
95}
96
97impl<T> Stream for Fragmented<T>
98 where T: Stream<Item = WebSocket, Error = io::Error>,
99 T: Sink<SinkItem = WebSocket, SinkError = io::Error>
100{
101 type Item = WebSocket;
102 type Error = io::Error;
103
104 fn poll(&mut self) -> Poll<Option<WebSocket>, io::Error> {
105 loop {
106 match try_ready!(self.upstream.poll()) {
107 Some(ref msg) if msg.is_fragment_start() => {
108 if let Some(base) = msg.base() {
109 try_trace!(self.stdout, "fragment start frame received");
110 self.opcode = base.opcode();
111 self.started = true;
112 self.buf.extend(base.application_data());
113 self.poll_complete()?;
114 } else {
115 return Err(util::other("invalid fragment start frame received"));
116 }
117 }
118 Some(ref msg) if msg.is_fragment() => {
119 if !self.started || self.complete {
120 return Err(util::other("invalid fragment frame received"));
121 }
122
123 if let Some(base) = msg.base() {
124 try_trace!(self.stdout, "fragment continuation frame received");
125 self.buf.extend(base.application_data());
126
127 if self.opcode == OpCode::Text && self.buf.len() < 8192 {
128 try_trace!(self.stdout, "validating from pos: {}", self.pos);
129 match validate(&self.buf[self.pos..]) {
130 Ok(Success::Complete(pos)) => {
131 try_trace!(self.stdout, "complete: {}", pos);
132 self.pos += pos;
133 }
134 Ok(Success::Incomplete(_, pos)) => {
135 try_trace!(self.stdout, "incomplete: {}", pos);
136 self.pos += pos;
137 }
138 Err(e) => {
139 try_error!(self.stderr, "{}", e);
140 return Err(util::other("invalid utf-8 sequence"));
141 }
142 }
143 }
144 self.poll_complete()?;
145 } else {
146 return Err(util::other("invalid fragment frame received"));
147 }
148 }
149 Some(ref msg) if msg.is_fragment_complete() => {
150 if !self.started || self.complete {
151 return Err(util::other("invalid fragment complete frame received"));
152 }
153 if let Some(base) = msg.base() {
154 try_trace!(self.stdout, "fragment finish frame received");
155 self.complete = true;
156 self.buf.extend(base.application_data());
157 self.poll_complete()?;
158 } else {
159 return Err(util::other("invalid fragment complete frame received"));
160 }
161 }
162 Some(ref msg) if msg.is_badfragment() => {
163 if self.started && !self.complete {
164 return Err(util::other("invalid opcode for continuation fragment"));
165 }
166 return Ok(Async::Ready(Some(msg.clone())));
167 }
168 m => return Ok(Async::Ready(m)),
169 }
170 }
171 }
172}
173
174impl<T> Sink for Fragmented<T>
175 where T: Sink<SinkItem = WebSocket, SinkError = io::Error>
176{
177 type SinkItem = WebSocket;
178 type SinkError = io::Error;
179
180 fn start_send(&mut self, item: WebSocket) -> StartSend<WebSocket, io::Error> {
181 self.upstream.start_send(item)
182 }
183
184 fn poll_complete(&mut self) -> Poll<(), io::Error> {
185 if self.started && self.complete {
186 let mut message: WebSocket = Default::default();
187
188 let mut base: Frame = Default::default();
190 base.set_fin(true).set_opcode(self.opcode);
191 base.set_application_data(self.buf.to_vec());
192 base.set_payload_length(self.buf.len() as u64);
193
194 if base.opcode() == OpCode::Text && base.fin() {
196 match validate(&self.buf[self.pos..]) {
197 Ok(Success::Complete(_)) => {}
198 Ok(Success::Incomplete(_, pos)) => {
199 try_error!(self.stderr, "incomplete: {}", pos);
200 return Err(util::other("invalid utf-8 sequence"));
201 }
202 Err(e) => {
203 try_error!(self.stderr, "{}", e);
204 return Err(util::other("invalid utf-8 sequence"));
205 }
206 }
207 }
208
209 self.ext_chain_decode(&mut base)?;
211
212 message.set_base(base);
213
214 self.upstream.start_send(message)?;
216
217 self.started = false;
219 self.complete = false;
220 self.opcode = OpCode::Close;
221 self.pos = 0;
222 self.buf.clear();
223
224 try_trace!(self.stdout, "fragment completed sending result upstream");
225 }
226 self.upstream.poll_complete()
227 }
228}