wecanencrypt 0.2.0

Simple Rust OpenPGP library for encryption, signing, and key management (includes rpgp).
Documentation
use std::io::{self, BufRead, Read};

use bytes::{Buf, BytesMut};

use super::PacketBodyReader;
use crate::pgp::{
    composed::DebugBufRead,
    packet::{Decompressor, PacketHeader},
    types::Tag,
    util::fill_buffer_bytes,
};

const BUFFER_SIZE: usize = 8 * 1024;

#[derive(Debug)]
pub enum CompressedDataReader<R: DebugBufRead> {
    Body {
        source: MaybeDecompress<PacketBodyReader<R>>,
        buffer: BytesMut,
    },
    Done {
        source: PacketBodyReader<R>,
    },
    Error,
}

impl<R: DebugBufRead> CompressedDataReader<R> {
    pub fn new(source: PacketBodyReader<R>, decompress: bool) -> io::Result<Self> {
        debug_assert_eq!(source.packet_header().tag(), Tag::CompressedData);

        let source = if decompress {
            let dec = Decompressor::from_reader(source)?;
            MaybeDecompress::Decompress(dec)
        } else {
            MaybeDecompress::Raw(source)
        };

        Ok(Self::Body {
            source,
            buffer: BytesMut::with_capacity(BUFFER_SIZE),
        })
    }

    pub fn new_done(source: PacketBodyReader<R>) -> Self {
        Self::Done { source }
    }

    pub fn is_done(&self) -> bool {
        matches!(self, Self::Done { .. })
    }

    pub fn into_inner(self) -> PacketBodyReader<R> {
        match self {
            Self::Body { source, .. } => source.into_inner(),
            Self::Done { source, .. } => source,
            Self::Error => {
                panic!("CompressedDataReader errored")
            }
        }
    }

    pub fn get_mut(&mut self) -> &mut PacketBodyReader<R> {
        match self {
            Self::Body { source, .. } => source.get_mut(),
            Self::Done { source, .. } => source,
            Self::Error => {
                panic!("CompressedDataReader errored")
            }
        }
    }

    pub fn packet_header(&self) -> PacketHeader {
        match self {
            Self::Body { ref source, .. } => match source {
                MaybeDecompress::Raw(r) => r.packet_header(),
                MaybeDecompress::Decompress(r) => r.get_ref().packet_header(),
            },
            Self::Done { ref source, .. } => source.packet_header(),
            Self::Error => {
                panic!("CompressedDataReader errored")
            }
        }
    }

    /// Enables decompression
    pub fn decompress(self) -> io::Result<Self> {
        match self {
            Self::Body { source, buffer } => Ok(Self::Body {
                source: source.decompress()?,
                buffer,
            }),
            Self::Done { .. } => Err(io::Error::new(
                io::ErrorKind::InvalidInput,
                "already finished",
            )),
            Self::Error => Err(io::Error::other("CompressedDataReader errored")),
        }
    }
}

impl<R: DebugBufRead> BufRead for CompressedDataReader<R> {
    fn fill_buf(&mut self) -> io::Result<&[u8]> {
        self.fill_inner()?;
        match self {
            Self::Body { ref mut buffer, .. } => Ok(&buffer[..]),
            Self::Done { .. } => Ok(&[][..]),
            Self::Error => Err(io::Error::other("CompressedDataReader errored")),
        }
    }

    fn consume(&mut self, amt: usize) {
        match self {
            Self::Body { ref mut buffer, .. } => {
                buffer.advance(amt);
            }
            Self::Done { .. } => {}
            Self::Error => {
                panic!("CompressedDataReader errored");
            }
        }
    }
}

impl<R: DebugBufRead> Read for CompressedDataReader<R> {
    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        self.fill_inner()?;
        match self {
            Self::Body { ref mut buffer, .. } => {
                let to_write = buffer.remaining().min(buf.len());
                buffer.copy_to_slice(&mut buf[..to_write]);
                Ok(to_write)
            }
            Self::Done { .. } => Ok(0),
            Self::Error => Err(io::Error::other("CompressedDataReader errored")),
        }
    }
}

impl<R: DebugBufRead> CompressedDataReader<R> {
    fn fill_inner(&mut self) -> io::Result<()> {
        if matches!(self, Self::Done { .. }) {
            return Ok(());
        }

        match std::mem::replace(self, Self::Error) {
            Self::Body {
                mut source,
                mut buffer,
            } => {
                if buffer.has_remaining() {
                    *self = Self::Body { source, buffer };
                    return Ok(());
                }

                let read = fill_buffer_bytes(&mut source, &mut buffer, BUFFER_SIZE)?;
                if read == 0 {
                    let source = source.into_inner();

                    *self = Self::Done { source };
                } else {
                    *self = Self::Body { source, buffer };
                }
                Ok(())
            }
            Self::Done { source } => {
                *self = Self::Done { source };
                Ok(())
            }
            Self::Error => Err(io::Error::other("CompressedDataReader errored")),
        }
    }
}

#[derive(Debug)]
pub enum MaybeDecompress<R: DebugBufRead> {
    Raw(R),
    Decompress(Decompressor<R>),
}

impl<R: DebugBufRead> MaybeDecompress<R> {
    fn decompress(self) -> io::Result<Self> {
        match self {
            Self::Raw(r) => Ok(Self::Decompress(Decompressor::from_reader(r)?)),
            Self::Decompress(_) => {
                // already decompressing
                Ok(self)
            }
        }
    }
}

impl<R: DebugBufRead> MaybeDecompress<R> {
    fn into_inner(self) -> R {
        match self {
            Self::Raw(r) => r,
            Self::Decompress(r) => r.into_inner(),
        }
    }
    fn get_mut(&mut self) -> &mut R {
        match self {
            Self::Raw(r) => r,
            Self::Decompress(r) => r.get_mut(),
        }
    }
}

impl<R: DebugBufRead> BufRead for MaybeDecompress<R> {
    fn fill_buf(&mut self) -> io::Result<&[u8]> {
        match self {
            Self::Raw(ref mut r) => r.fill_buf(),
            Self::Decompress(ref mut r) => r.fill_buf(),
        }
    }
    fn consume(&mut self, amt: usize) {
        match self {
            Self::Raw(ref mut r) => r.consume(amt),
            Self::Decompress(ref mut r) => r.consume(amt),
        }
    }
}

impl<R: DebugBufRead> Read for MaybeDecompress<R> {
    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        match self {
            Self::Raw(ref mut r) => r.read(buf),
            Self::Decompress(ref mut r) => r.read(buf),
        }
    }
}