use std::io::{self, Write};
use zstd::stream::write::{Decoder, Encoder};
pub const WS_COMPRESSION_LEVEL: i32 = 3;
pub struct WsStreamCompressor {
encoder: Encoder<'static, Vec<u8>>,
}
impl WsStreamCompressor {
pub fn new(level: i32) -> io::Result<Self> {
Ok(Self {
encoder: Encoder::new(Vec::new(), level)?,
})
}
pub fn compress(&mut self, message: &[u8]) -> io::Result<Vec<u8>> {
self.encoder.write_all(message)?;
self.encoder.flush()?;
Ok(take_buffer(self.encoder.get_mut()))
}
}
pub struct WsStreamDecompressor {
decoder: Decoder<'static, Vec<u8>>,
}
impl WsStreamDecompressor {
pub fn new() -> io::Result<Self> {
Ok(Self {
decoder: Decoder::new(Vec::new())?,
})
}
pub fn decompress(&mut self, frame: &[u8]) -> io::Result<Vec<u8>> {
self.decoder.write_all(frame)?;
self.decoder.flush()?;
Ok(take_buffer(self.decoder.get_mut()))
}
}
fn take_buffer(buf: &mut Vec<u8>) -> Vec<u8> {
std::mem::replace(buf, Vec::with_capacity(buf.capacity()))
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_messages(n: usize) -> Vec<Vec<u8>> {
(0..n)
.map(|i| {
format!(
r#"{{"jsonrpc":"2.0","method":"transactionNotification","params":{{"subscription":1,"result":{{"slot":{},"transaction":{{"signatures":["sig{}"],"meta":{{"fee":5000,"computeUnitsConsumed":1234,"logMessages":["Program X invoke [1]","Program X success"]}}}}}}}}}}"#,
423_816_307 + i,
i
)
.into_bytes()
})
.collect()
}
#[test]
fn round_trips_a_stream_of_messages() {
let messages = sample_messages(50);
let mut comp = WsStreamCompressor::new(WS_COMPRESSION_LEVEL).unwrap();
let mut decomp = WsStreamDecompressor::new().unwrap();
for msg in &messages {
let frame = comp.compress(msg).unwrap();
let back = decomp.decompress(&frame).unwrap();
assert_eq!(&back, msg, "decompressed frame must match the original");
}
}
#[test]
fn handles_empty_and_large_messages() {
let mut comp = WsStreamCompressor::new(WS_COMPRESSION_LEVEL).unwrap();
let mut decomp = WsStreamDecompressor::new().unwrap();
let empty = comp.compress(b"").unwrap();
assert_eq!(decomp.decompress(&empty).unwrap(), b"");
let large = vec![b'a'; 5 * 1024 * 1024];
let frame = comp.compress(&large).unwrap();
assert_eq!(decomp.decompress(&frame).unwrap(), large);
}
}