wecanencrypt 0.2.0

Simple Rust OpenPGP library for encryption, signing, and key management (includes rpgp).
Documentation
use std::io::{self, BufRead, Read};

use super::PacketBodyReader;
use crate::pgp::{
    composed::{DebugBufRead, PlainSessionKey},
    crypto::sym::SymmetricKeyAlgorithm,
    errors::{bail, ensure, ensure_eq, unsupported_err, Result},
    packet::{
        GnupgAeadDataConfig, PacketHeader, ProtectedDataConfig, StreamDecryptor,
        SymEncryptedProtectedDataConfig,
    },
    types::Tag,
};

#[derive(Debug)]
pub struct SymEncryptedProtectedDataReader<R: DebugBufRead> {
    config: ProtectedDataConfig,
    source: Source<R>,
}

#[allow(clippy::large_enum_variant)]
#[derive(Debug)]
enum Source<R: DebugBufRead> {
    Init(PacketBodyReader<R>),
    BodyRaw(PacketBodyReader<R>),
    BodyDecryptor(StreamDecryptor<PacketBodyReader<R>>),
    Done(PacketBodyReader<R>),
    Error,
}

impl<R: DebugBufRead> SymEncryptedProtectedDataReader<R> {
    pub fn new(mut source: PacketBodyReader<R>) -> Result<Self> {
        debug_assert_eq!(source.packet_header().tag(), Tag::SymEncryptedProtectedData);

        let config = ProtectedDataConfig::Seipd(SymEncryptedProtectedDataConfig::try_from_reader(
            &mut source,
        )?);

        Ok(Self {
            config,
            source: Source::Init(source),
        })
    }

    pub fn new_gnupg_aead(mut source: PacketBodyReader<R>) -> Result<Self> {
        debug_assert_eq!(source.packet_header().tag(), Tag::GnupgAeadData);

        let config =
            ProtectedDataConfig::GnupgAead(GnupgAeadDataConfig::try_from_reader(&mut source)?);

        Ok(Self {
            config,
            source: Source::Init(source),
        })
    }

    pub fn decrypt(&mut self, session_key: &PlainSessionKey) -> Result<()> {
        ensure!(
            matches!(self.source, Source::Init(_)),
            "cannot decrypt after starting to read"
        );

        match self.config {
            ProtectedDataConfig::Seipd(SymEncryptedProtectedDataConfig::V1) => {
                let (sym_alg, session_key) = match session_key {
                    PlainSessionKey::V3_4 { sym_alg, key } => (sym_alg, key),
                    PlainSessionKey::V5 { .. } => {
                        unsupported_err!("v5 session keys are unsupported with SEIPD v1 edata");
                    }
                    PlainSessionKey::V6 { .. } => {
                        bail!("v6 session keys are not allowed with SEIPD v1 edata");
                    }
                };

                replace_with::replace_with_and_return(
                    &mut self.source,
                    || Source::Error,
                    |source| {
                        let Source::Init(source) = source else {
                            unreachable!("checked");
                        };
                        match StreamDecryptor::v1(*sym_alg, session_key.as_ref(), source) {
                            Ok(dec) => (Ok(()), Source::BodyDecryptor(dec)),
                            Err(err) => (Err(err), Source::Error),
                        }
                    },
                )?;
            }
            ProtectedDataConfig::GnupgAead(GnupgAeadDataConfig {
                sym_alg,
                aead,
                chunk_size,
                ref iv,
            }) => {
                let (sym_alg_session_key, session_key) = match session_key {
                    PlainSessionKey::V3_4 { sym_alg, key } => (Some(sym_alg), key),
                    PlainSessionKey::V5 { key } => (None, key),
                    PlainSessionKey::V6 { .. } => {
                        bail!("v6 session keys are not allowed with GnupgAead edata")
                    }
                };
                if let Some(sym_alg_session_key) = sym_alg_session_key {
                    ensure_eq!(
                        sym_alg,
                        *sym_alg_session_key,
                        "Mismatching symmetric key algorithm"
                    );
                }

                ensure_eq!(
                    session_key.len(),
                    sym_alg.key_size(),
                    "Unexpected session key length for {:?}",
                    sym_alg
                );

                replace_with::replace_with_and_return(
                    &mut self.source,
                    || Source::Error,
                    |source| {
                        let Source::Init(source) = source else {
                            unreachable!("checked");
                        };

                        match StreamDecryptor::gnupg_aead(
                            sym_alg,
                            aead,
                            chunk_size,
                            session_key.as_ref(),
                            iv,
                            source,
                        ) {
                            Ok(dec) => (Ok(()), Source::BodyDecryptor(dec)),
                            Err(err) => (Err(err), Source::Error),
                        }
                    },
                )?;
            }
            ProtectedDataConfig::Seipd(SymEncryptedProtectedDataConfig::V2 {
                sym_alg,
                aead,
                chunk_size,
                salt,
            }) => {
                let (sym_alg_session_key, session_key) = match session_key {
                    PlainSessionKey::V3_4 { .. } => {
                        bail!("v3/4 session keys are not allowed with SEIPD v2 edata");
                    }
                    PlainSessionKey::V5 { .. } => {
                        bail!("v5 session keys are not allowed with SEIPD v2 edata");
                    }
                    PlainSessionKey::V6 { key } => (None::<&SymmetricKeyAlgorithm>, key),
                };
                if let Some(sym_alg_session_key) = sym_alg_session_key {
                    ensure_eq!(
                        sym_alg,
                        *sym_alg_session_key,
                        "mismatching symmetric key algorithm"
                    );
                }

                ensure_eq!(
                    session_key.len(),
                    sym_alg.key_size(),
                    "Unexpected session key length for {:?}",
                    sym_alg
                );
                replace_with::replace_with_and_return(
                    &mut self.source,
                    || Source::Error,
                    |source| {
                        let Source::Init(source) = source else {
                            unreachable!("checked");
                        };
                        match StreamDecryptor::v2(
                            sym_alg,
                            aead,
                            chunk_size,
                            &salt,
                            session_key.as_ref(),
                            source,
                        ) {
                            Ok(dec) => (Ok(()), Source::BodyDecryptor(dec)),
                            Err(err) => (Err(err), Source::Error),
                        }
                    },
                )?;
            }
        }
        Ok(())
    }

