twist/proto/server/
fragmented.rs

1//! The `Fragmented` protocol middleware.
2use bytes::BytesMut;
3use extension::{PerFrameExtensions, PerMessageExtensions};
4use frame::WebSocket;
5use frame::base::{Frame, OpCode};
6use futures::{Async, Poll, Sink, StartSend, Stream};
7use slog::Logger;
8use std::{io, str};
9use util;
10use uuid::Uuid;
11use vatfluid::{Success, validate};
12
13/// The `Fragmented` struct.
14pub struct Fragmented<T> {
15    /// The Uuid for the protocol chain.
16    uuid: Uuid,
17    /// The upstream protocol.
18    upstream: T,
19    /// Has the fragmented message started?
20    started: bool,
21    /// Is the fragmented message complete?
22    complete: bool,
23    /// The `OpCode` from the original message.
24    opcode: OpCode,
25    /// The buffer used to store the fragmented data.
26    buf: BytesMut,
27    /// The position in our buffer that we have validated in the case of a text frame.
28    pos: usize,
29    /// Per-message extensions
30    permessage_extensions: PerMessageExtensions,
31    /// Per-frame extensions
32    #[allow(dead_code)]
33    perframe_extensions: PerFrameExtensions,
34    /// slog stdout `Logger`
35    stdout: Option<Logger>,
36    /// slog stderr `Logger`
37    stderr: Option<Logger>,
38}
39
40impl<T> Fragmented<T> {
41    /// Create a new `Fragmented` protocol middleware.
42    pub fn new(upstream: T,
43               uuid: Uuid,
44               permessage_extensions: PerMessageExtensions,
45               perframe_extensions: PerFrameExtensions)
46               -> Fragmented<T> {
47        Fragmented {
48            uuid: uuid,
49            upstream: upstream,
50            started: false,
51            complete: false,
52            opcode: OpCode::Close,
53            buf: BytesMut::with_capacity(1024),
54            pos: 0,
55            permessage_extensions: permessage_extensions,
56            perframe_extensions: perframe_extensions,
57            stdout: None,
58            stderr: None,
59        }
60    }
61
62    /// Add a stdout slog `Logger` to this protocol.
63    pub fn stdout(&mut self, logger: Logger) -> &mut Fragmented<T> {
64        let stdout = logger.new(o!("proto" => "fragmented"));
65        self.stdout = Some(stdout);
66        self
67    }
68
69    /// Add a stderr slog `Logger` to this protocol.
70    pub fn stderr(&mut self, logger: Logger) -> &mut Fragmented<T> {
71        let stderr = logger.new(o!("proto" => "fragmented"));
72        self.stderr = Some(stderr);
73        self
74    }
75
76    /// Run the extension chain decode on the given `base::Frame`.
77    fn ext_chain_decode(&self, frame: &mut Frame) -> Result<(), io::Error> {
78        let opcode = frame.opcode();
79        // Only run the chain if this is a Text/Binary finish frame.
80        if frame.fin() && (opcode == OpCode::Text || opcode == OpCode::Binary) {
81            let pm_lock = self.permessage_extensions.clone();
82            let mut map = match pm_lock.lock() {
83                Ok(guard) => guard,
84                Err(poisoned) => poisoned.into_inner(),
85            };
86            let vec_pm_exts = map.entry(self.uuid).or_insert_with(Vec::new);
87            for ext in vec_pm_exts.iter_mut() {
88                if ext.enabled() {
89                    ext.decode(frame)?;
90                }
91            }
92        }
93        Ok(())
94    }
95}
96
97impl<T> Stream for Fragmented<T>
98    where T: Stream<Item = WebSocket, Error = io::Error>,
99          T: Sink<SinkItem = WebSocket, SinkError = io::Error>
100{
101    type Item = WebSocket;
102    type Error = io::Error;
103
104    fn poll(&mut self) -> Poll<Option<WebSocket>, io::Error> {
105        loop {
106            match try_ready!(self.upstream.poll()) {
107                Some(ref msg) if msg.is_fragment_start() => {
108                    if let Some(base) = msg.base() {
109                        try_trace!(self.stdout, "fragment start frame received");
110                        self.opcode = base.opcode();
111                        self.started = true;
112                        self.buf.extend(base.application_data());
113                        self.poll_complete()?;
114                    } else {
115                        return Err(util::other("invalid fragment start frame received"));
116                    }
117                }
118                Some(ref msg) if msg.is_fragment() => {
119                    if !self.started || self.complete {
120                        return Err(util::other("invalid fragment frame received"));
121                    }
122
123                    if let Some(base) = msg.base() {
124                        try_trace!(self.stdout, "fragment continuation frame received");
125                        self.buf.extend(base.application_data());
126
127                        if self.opcode == OpCode::Text && self.buf.len() < 8192 {
128                            try_trace!(self.stdout, "validating from pos: {}", self.pos);
129                            match validate(&self.buf[self.pos..]) {
130                                Ok(Success::Complete(pos)) => {
131                                    try_trace!(self.stdout, "complete: {}", pos);
132                                    self.pos += pos;
133                                }
134                                Ok(Success::Incomplete(_, pos)) => {
135                                    try_trace!(self.stdout, "incomplete: {}", pos);
136                                    self.pos += pos;
137                                }
138                                Err(e) => {
139                                    try_error!(self.stderr, "{}", e);
140                                    return Err(util::other("invalid utf-8 sequence"));
141                                }
142                            }
143                        }
144                        self.poll_complete()?;
145                    } else {
146                        return Err(util::other("invalid fragment frame received"));
147                    }
148                }
149                Some(ref msg) if msg.is_fragment_complete() => {
150                    if !self.started || self.complete {
151                        return Err(util::other("invalid fragment complete frame received"));
152                    }
153                    if let Some(base) = msg.base() {
154                        try_trace!(self.stdout, "fragment finish frame received");
155                        self.complete = true;
156                        self.buf.extend(base.application_data());
157                        self.poll_complete()?;
158                    } else {
159                        return Err(util::other("invalid fragment complete frame received"));
160                    }
161                }
162                Some(ref msg) if msg.is_badfragment() => {
163                    if self.started && !self.complete {
164                        return Err(util::other("invalid opcode for continuation fragment"));
165                    }
166                    return Ok(Async::Ready(Some(msg.clone())));
167                }
168                m => return Ok(Async::Ready(m)),
169            }
170        }
171    }
172}
173
174impl<T> Sink for Fragmented<T>
175    where T: Sink<SinkItem = WebSocket, SinkError = io::Error>
176{
177    type SinkItem = WebSocket;
178    type SinkError = io::Error;
179
180    fn start_send(&mut self, item: WebSocket) -> StartSend<WebSocket, io::Error> {
181        self.upstream.start_send(item)
182    }
183
184    fn poll_complete(&mut self) -> Poll<(), io::Error> {
185        if self.started && self.complete {
186            let mut message: WebSocket = Default::default();
187
188            // Setup the `Frame` to pass upstream.
189            let mut base: Frame = Default::default();
190            base.set_fin(true).set_opcode(self.opcode);
191            base.set_application_data(self.buf.to_vec());
192            base.set_payload_length(self.buf.len() as u64);
193
194            // Validate utf-8 here to allow pre-processing of appdata by extension chain.
195            if base.opcode() == OpCode::Text && base.fin() {
196                match validate(&self.buf[self.pos..]) {
197                    Ok(Success::Complete(_)) => {}
198                    Ok(Success::Incomplete(_, pos)) => {
199                        try_error!(self.stderr, "incomplete: {}", pos);
200                        return Err(util::other("invalid utf-8 sequence"));
201                    }
202                    Err(e) => {
203                        try_error!(self.stderr, "{}", e);
204                        return Err(util::other("invalid utf-8 sequence"));
205                    }
206                }
207            }
208
209            // Run the `Frame` through the extension decode chain.
210            self.ext_chain_decode(&mut base)?;
211
212            message.set_base(base);
213
214            // Send it upstream
215            self.upstream.start_send(message)?;
216
217            // Reset my state.
218            self.started = false;
219            self.complete = false;
220            self.opcode = OpCode::Close;
221            self.pos = 0;
222            self.buf.clear();
223
224            try_trace!(self.stdout, "fragment completed sending result upstream");
225        }
226        self.upstream.poll_complete()
227    }
228}