use std::{hash::Hasher, io::Write};
use base64::engine::{general_purpose, Engine as _};
use crc24::Crc24Hasher;
use generic_array::typenum::U64;
use super::Headers;
use crate::pgp::{
armor::BlockType,
errors::Result,
line_writer::{LineBreak, LineWriter},
ser::Serialize,
util::TeeWriter,
};
pub fn write(
source: &impl Serialize,
typ: BlockType,
writer: &mut impl Write,
headers: Option<&Headers>,
include_checksum: bool,
) -> Result<()> {
write_header(writer, typ, headers)?;
let mut crc_hasher = include_checksum.then(Crc24Hasher::new);
write_body(writer, source, crc_hasher.as_mut())?;
write_footer(writer, typ, crc_hasher)?;
Ok(())
}
pub(crate) fn write_header(
writer: &mut impl Write,
typ: BlockType,
headers: Option<&Headers>,
) -> Result<()> {
writer.write_all(&b"-----BEGIN "[..])?;
typ.to_writer(writer)?;
writer.write_all(&b"-----\n"[..])?;
if let Some(headers) = headers {
for (key, values) in headers.iter() {
for value in values {
writer.write_all(key.as_bytes())?;
writer.write_all(&b": "[..])?;
writer.write_all(value.as_bytes())?;
writer.write_all(&b"\n"[..])?;
}
}
}
writer.write_all(&b"\n"[..])?;
writer.flush()?;
Ok(())
}
fn write_body(
writer: &mut impl Write,
source: &impl Serialize,
crc_hasher: Option<&mut Crc24Hasher>,
) -> Result<()> {
{
let mut line_wrapper = LineWriter::<_, U64>::new(writer.by_ref(), LineBreak::Lf);
let mut enc = Base64Encoder::new(&mut line_wrapper);
if let Some(crc_hasher) = crc_hasher {
let mut tee = TeeWriter::new(crc_hasher, &mut enc);
source.to_writer(&mut tee)?;
} else {
source.to_writer(&mut enc)?;
}
}
Ok(())
}
pub(crate) fn write_footer(
writer: &mut impl Write,
typ: BlockType,
crc_hasher: Option<Crc24Hasher>,
) -> Result<()> {
if let Some(crc_hasher) = crc_hasher {
writer.write_all(b"=")?;
let crc = crc_hasher.finish() as u32;
let crc_buf = [
(crc >> 16) as u8,
(crc >> 8) as u8,
crc as u8,
];
let crc_enc = general_purpose::STANDARD.encode(crc_buf);
writer.write_all(crc_enc.as_bytes())?;
writer.write_all(&b"\n"[..])?;
}
writer.write_all(&b"-----END "[..])?;
typ.to_writer(writer)?;
writer.write_all(&b"-----\n"[..])?;
Ok(())
}
pub(crate) struct Base64Encoder<W: std::io::Write>(
base64::write::EncoderWriter<'static, general_purpose::GeneralPurpose, W>,
);
impl<W: std::io::Write> Base64Encoder<W> {
pub(crate) fn new(writer: W) -> Self {
Self(base64::write::EncoderWriter::new(
writer,
&general_purpose::STANDARD,
))
}
}
impl<W: std::io::Write> std::io::Write for Base64Encoder<W> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.0.write(buf)
}
fn write_all(&mut self, mut buf: &[u8]) -> std::io::Result<()> {
while !buf.is_empty() {
match self.0.write(buf) {
Ok(0) => {}
Ok(n) => buf = &buf[n..],
Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => {}
Err(e) => return Err(e),
}
}
Ok(())
}
fn flush(&mut self) -> std::io::Result<()> {
self.0.flush()
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use std::io;
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;
use rand_xorshift::XorShiftRng;
use super::*;
use crate::pgp::util::test::ChaosReader;
struct TestSource {
content: Vec<u8>,
}
impl Serialize for TestSource {
fn to_writer<W: io::Write>(&self, w: &mut W) -> Result<()> {
w.write_all(&self.content).unwrap();
Ok(())
}
fn write_len(&self) -> usize {
self.content.len()
}
}
impl TestSource {
pub fn new(content: Vec<u8>) -> Self {
TestSource { content }
}
}
#[test]
fn writes_no_doubleline() {
let rng = &mut XorShiftRng::seed_from_u64(0);
for i in 2..1024 {
let buf: Vec<u8> = (0..i).map(|_| rng.gen()).collect();
let source = TestSource::new(buf);
let mut dest = Vec::new();
write(&source, BlockType::Message, &mut dest, None, true).unwrap();
let dest_str = std::str::from_utf8(&dest).unwrap();
let lines = dest_str.lines().collect::<Vec<_>>();
assert_eq!(lines[0], "-----BEGIN PGP MESSAGE-----");
assert!(
!lines[lines.len() - 3].is_empty(),
"last line must not be empty"
);
assert_eq!(
lines[lines.len() - 2].len(),
5,
"invalid checksum line: '{}'",
lines[lines.len() - 2]
);
assert_eq!(lines[lines.len() - 1], "-----END PGP MESSAGE-----");
}
}
#[test]
fn writes_no_checksum() {
let mut rng = XorShiftRng::seed_from_u64(0);
for i in 2..1024 {
let buf: Vec<u8> = (0..i).map(|_| rng.gen()).collect();
let source = TestSource::new(buf);
let mut dest = Vec::new();
write(&source, BlockType::Message, &mut dest, None, false).unwrap();
let dest_str = std::str::from_utf8(&dest).unwrap();
let lines = dest_str.lines().collect::<Vec<_>>();
assert_eq!(lines[0], "-----BEGIN PGP MESSAGE-----");
assert!(
!lines[lines.len() - 2].is_empty(),
"last line must not be empty"
);
assert_eq!(lines[lines.len() - 1], "-----END PGP MESSAGE-----");
}
}
#[test]
fn test_base64_encoder() {
let mut rng = ChaCha8Rng::seed_from_u64(1);
for size in 1..=500 {
let mut buf = vec![0u8; size];
rng.fill(&mut buf[..]);
let mut reader = ChaosReader::new(rng.clone(), buf.clone());
let mut out = Vec::new();
{
let mut writer = Base64Encoder::new(&mut out);
std::io::copy(&mut reader, &mut writer).unwrap();
}
let out = std::string::String::from_utf8(out).unwrap();
let out2 = general_purpose::STANDARD.encode(buf);
assert_eq!(out, out2);
}
}
}