    pub(crate) fn new_done(config: ProtectedDataConfig, source: PacketBodyReader<R>) -> Self {
        Self {
            source: Source::Done(source),
            config,
        }
    }

    pub fn config(&self) -> &ProtectedDataConfig {
        &self.config
    }

    pub fn is_done(&self) -> bool {
        matches!(self.source, Source::Done { .. })
    }

    pub fn into_inner(self) -> PacketBodyReader<R> {
        match self.source {
            Source::Init(source) => source,
            Source::BodyDecryptor(source) => source.into_inner(),
            Source::BodyRaw(source) => source,
            Source::Done(source) => source,
            Source::Error => panic!("SymEncryptedProtectedDataReader errored"),
        }
    }

    pub fn get_mut(&mut self) -> &mut PacketBodyReader<R> {
        match &mut self.source {
            Source::Init(source) => source,
            Source::BodyDecryptor(source) => source.get_mut(),
            Source::BodyRaw(source) => source,
            Source::Done(source) => source,
            Source::Error => panic!("SymEncryptedProtectedDataReader errored"),
        }
    }

    pub fn packet_header(&self) -> PacketHeader {
        match &self.source {
            Source::Init(source) => source.packet_header(),
            Source::BodyDecryptor(source) => source.get_ref().packet_header(),
            Source::BodyRaw(source) => source.packet_header(),
            Source::Done(source) => source.packet_header(),
            Source::Error => panic!("SymEncryptedProtectedDataReader errored"),
        }
    }

    fn fill_inner(&mut self) -> io::Result<()> {
        if matches!(self.source, Source::Done(_)) {
            return Ok(());
        }

        loop {
            let (needs_replacing, should_return) = match &mut self.source {
                Source::Init(_) => (true, false),
                Source::BodyRaw(source) => {
                    let buf = source.fill_buf()?;
                    (buf.is_empty(), true)
                }
                Source::BodyDecryptor(decryptor) => {
                    let buf = decryptor.fill_buf()?;
                    (buf.is_empty(), true)
                }
                Source::Done(_) => (false, true),
                Source::Error => panic!("SymEncryptedProtectedDataReader errored"),
            };

            if needs_replacing {
                replace_with::replace_with(
                    &mut self.source,
                    || Source::Error,
                    |source| match source {
                        Source::Init(source) => Source::BodyRaw(source),
                        Source::BodyRaw(source) => Source::Done(source),
                        Source::BodyDecryptor(source) => Source::Done(source.into_inner()),
                        Source::Done(source) => Source::Done(source),
                        Source::Error => panic!("SymEncryptedProtectedDataReader errored"),
                    },
                )
            }
            if should_return {
                return Ok(());
            }
        }
    }
}

impl<R: DebugBufRead> BufRead for SymEncryptedProtectedDataReader<R> {
    fn fill_buf(&mut self) -> io::Result<&[u8]> {
        self.fill_inner()?;
        match &mut self.source {
            Source::Init(..) => unreachable!("invalid state"),
            Source::BodyDecryptor(decryptor) => decryptor.fill_buf(),
            Source::BodyRaw(source) => source.fill_buf(),
            Source::Done(_) => Ok(&[][..]),
            Source::Error => panic!("SymEncryptedProtectedDataReader errored"),
        }
    }

    fn consume(&mut self, amt: usize) {
        match &mut self.source {
            Source::Init(..) => unreachable!("invalid state"),
            Source::BodyDecryptor(decryptor) => decryptor.consume(amt),
            Source::BodyRaw(source) => source.consume(amt),
            Source::Done(_) => {
                if amt > 0 {
                    panic!("consume after done: {amt}")
                }
            }
            Source::Error => panic!("SymEncryptedProtectedDataReader errored"),
        }
    }
}

impl<R: DebugBufRead> Read for SymEncryptedProtectedDataReader<R> {
    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        self.fill_inner()?;
        match &mut self.source {
            Source::Init(..) => unreachable!("invalid state"),
            Source::BodyDecryptor(decryptor) => decryptor.read(buf),
            Source::BodyRaw(source) => source.read(buf),
            Source::Done(_) => Ok(0),
            Source::Error => panic!("SymEncryptedProtectedDataReader errored"),
        }
    }
}