use std::fmt;
use std::io::{self, Write};
use crate::{
Result,
packet::prelude::*,
};
use crate::packet::header::CTB;
use crate::serialize::{
Marshal,
stream::{
writer,
Cookie,
Message,
PartialBodyFilter,
},
};
use crate::types::{
CompressionAlgorithm,
CompressionLevel,
};
pub struct Padder<'a> {
inner: writer::BoxStack<'a, Cookie>,
policy: fn(u64) -> u64,
}
assert_send_and_sync!(Padder<'_>);
impl<'a> Padder<'a> {
pub fn new(inner: Message<'a>) -> Self {
Self {
inner: writer::BoxStack::from(inner),
policy: padme,
}
}
pub fn with_policy(mut self, p: fn(u64) -> u64) -> Self {
self.policy = p;
self
}
pub fn build(mut self) -> Result<Message<'a>> {
let mut inner = self.inner;
let level = inner.cookie_ref().level + 1;
CTB::new(Tag::CompressedData).serialize(&mut inner)?;
let mut inner: Message<'a>
= PartialBodyFilter::new(Message::from(inner),
Cookie::new(level));
inner.as_mut().write_u8(CompressionAlgorithm::Zip.into())?;
self.inner =
writer::ZIP::new(inner, Cookie::new(level),
CompressionLevel::none()).into();
Ok(Message::from(Box::new(self)))
}
}
impl<'a> fmt::Debug for Padder<'a> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Padder")
.field("inner", &self.inner)
.finish()
}
}
impl<'a> io::Write for Padder<'a> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.inner.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.inner.flush()
}
}
impl<'a> writer::Stackable<'a, Cookie> for Padder<'a>
{
fn into_inner(self: Box<Self>)
-> Result<Option<writer::BoxStack<'a, Cookie>>> {
let uncompressed_size = self.position();
let mut pb_writer = Box::new(self.inner).into_inner()?.unwrap();
let compressed_size = pb_writer.position();
let size = std::cmp::max(uncompressed_size, compressed_size);
let padded_size = (self.policy)(size);
if padded_size < size {
return Err(crate::Error::InvalidOperation(
format!("Padding policy({}) returned {}: smaller than argument",
size, padded_size)).into());
}
let mut amount = padded_size - compressed_size;
if false {
eprintln!("u: {}, c: {}, amount: {}",
uncompressed_size, compressed_size, amount);
}
const BUFFER_SIZE: usize = 4096;
let mut padding = vec![0; BUFFER_SIZE];
while amount > 0 {
let n = std::cmp::min(BUFFER_SIZE as u64, amount) as usize;
crate::crypto::random(&mut padding[..n]);
pb_writer.write_all(&padding[..n])?;
amount -= n as u64;
}
pb_writer.into_inner()
}
fn pop(&mut self) -> Result<Option<writer::BoxStack<'a, Cookie>>> {
unreachable!("Only implemented by Signer")
}
fn mount(&mut self, _new: writer::BoxStack<'a, Cookie>) {
unreachable!("Only implemented by Signer")
}
fn inner_ref(&self) -> Option<&(dyn writer::Stackable<'a, Cookie> + Send + Sync)> {
Some(self.inner.as_ref())
}
fn inner_mut(&mut self) -> Option<&mut (dyn writer::Stackable<'a, Cookie> + Send + Sync)> {
Some(self.inner.as_mut())
}
fn cookie_set(&mut self, cookie: Cookie) -> Cookie {
self.inner.cookie_set(cookie)
}
fn cookie_ref(&self) -> &Cookie {
self.inner.cookie_ref()
}
fn cookie_mut(&mut self) -> &mut Cookie {
self.inner.cookie_mut()
}
fn position(&self) -> u64 {
self.inner.position()
}
}
#[allow(clippy::many_single_char_names)]
pub fn padme(l: u64) -> u64 {
if l < 2 {
return 1; }
let e = log2(l); let s = log2(e as u64) + 1; let z = e - s; let m = (1 << z) - 1; (l + (m as u64)) & !(m as u64) }
fn log2(x: u64) -> usize {
if x == 0 {
0
} else {
63 - x.leading_zeros() as usize
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn log2_test() {
for i in 0..64 {
assert_eq!(log2(1u64 << i), i);
if i > 0 {
assert_eq!(log2((1u64 << i) - 1), i - 1);
assert_eq!(log2((1u64 << i) + 1), i);
}
}
}
fn padme_multiplicative_overhead(p: u64) -> f32 {
let c = padme(p);
let (p, c) = (p as f32, c as f32);
(c - p) / p
}
const MAX_OVERHEAD: f32 = 0.1163;
#[test]
fn padme_max_overhead() {
assert!(0.111 < padme_multiplicative_overhead(9));
assert!(padme_multiplicative_overhead(9) < 0.112);
assert!(padme_multiplicative_overhead(129) < MAX_OVERHEAD);
}
quickcheck! {
fn padme_overhead(l: u32) -> bool {
if l < 2 {
return true; }
let o = padme_multiplicative_overhead(l as u64);
let l_ = l as f32;
let e = l_.log2().floor(); let s = e.log2().floor() + 1.; let max_overhead = (2.0_f32.powf(e-s) - 1.) / l_;
assert!(o < MAX_OVERHEAD,
"padme({}) => {}: overhead {} exceeds maximum overhead {}",
l, padme(l.into()), o, MAX_OVERHEAD);
assert!(o <= max_overhead,
"padme({}) => {}: overhead {} exceeds maximum overhead {}",
l, padme(l.into()), o, max_overhead);
true
}
}
#[test]
fn roundtrip() {
use std::io::Write;
use crate::parse::Parse;
use crate::serialize::stream::*;
let mut msg = vec![0; rand::random::<usize>() % 1024];
crate::crypto::random(&mut msg);
let mut padded = vec![];
{
let message = Message::new(&mut padded);
let padder = Padder::new(message).with_policy(padme).build().unwrap();
let mut w = LiteralWriter::new(padder).build().unwrap();
w.write_all(&msg).unwrap();
w.finalize().unwrap();
}
let m = crate::Message::from_bytes(&padded).unwrap();
assert_eq!(m.body().unwrap().body(), &msg[..]);
}
#[test]
fn no_compression() {
use std::io::Write;
use crate::serialize::stream::*;
const MSG: &[u8] = b"@@@@@@@@@@@@@@";
let mut padded = vec![];
{
let message = Message::new(&mut padded);
let padder = Padder::new(message).build().unwrap();
let mut w = LiteralWriter::new(padder).build().unwrap();
w.write_all(MSG).unwrap();
w.finalize().unwrap();
}
assert!(padded.windows(MSG.len()).any(|ch| ch == MSG),
"Could not find uncompressed message");
}
}