use crate::{
base::{Data, Header, OpCode},
connection::Mode,
extension::{Extension, Param}
};
use flate2::{Compress, Compression, Decompress, FlushCompress, FlushDecompress};
use log::debug;
use smallvec::SmallVec;
const SERVER_NO_CONTEXT_TAKEOVER: &str = "server_no_context_takeover";
const SERVER_MAX_WINDOW_BITS: &str = "server_max_window_bits";
const CLIENT_NO_CONTEXT_TAKEOVER: &str = "client_no_context_takeover";
const CLIENT_MAX_WINDOW_BITS: &str = "client_max_window_bits";
#[derive(Debug)]
pub struct Deflate {
mode: Mode,
enabled: bool,
buffer: Vec<u8>,
params: SmallVec<[Param<'static>; 2]>,
our_max_window_bits: u8,
their_max_window_bits: u8
}
impl Deflate {
pub fn new(mode: Mode) -> Self {
let params = match mode {
Mode::Server => SmallVec::new(),
Mode::Client => {
let mut params = SmallVec::new();
params.push(Param::new(SERVER_NO_CONTEXT_TAKEOVER));
params.push(Param::new(CLIENT_NO_CONTEXT_TAKEOVER));
params.push(Param::new(CLIENT_MAX_WINDOW_BITS));
params
}
};
Deflate {
mode,
enabled: false,
buffer: Vec::new(),
params,
our_max_window_bits: 15,
their_max_window_bits: 15
}
}
pub fn set_max_server_window_bits(&mut self, max: u8) {
assert!(self.mode == Mode::Client, "setting max. server window bits requires client mode");
assert!(max > 8 && max <= 15, "max. server window bits have to be within 9 ..= 15");
self.their_max_window_bits = max; let mut p = Param::new(SERVER_MAX_WINDOW_BITS);
p.set_value(Some(max.to_string()));
self.params.push(p)
}
pub fn set_max_client_window_bits(&mut self, max: u8) {
assert!(self.mode == Mode::Client, "setting max. client window bits requires client mode");
assert!(max > 8 && max <= 15, "max. client window bits have to be within 9 ..= 15");
self.our_max_window_bits = max; if let Some(p) = self.params.iter_mut().find(|p| p.name() == CLIENT_MAX_WINDOW_BITS) {
p.set_value(Some(max.to_string()));
} else {
let mut p = Param::new(CLIENT_MAX_WINDOW_BITS);
p.set_value(Some(max.to_string()));
self.params.push(p)
}
}
fn set_their_max_window_bits(&mut self, p: &Param, expected: Option<u8>) -> Result<(), ()> {
if let Some(Ok(v)) = p.value().map(|s| s.parse::<u8>()) {
if v < 8 || v > 15 {
debug!("invalid {}: {} (expected range: 8 ..= 15)", p.name(), v);
return Err(())
}
if let Some(x) = expected {
if v > x {
debug!("invalid {}: {} (expected: {} <= {})", p.name(), v, v, x);
return Err(())
}
}
self.their_max_window_bits = std::cmp::max(9, v);
}
Ok(())
}
}
impl Extension for Deflate {
fn name(&self) -> &str {
"permessage-deflate"
}
fn is_enabled(&self) -> bool {
self.enabled
}
fn params(&self) -> &[Param] {
&self.params
}
fn configure(&mut self, params: &[Param]) -> Result<(), crate::BoxError> {
match self.mode {
Mode::Server => {
self.params.clear();
for p in params {
match p.name() {
CLIENT_MAX_WINDOW_BITS =>
if self.set_their_max_window_bits(&p, None).is_err() {
return Ok(())
}
SERVER_MAX_WINDOW_BITS => {
if let Some(Ok(v)) = p.value().map(|s| s.parse::<u8>()) {
if v < 9 || v > 15 {
debug!("unacceptable server_max_window_bits: {}", v);
return Ok(())
}
let mut x = Param::new(SERVER_MAX_WINDOW_BITS);
x.set_value(Some(v.to_string()));
self.params.push(x);
self.our_max_window_bits = v;
} else {
debug!("invalid server_max_window_bits: {:?}", p.value());
return Ok(())
}
}
CLIENT_NO_CONTEXT_TAKEOVER =>
self.params.push(Param::new(CLIENT_NO_CONTEXT_TAKEOVER)),
SERVER_NO_CONTEXT_TAKEOVER =>
self.params.push(Param::new(SERVER_NO_CONTEXT_TAKEOVER)),
_ => {
debug!("{}: unknown parameter: {}", self.name(), p.name());
return Ok(())
}
}
}
}
Mode::Client => {
let mut server_no_context_takeover = false;
for p in params {
match p.name() {
SERVER_NO_CONTEXT_TAKEOVER => server_no_context_takeover = true,
CLIENT_NO_CONTEXT_TAKEOVER => {} SERVER_MAX_WINDOW_BITS => {
let expected = Some(self.their_max_window_bits);
if self.set_their_max_window_bits(&p, expected).is_err() {
return Ok(())
}
}
CLIENT_MAX_WINDOW_BITS =>
if let Some(Ok(v)) = p.value().map(|s| s.parse::<u8>()) {
if v < 8 || v > 15 {
debug!("unacceptable client_max_window_bits: {}", v);
return Ok(())
}
use std::cmp::{min, max};
self.our_max_window_bits = min(self.our_max_window_bits, max(9, v));
}
_ => {
debug!("{}: unknown parameter: {}", self.name(), p.name());
return Ok(())
}
}
}
if !server_no_context_takeover {
debug!("{}: server did not confirm no context takeover", self.name());
return Ok(())
}
}
}
self.enabled = true;
Ok(())
}
fn reserved_bits(&self) -> (bool, bool, bool) {
(true, false, false)
}
fn decode(&mut self, hdr: &mut Header, data: &mut Option<Data>) -> Result<(), crate::BoxError> {
match hdr.opcode() {
OpCode::Binary | OpCode::Text if hdr.is_rsv1() && hdr.is_fin() => {}
OpCode::Continue if hdr.is_fin() => {}
_ => return Ok(())
}
if let Some(data) = data {
data.bytes_mut().extend_from_slice(&[0, 0, 0xFF, 0xFF]); self.buffer.clear();
let mut d = Decompress::new_with_window_bits(false, self.their_max_window_bits);
while (d.total_in() as usize) < data.as_ref().len() {
let off = d.total_in() as usize;
self.buffer.reserve(data.as_ref().len() - off);
d.decompress_vec(&data.as_ref()[off ..], &mut self.buffer, FlushDecompress::Sync)?;
}
data.bytes_mut().clear();
data.bytes_mut().extend_from_slice(&self.buffer);
hdr.set_rsv1(false);
}
Ok(())
}
fn encode(&mut self, hdr: &mut Header, data: &mut Option<Data>) -> Result<(), crate::BoxError> {
match hdr.opcode() {
OpCode::Text | OpCode::Binary => {},
_ => return Ok(())
}
if let Some(data) = data {
let mut c = Compress::new_with_window_bits(Compression::fast(), false, self.our_max_window_bits);
self.buffer.clear();
while (c.total_in() as usize) < data.as_ref().len() {
let off = c.total_in() as usize;
self.buffer.reserve(data.as_ref().len() - off);
c.compress_vec(&data.as_ref()[off ..], &mut self.buffer, FlushCompress::Sync)?;
}
if self.buffer.capacity() - self.buffer.len() < 5 {
self.buffer.reserve(5); c.compress_vec(&[], &mut self.buffer, FlushCompress::Sync)?;
}
let n = self.buffer.len() - 4;
self.buffer.truncate(n); data.bytes_mut().clear();
data.bytes_mut().extend_from_slice(&self.buffer);
hdr.set_rsv1(true);
}
Ok(())
}
}