use std::time::Duration;
use crate::compression::WebSocketExtensions;
pub type CompressionLevel = flate2::Compression;
#[derive(Clone, Default)]
pub struct Options {
pub max_payload_read: Option<usize>,
pub max_read_buffer: Option<usize>,
pub compression: Option<DeflateOptions>,
pub fragmentation: Option<Fragmentation>,
pub check_utf8: bool,
pub no_delay: bool,
pub max_backpressure_write_boundary: Option<usize>,
}
#[derive(Clone, Default, Debug)]
pub struct Fragmentation {
pub timeout: Option<Duration>,
pub fragment_size: Option<usize>,
}
impl Options {
pub fn with_limits(self, max_payload: usize, max_buffer: usize) -> Self {
Self {
max_payload_read: Some(max_payload),
max_read_buffer: Some(max_buffer),
..self
}
}
pub fn with_low_latency_compression(self) -> Self {
Self {
compression: Some(DeflateOptions::low_latency()),
..self
}
}
pub fn with_high_compression(self) -> Self {
Self {
compression: Some(DeflateOptions::high_compression()),
..self
}
}
pub fn with_balanced_compression(self) -> Self {
Self {
compression: Some(DeflateOptions::balanced()),
..self
}
}
pub fn with_compression_level(self, level: CompressionLevel) -> Self {
let mut compression = self.compression.unwrap_or_default();
compression.level = level;
Self {
compression: Some(compression),
..self
}
}
pub fn without_compression(self) -> Self {
Self {
compression: None,
..self
}
}
pub fn with_max_payload_read(self, size: usize) -> Self {
Self {
max_payload_read: Some(size),
..self
}
}
pub fn with_max_read_buffer(self, size: usize) -> Self {
Self {
max_read_buffer: Some(size),
..self
}
}
pub fn with_utf8(self) -> Self {
Self {
check_utf8: true,
..self
}
}
pub fn with_no_delay(self) -> Self {
Self {
no_delay: true,
..self
}
}
pub fn with_fragment_timeout(self, timeout: Duration) -> Self {
let mut fragmentation = self.fragmentation.unwrap_or_default();
fragmentation.timeout = Some(timeout);
Self {
fragmentation: Some(fragmentation),
..self
}
}
pub fn with_max_fragment_size(self, size: usize) -> Self {
let mut fragmentation = self.fragmentation.unwrap_or_default();
fragmentation.fragment_size = Some(size);
Self {
fragmentation: Some(fragmentation),
..self
}
}
pub fn with_backpressure_boundary(self, size: usize) -> Self {
Self {
max_backpressure_write_boundary: Some(size),
..self
}
}
#[cfg(feature = "zlib")]
pub fn with_client_max_window_bits(self, max_window_bits: u8) -> Self {
let mut compression = self.compression.unwrap_or_default();
compression.client_max_window_bits = Some(max_window_bits);
Self {
compression: Some(compression),
..self
}
}
#[cfg(feature = "zlib")]
pub fn with_server_max_window_bits(self, max_window_bits: u8) -> Self {
let mut compression = self.compression.unwrap_or_default();
compression.server_max_window_bits = Some(max_window_bits);
Self {
compression: Some(compression),
..self
}
}
pub fn server_no_context_takeover(self) -> Self {
let mut compression = self.compression.unwrap_or_default();
compression.server_no_context_takeover = true;
Self {
compression: Some(compression),
..self
}
}
pub fn client_no_context_takeover(self) -> Self {
let mut compression = self.compression.unwrap_or_default();
compression.client_no_context_takeover = true;
Self {
compression: Some(compression),
..self
}
}
}
#[derive(Clone, Default)]
pub struct DeflateOptions {
pub level: CompressionLevel,
#[cfg(feature = "zlib")]
pub server_max_window_bits: Option<u8>,
#[cfg(feature = "zlib")]
pub client_max_window_bits: Option<u8>,
pub server_no_context_takeover: bool,
pub client_no_context_takeover: bool,
}
impl DeflateOptions {
pub fn low_latency() -> Self {
Self {
level: CompressionLevel::fast(),
#[cfg(feature = "zlib")]
server_max_window_bits: None,
#[cfg(feature = "zlib")]
client_max_window_bits: None,
server_no_context_takeover: false,
client_no_context_takeover: false,
}
}
pub fn high_compression() -> Self {
Self {
level: CompressionLevel::best(),
#[cfg(feature = "zlib")]
server_max_window_bits: None,
#[cfg(feature = "zlib")]
client_max_window_bits: None,
server_no_context_takeover: false,
client_no_context_takeover: false,
}
}
pub fn balanced() -> Self {
Self {
level: CompressionLevel::default(),
#[cfg(feature = "zlib")]
server_max_window_bits: None,
#[cfg(feature = "zlib")]
client_max_window_bits: None,
server_no_context_takeover: false,
client_no_context_takeover: false,
}
}
pub(super) fn merge(&self, offered: &WebSocketExtensions) -> WebSocketExtensions {
WebSocketExtensions {
client_no_context_takeover: offered.client_no_context_takeover
|| self.client_no_context_takeover,
server_no_context_takeover: offered.server_no_context_takeover
|| self.server_no_context_takeover,
#[cfg(feature = "zlib")]
client_max_window_bits: match (
offered.client_max_window_bits,
self.client_max_window_bits,
) {
(Some(Some(c)), Some(s)) => Some(Some(c.min(s))),
(Some(Some(c)), None) => Some(Some(c)),
(Some(None), Some(s)) => Some(Some(s)),
(Some(None), None) => Some(Some(9)),
(None, s) => s.map(Some),
},
#[cfg(feature = "zlib")]
server_max_window_bits: match (
offered.server_max_window_bits,
self.server_max_window_bits,
) {
(Some(Some(c)), Some(s)) => Some(Some(c.min(s))),
(Some(Some(c)), None) => Some(Some(c)),
(Some(None), Some(s)) => Some(Some(s)),
(Some(None), None) => Some(Some(9)),
(None, s) => s.map(Some),
},
#[cfg(not(feature = "zlib"))]
client_max_window_bits: None,
#[cfg(not(feature = "zlib"))]
server_max_window_bits: None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_merge_no_context_takeover() {
let server = DeflateOptions {
level: CompressionLevel::default(),
#[cfg(feature = "zlib")]
server_max_window_bits: None,
#[cfg(feature = "zlib")]
client_max_window_bits: None,
server_no_context_takeover: false,
client_no_context_takeover: false,
};
let client_offer = WebSocketExtensions {
server_max_window_bits: None,
client_max_window_bits: None,
server_no_context_takeover: true,
client_no_context_takeover: true,
};
let merged = server.merge(&client_offer);
assert!(merged.server_no_context_takeover);
assert!(merged.client_no_context_takeover);
}
#[test]
fn test_merge_no_context_takeover_server_requires() {
let server = DeflateOptions {
level: CompressionLevel::default(),
#[cfg(feature = "zlib")]
server_max_window_bits: None,
#[cfg(feature = "zlib")]
client_max_window_bits: None,
server_no_context_takeover: true,
client_no_context_takeover: false,
};
let client_offer = WebSocketExtensions {
server_max_window_bits: None,
client_max_window_bits: None,
server_no_context_takeover: false,
client_no_context_takeover: false,
};
let merged = server.merge(&client_offer);
assert!(merged.server_no_context_takeover);
assert!(!merged.client_no_context_takeover);
}
#[cfg(feature = "zlib")]
#[test]
fn test_merge_window_bits_takes_minimum() {
let server = DeflateOptions {
level: CompressionLevel::default(),
server_max_window_bits: Some(15),
client_max_window_bits: Some(15),
server_no_context_takeover: false,
client_no_context_takeover: false,
};
let client_offer = WebSocketExtensions {
server_max_window_bits: Some(Some(12)),
client_max_window_bits: Some(Some(10)),
server_no_context_takeover: false,
client_no_context_takeover: false,
};
let merged = server.merge(&client_offer);
assert_eq!(merged.server_max_window_bits, Some(Some(12)));
assert_eq!(merged.client_max_window_bits, Some(Some(10)));
}
#[cfg(feature = "zlib")]
#[test]
fn test_merge_window_bits_client_only() {
let server = DeflateOptions {
level: CompressionLevel::default(),
#[cfg(feature = "zlib")]
server_max_window_bits: None,
#[cfg(feature = "zlib")]
client_max_window_bits: None,
server_no_context_takeover: false,
client_no_context_takeover: false,
};
let client_offer = WebSocketExtensions {
server_max_window_bits: Some(Some(12)),
client_max_window_bits: Some(Some(10)),
server_no_context_takeover: false,
client_no_context_takeover: false,
};
let merged = server.merge(&client_offer);
assert_eq!(merged.server_max_window_bits, Some(Some(12)));
assert_eq!(merged.client_max_window_bits, Some(Some(10)));
}
#[cfg(feature = "zlib")]
#[test]
fn test_merge_window_bits_server_only() {
let server = DeflateOptions {
level: CompressionLevel::default(),
#[cfg(feature = "zlib")]
server_max_window_bits: Some(14),
#[cfg(feature = "zlib")]
client_max_window_bits: Some(13),
server_no_context_takeover: false,
client_no_context_takeover: false,
};
let client_offer = WebSocketExtensions {
server_max_window_bits: None,
client_max_window_bits: None,
server_no_context_takeover: false,
client_no_context_takeover: false,
};
let merged = server.merge(&client_offer);
assert_eq!(merged.server_max_window_bits, Some(Some(14)));
assert_eq!(merged.client_max_window_bits, Some(Some(13)));
}
#[cfg(feature = "zlib")]
#[test]
fn test_merge_window_bits_none_both() {
let server = DeflateOptions {
level: CompressionLevel::default(),
#[cfg(feature = "zlib")]
server_max_window_bits: None,
#[cfg(feature = "zlib")]
client_max_window_bits: None,
server_no_context_takeover: false,
client_no_context_takeover: false,
};
let client_offer = WebSocketExtensions {
server_max_window_bits: None,
client_max_window_bits: None,
server_no_context_takeover: false,
client_no_context_takeover: false,
};
let merged = server.merge(&client_offer);
assert_eq!(merged.server_max_window_bits, None);
assert_eq!(merged.client_max_window_bits, None);
}
#[test]
fn test_merge_mixed_options() {
#[cfg(feature = "zlib")]
let server = DeflateOptions {
level: CompressionLevel::default(),
server_max_window_bits: Some(15),
client_max_window_bits: Some(14),
server_no_context_takeover: true,
client_no_context_takeover: false,
};
#[cfg(not(feature = "zlib"))]
let server = DeflateOptions {
level: CompressionLevel::default(),
server_no_context_takeover: true,
client_no_context_takeover: false,
};
#[cfg(feature = "zlib")]
let client_offer = WebSocketExtensions {
server_max_window_bits: Some(Some(12)),
client_max_window_bits: Some(Some(13)),
server_no_context_takeover: false,
client_no_context_takeover: true,
};
#[cfg(not(feature = "zlib"))]
let client_offer = WebSocketExtensions {
server_max_window_bits: None,
client_max_window_bits: None,
server_no_context_takeover: false,
client_no_context_takeover: true,
};
let merged = server.merge(&client_offer);
assert!(merged.server_no_context_takeover);
assert!(merged.client_no_context_takeover);
#[cfg(feature = "zlib")]
{
assert_eq!(merged.server_max_window_bits, Some(Some(12)));
assert_eq!(merged.client_max_window_bits, Some(Some(13)));
}
#[cfg(not(feature = "zlib"))]
{
assert_eq!(merged.server_max_window_bits, None);
assert_eq!(merged.client_max_window_bits, None);
}
}
#[cfg(feature = "zlib")]
#[test]
fn test_merge_client_offers_no_value() {
let server = DeflateOptions {
level: CompressionLevel::default(),
server_max_window_bits: Some(14),
client_max_window_bits: Some(13),
server_no_context_takeover: false,
client_no_context_takeover: false,
};
let client_offer = WebSocketExtensions {
server_max_window_bits: Some(None), client_max_window_bits: Some(None), server_no_context_takeover: false,
client_no_context_takeover: false,
};
let merged = server.merge(&client_offer);
assert_eq!(merged.server_max_window_bits, Some(Some(14)));
assert_eq!(merged.client_max_window_bits, Some(Some(13)));
}
#[cfg(feature = "zlib")]
#[test]
fn test_merge_client_offers_no_value_server_no_preference() {
let server = DeflateOptions {
level: CompressionLevel::default(),
server_max_window_bits: None,
client_max_window_bits: None,
server_no_context_takeover: false,
client_no_context_takeover: false,
};
let client_offer = WebSocketExtensions {
server_max_window_bits: Some(None), client_max_window_bits: Some(None), server_no_context_takeover: false,
client_no_context_takeover: false,
};
let merged = server.merge(&client_offer);
assert_eq!(merged.server_max_window_bits, Some(Some(9)));
assert_eq!(merged.client_max_window_bits, Some(Some(9)));
}
#[test]
fn test_parse_and_merge_client_offers_client_max_window_bits_no_value() {
use std::str::FromStr;
let client_offer =
WebSocketExtensions::from_str("permessage-deflate; client_max_window_bits").unwrap();
assert_eq!(client_offer.client_max_window_bits, Some(None));
assert_eq!(client_offer.server_max_window_bits, None);
#[cfg(feature = "zlib")]
{
let server = DeflateOptions {
level: CompressionLevel::default(),
server_max_window_bits: Some(15),
client_max_window_bits: Some(12), server_no_context_takeover: false,
client_no_context_takeover: false,
};
let merged = server.merge(&client_offer);
assert_eq!(merged.client_max_window_bits, Some(Some(12)));
assert_eq!(merged.server_max_window_bits, Some(Some(15)));
let response = merged.to_string();
assert!(response.contains("client_max_window_bits=12"));
assert!(response.contains("server_max_window_bits=15"));
}
#[cfg(not(feature = "zlib"))]
{
let server = DeflateOptions {
level: CompressionLevel::default(),
server_no_context_takeover: false,
client_no_context_takeover: false,
};
let merged = server.merge(&client_offer);
assert_eq!(merged.client_max_window_bits, None);
assert_eq!(merged.server_max_window_bits, None);
}
}
}