use std::hash::Hasher;
use std::io::Write;
use base64::engine::{general_purpose, Engine as _};
use crc24::Crc24Hasher;
use generic_array::typenum::U64;
use crate::armor::BlockType;
use crate::errors::Result;
use crate::line_writer::{LineBreak, LineWriter};
use crate::ser::Serialize;
use crate::util::TeeWriter;
use super::Headers;
pub fn write(
source: &impl Serialize,
typ: BlockType,
writer: &mut impl Write,
headers: Option<&Headers>,
include_checksum: bool,
) -> Result<()> {
write_header(writer, typ, headers)?;
let mut crc_hasher = include_checksum.then(Crc24Hasher::new);
write_body(writer, source, crc_hasher.as_mut())?;
write_footer(writer, typ, crc_hasher)?;
Ok(())
}
fn write_header(writer: &mut impl Write, typ: BlockType, headers: Option<&Headers>) -> Result<()> {
writer.write_all(&b"-----BEGIN "[..])?;
typ.to_writer(writer)?;
writer.write_all(&b"-----\n"[..])?;
if let Some(headers) = headers {
for (key, values) in headers.iter() {
for value in values {
writer.write_all(key.as_bytes())?;
writer.write_all(&b": "[..])?;
writer.write_all(value.as_bytes())?;
writer.write_all(&b"\n"[..])?;
}
}
}
writer.write_all(&b"\n"[..])?;
writer.flush()?;
Ok(())
}
fn write_body(
writer: &mut impl Write,
source: &impl Serialize,
crc_hasher: Option<&mut Crc24Hasher>,
) -> Result<()> {
{
let mut line_wrapper = LineWriter::<_, U64>::new(writer.by_ref(), LineBreak::Lf);
let mut enc = ZeroWrapper(base64::write::EncoderWriter::new(
&mut line_wrapper,
&general_purpose::STANDARD,
));
if let Some(crc_hasher) = crc_hasher {
let mut tee = TeeWriter::new(crc_hasher, &mut enc);
source.to_writer(&mut tee)?;
} else {
source.to_writer(&mut enc)?;
}
}
Ok(())
}
fn write_footer(
writer: &mut impl Write,
typ: BlockType,
crc_hasher: Option<Crc24Hasher>,
) -> Result<()> {
if let Some(crc_hasher) = crc_hasher {
writer.write_all(b"=")?;
let crc = crc_hasher.finish() as u32;
let crc_buf = [
(crc >> 16) as u8,
(crc >> 8) as u8,
crc as u8,
];
let crc_enc = general_purpose::STANDARD.encode(crc_buf);
writer.write_all(crc_enc.as_bytes())?;
writer.write_all(&b"\n"[..])?;
}
writer.write_all(&b"-----END "[..])?;
typ.to_writer(writer)?;
writer.write_all(&b"-----\n"[..])?;
Ok(())
}
struct ZeroWrapper<W: std::io::Write>(W);
impl<W: std::io::Write> std::io::Write for ZeroWrapper<W> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.0.write(buf)
}
fn write_all(&mut self, mut buf: &[u8]) -> std::io::Result<()> {
while !buf.is_empty() {
match self.write(buf) {
Ok(0) => {}
Ok(n) => buf = &buf[n..],
Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => {}
Err(e) => return Err(e),
}
}
Ok(())
}
fn flush(&mut self) -> std::io::Result<()> {
self.0.flush()
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
use rand::{Rng, SeedableRng};
use rand_xorshift::XorShiftRng;
use std::io;
struct TestSource {
content: Vec<u8>,
}
impl Serialize for TestSource {
fn to_writer<W: io::Write>(&self, w: &mut W) -> Result<()> {
w.write_all(&self.content).unwrap();
Ok(())
}
}
impl TestSource {
pub fn new(content: Vec<u8>) -> Self {
TestSource { content }
}
}
#[test]
fn writes_no_doubleline() {
let rng = &mut XorShiftRng::seed_from_u64(0);
for i in 2..1024 {
let buf: Vec<u8> = (0..i).map(|_| rng.gen()).collect();
let source = TestSource::new(buf);
let mut dest = Vec::new();
write(&source, BlockType::Message, &mut dest, None, true).unwrap();
let dest_str = std::str::from_utf8(&dest).unwrap();
let lines = dest_str.lines().collect::<Vec<_>>();
assert_eq!(lines[0], "-----BEGIN PGP MESSAGE-----");
assert!(
!lines[lines.len() - 3].is_empty(),
"last line must not be empty"
);
assert_eq!(
lines[lines.len() - 2].len(),
5,
"invalid checksum line: '{}'",
lines[lines.len() - 2]
);
assert_eq!(lines[lines.len() - 1], "-----END PGP MESSAGE-----");
}
}
#[test]
fn writes_no_checksum() {
let mut rng = XorShiftRng::seed_from_u64(0);
for i in 2..1024 {
let buf: Vec<u8> = (0..i).map(|_| rng.gen()).collect();
let source = TestSource::new(buf);
let mut dest = Vec::new();
write(&source, BlockType::Message, &mut dest, None, false).unwrap();
let dest_str = std::str::from_utf8(&dest).unwrap();
let lines = dest_str.lines().collect::<Vec<_>>();
assert_eq!(lines[0], "-----BEGIN PGP MESSAGE-----");
assert!(
!lines[lines.len() - 2].is_empty(),
"last line must not be empty"
);
assert_eq!(lines[lines.len() - 1], "-----END PGP MESSAGE-----");
}
}
}