use std::fmt;
use std::io;
use crate::{
Profile,
Result,
packet::prelude::*,
};
use crate::serialize::{
Marshal,
stream::{
writer,
Cookie,
Message,
Private,
},
};
pub struct Padder<'a, 'p: 'a> {
inner: writer::BoxStack<'a, Cookie>,
policy: Box<dyn Fn(u64) -> u64 + Send + Sync + 'p>,
cookie: Cookie,
}
assert_send_and_sync!(Padder<'_, '_>);
impl<'a, 'p> Padder<'a, 'p> {
pub fn new(inner: Message<'a>) -> Self {
let level = inner.as_ref().cookie_ref().level;
let cookie = Cookie::new(level + 1);
Self {
inner: writer::BoxStack::from(inner),
policy: Box::new(padme),
cookie,
}
}
pub fn with_policy<P>(mut self, p: P) -> Self
where
P: Fn(u64) -> u64 + Send + Sync + 'p,
{
self.policy = Box::new(p);
self
}
pub fn build(self) -> Result<Message<'a>> {
Ok(Message::from(Box::new(self)))
}
}
impl<'a, 'p> fmt::Debug for Padder<'a, 'p> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Padder")
.field("inner", &self.inner)
.field("cookie", &self.cookie)
.finish()
}
}
impl<'a, 'p> io::Write for Padder<'a, 'p> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.inner.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.inner.flush()
}
}
impl<'a, 'p> writer::Stackable<'a, Cookie> for Padder<'a, 'p>
{
fn into_inner(mut self: Box<Self>)
-> Result<Option<writer::BoxStack<'a, Cookie>>>
{
let enabled = writer::map(
self.as_ref(),
|w| match w.cookie_ref().private {
Private::Encryptor { profile, .. } =>
Some(profile == Profile::RFC9580),
_ => None,
})
.unwrap_or(false);
if enabled {
let size = self.position();
let padded_size = (self.policy)(size);
if padded_size < size {
return Err(crate::Error::InvalidOperation(
format!("Padding policy({}) returned {}: \
smaller than argument",
size, padded_size)).into());
}
let amount = padded_size - size;
Packet::from(Padding::new(amount.try_into()
.unwrap_or(usize::MAX))?)
.serialize(&mut self)?;
}
Ok(Some(self.inner))
}
fn pop(&mut self) -> Result<Option<writer::BoxStack<'a, Cookie>>> {
unreachable!("Only implemented by Signer")
}
fn mount(&mut self, _new: writer::BoxStack<'a, Cookie>) {
unreachable!("Only implemented by Signer")
}
fn inner_ref(&self) -> Option<&(dyn writer::Stackable<'a, Cookie> + Send + Sync)> {
Some(self.inner.as_ref())
}
fn inner_mut(&mut self) -> Option<&mut (dyn writer::Stackable<'a, Cookie> + Send + Sync)> {
Some(self.inner.as_mut())
}
fn cookie_set(&mut self, cookie: Cookie) -> Cookie {
std::mem::replace(&mut self.cookie, cookie)
}
fn cookie_ref(&self) -> &Cookie {
&self.cookie
}
fn cookie_mut(&mut self) -> &mut Cookie {
&mut self.cookie
}
fn position(&self) -> u64 {
self.inner.position()
}
}
pub fn padme(l: u64) -> u64 {
if l < 2 {
return 1; }
let e = log2(l); let s = log2(e as u64) + 1; let z = e - s; let m = (1 << z) - 1; (l + (m as u64)) & !(m as u64) }
fn log2(x: u64) -> usize {
if x == 0 {
0
} else {
63 - x.leading_zeros() as usize
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn log2_test() {
for i in 0..64 {
assert_eq!(log2(1u64 << i), i);
if i > 0 {
assert_eq!(log2((1u64 << i) - 1), i - 1);
assert_eq!(log2((1u64 << i) + 1), i);
}
}
}
fn padme_multiplicative_overhead(p: u64) -> f32 {
let c = padme(p);
let (p, c) = (p as f32, c as f32);
(c - p) / p
}
const MAX_OVERHEAD: f32 = 0.1163;
#[test]
fn padme_max_overhead() {
assert!(0.111 < padme_multiplicative_overhead(9));
assert!(padme_multiplicative_overhead(9) < 0.112);
assert!(padme_multiplicative_overhead(129) < MAX_OVERHEAD);
}
quickcheck! {
fn padme_overhead(l: u32) -> bool {
if l < 2 {
return true; }
let o = padme_multiplicative_overhead(l as u64);
let l_ = l as f32;
let e = l_.log2().floor(); let s = e.log2().floor() + 1.; let max_overhead = (2.0_f32.powf(e-s) - 1.) / l_;
assert!(o < MAX_OVERHEAD,
"padme({}) => {}: overhead {} exceeds maximum overhead {}",
l, padme(l.into()), o, MAX_OVERHEAD);
assert!(o <= max_overhead,
"padme({}) => {}: overhead {} exceeds maximum overhead {}",
l, padme(l.into()), o, max_overhead);
true
}
}
#[test]
fn roundtrip() {
use std::io::Write;
use crate::crypto::random;
use crate::parse::Parse;
use crate::serialize::stream::*;
let mut two_bytes = [0; 2];
random(&mut two_bytes).expect("Have RNG");
let size: usize = ((two_bytes[0] as usize) << 8) + (two_bytes[1] as usize);
let mut msg = vec![0; size % 1024];
crate::crypto::random(&mut msg).unwrap();
let mut padded = vec![];
{
let message = Message::new(&mut padded);
let padder = Padder::new(message).with_policy(padme).build().unwrap();
let mut w = LiteralWriter::new(padder).build().unwrap();
w.write_all(&msg).unwrap();
w.finalize().unwrap();
}
let m = crate::Message::from_bytes(&padded).unwrap();
assert_eq!(m.body().unwrap().body(), &msg[..]);
}
#[test]
fn no_compression() {
use std::io::Write;
use crate::serialize::stream::*;
const MSG: &[u8] = b"@@@@@@@@@@@@@@";
let mut padded = vec![];
{
let message = Message::new(&mut padded);
let padder = Padder::new(message).build().unwrap();
let mut w = LiteralWriter::new(padder).build().unwrap();
w.write_all(MSG).unwrap();
w.finalize().unwrap();
}
assert!(padded.windows(MSG.len()).any(|ch| ch == MSG),
"Could not find uncompressed message");
}
}