use std::{
collections::BTreeMap,
io::{self, Write as _},
path::{Path, PathBuf},
};
use flate2::read::GzDecoder;
use sequoia_openpgp::{
parse::{Parse, stream::DecryptorBuilder},
policy::StandardPolicy,
};
use tracing::{debug, info, instrument};
use walkdir::WalkDir;
use crate::{
filesystem::{check_space, get_combined_file_size},
package::{CHECKSUM_FILE, CompressionAlgorithm, DATA_FILE, Package},
progress::{ProgressDisplay, ProgressReader},
task::{Mode, Status},
};
const HEAP_BUFFER_SIZE: usize = 1 << 22;
pub struct DecryptOpts<T, F> {
pub key_store: crate::openpgp::keystore::KeyStore,
pub cert_store: crate::openpgp::certstore::CertStore<'static>,
pub password: F,
pub output: Option<PathBuf>,
pub decrypt_only: bool,
pub mode: Mode,
pub progress: Option<T>,
}
impl<T, F> std::fmt::Debug for DecryptOpts<T, F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DecryptOpts")
.field("output", &self.output.as_ref().map(|p| p.display()))
.field("decrypt_only", &self.decrypt_only)
.field("mode", &self.mode)
.finish()
}
}
#[instrument(err(Debug, level=tracing::Level::ERROR))]
pub async fn decrypt<S, T, F>(
package: Package<S, crate::package::state::Unverified>,
opts: DecryptOpts<T, F>,
) -> Result<Status, Error<S::Error>>
where
S: crate::package::source::PackageStream + std::fmt::Debug + 'static,
T: ProgressDisplay + Send + 'static,
F: Fn(crate::openpgp::crypto::PasswordHint) -> super::secret::Secret + Send + 'static,
<S as crate::package::source::PackageStream>::Error: Send + std::fmt::Debug + 'static,
{
let package = package.verify(&opts.cert_store).await?;
decrypt_verified(&package, opts).await
}
pub(crate) async fn decrypt_verified<S, T, F>(
package: &Package<S, crate::package::state::Verified>,
mut opts: DecryptOpts<T, F>,
) -> Result<Status, Error<S::Error>>
where
S: crate::package::source::PackageStream + std::fmt::Debug + 'static,
<S as crate::package::source::PackageStream>::Error: Send + std::fmt::Debug + 'static,
T: ProgressDisplay + Send + 'static,
F: Fn(crate::openpgp::crypto::PasswordHint) -> crate::secret::Secret + Send + 'static,
{
let output = get_output_path(
opts.output.map(|p| p.canonicalize()).transpose()?,
&package.name,
)?;
let policy = StandardPolicy::new();
let metadata = package.metadata().await?;
let (data_stream, data_size) = package.data().await?;
check_space(
data_size,
output.parent().ok_or(io::Error::new(
io::ErrorKind::NotFound,
"destination directory not found",
))?,
opts.mode,
)?;
use crate::package::source::IntoAsyncRead as _;
let data_reader = data_stream.into_async_reader();
let mut data_reader = tokio_util::io::SyncIoBridge::new(data_reader);
let status = tokio::task::spawn_blocking(move || -> Result<_, Error<S::Error>> {
let mut decryptor = DecryptorBuilder::from_reader(&mut data_reader)
.map_err(crate::openpgp::error::PgpError::from)?
.with_policy(
&policy,
None,
crate::openpgp::crypto::DecryptionHelper {
cert_store: &opts.cert_store,
key_store: &mut opts.key_store,
password: opts.password,
},
)
.map_err(crate::openpgp::error::PgpError::from)?;
let status = if let Mode::Check = opts.mode {
Status::Checked {
destination: output.to_string_lossy().to_string(),
source_size: data_size,
}
} else {
std::fs::create_dir_all(&output)?;
if let Some(pg) = opts.progress {
let mut progress_reader = ProgressReader::new(decryptor, pg.start(data_size));
if opts.decrypt_only {
write_to_file(&mut progress_reader, &output)?;
} else {
unpack(
&mut progress_reader,
&output,
metadata.compression_algorithm,
)?;
}
} else if opts.decrypt_only {
write_to_file(&mut decryptor, &output)?;
} else {
unpack(&mut decryptor, &output, metadata.compression_algorithm)?;
}
let output_files = WalkDir::new(&output)
.into_iter()
.flatten()
.filter(|entry| entry.file_type().is_file())
.map(|entry| entry.into_path());
Status::Completed {
source_size: data_size,
destination_size: get_combined_file_size(output_files)?,
destination: output.to_string_lossy().to_string(),
metadata,
}
};
Ok(status)
})
.await??;
match &status {
Status::Checked {
destination,
source_size,
} => {
debug!(destination, source_size, "Checked decryption task input");
}
Status::Completed {
destination,
source_size,
destination_size,
metadata,
} => {
info!(
destination,
source_size,
destination_size,
metadata = metadata.to_json_or_debug(),
"Successfully decrypted data package"
)
}
}
Ok(status)
}
fn get_output_path(
output: Option<PathBuf>,
pkg_file_name: &str,
) -> Result<PathBuf, std::io::Error> {
let base = if let Some(p) = output {
p
} else {
std::env::current_dir()?
};
let pkg_base_name = pkg_file_name
.split('.')
.next()
.ok_or_else(|| std::io::Error::other("Package file has no extension"))?;
let mut output = base.join(pkg_base_name);
let mut i = 1;
while output.exists() {
output = base.join(format!("{pkg_base_name}_{i}"));
i += 1;
}
Ok(output)
}
#[instrument(skip(source))]
fn unpack<R: io::Read + Send, E: Send + 'static + std::fmt::Debug>(
source: &mut R,
output: &Path,
compression_algorithm: CompressionAlgorithm,
) -> Result<(), Error<E>> {
match compression_algorithm {
CompressionAlgorithm::Stored => unpack_tar(&mut tar::Archive::new(source), output),
CompressionAlgorithm::Gzip(_) => {
unpack_tar(&mut tar::Archive::new(GzDecoder::new(source)), output)
}
CompressionAlgorithm::Zstandard(_) => unpack_tar(
&mut tar::Archive::new(zstd::stream::read::Decoder::new(source)?),
output,
),
}?;
Ok(())
}
fn sanitize_path(dest: &Path, path: &Path) -> Result<PathBuf, std::io::Error> {
use std::path::Component;
let mut sanitized = PathBuf::new();
for part in path.components() {
match part {
Component::Prefix(_) | Component::RootDir | Component::CurDir => continue,
Component::ParentDir => {
Err(std::io::Error::other("file path contains a relative part"))?;
}
Component::Normal(part) => sanitized.push(part),
}
}
if sanitized.parent().is_none() {
return Err(std::io::Error::other("empty file path"));
}
Ok(dest.join(&sanitized))
}
enum Message {
Init(PathBuf),
Payload(bytes::Bytes),
Finalize,
}
fn unpack_tar<E: Send + 'static + std::fmt::Debug>(
archive: &mut tar::Archive<impl io::Read>,
dest: &Path,
) -> Result<(), Error<E>> {
let (tx_checksum, rx_checksum) = std::sync::mpsc::sync_channel(8);
let (tx_write, rx_write) = std::sync::mpsc::sync_channel(8);
let checksum_handle = std::thread::spawn(move || -> Result<_, Error<E>> {
use sequoia_openpgp::types::HashAlgorithm::SHA256;
let mut hasher = SHA256
.context()
.map_err(crate::openpgp::error::PgpError::from)?
.for_digest();
let mut path = None;
let mut checksums = BTreeMap::new();
while let Ok(message) = rx_checksum.recv() {
match message {
Message::Init(p) => {
path = Some(p);
}
Message::Payload(buf) => hasher.update(&buf),
Message::Finalize => {
checksums.insert(
std::mem::take(&mut path).expect("path is initialized"),
crate::utils::to_hex_string(
&std::mem::replace(
&mut hasher,
SHA256
.context()
.map_err(crate::openpgp::error::PgpError::from)?
.for_digest(),
)
.into_digest()
.map_err(crate::openpgp::error::PgpError::from)?,
),
);
}
}
}
Ok(checksums)
});
let write_handle = std::thread::spawn(move || -> io::Result<()> {
let mut writer = None;
while let Ok(message) = rx_write.recv() {
match message {
Message::Init(p) => {
if let Some(parent) = p.parent()
&& !parent.exists()
{
std::fs::create_dir_all(parent)?;
}
writer = Some(io::BufWriter::with_capacity(
HEAP_BUFFER_SIZE,
std::fs::File::create(&p)?,
));
}
Message::Payload(buf) => writer
.as_mut()
.expect("writer is initialized")
.write_all(&buf)?,
Message::Finalize => {
writer = None;
}
}
}
Ok(())
});
let read_result: Result<(), Error<E>> = (|| {
for entry in archive.entries()? {
let mut entry = entry?;
let archive_path = entry.path()?.into_owned();
let output_path = match sanitize_path(dest, &archive_path) {
Ok(p) => p,
Err(e) => {
tracing::warn!("{:?}: {}", archive_path, e);
continue;
}
};
tx_checksum.send(Message::Init(archive_path))?;
tx_write.send(Message::Init(output_path))?;
copy_to_channels(&mut entry, [&tx_checksum, &tx_write])?;
tx_checksum.send(Message::Finalize)?;
tx_write.send(Message::Finalize)?;
}
Ok(())
})();
drop(tx_checksum);
drop(tx_write);
let write_result = write_handle.join().map_err(|_| Error::Thread("write"))?;
let checksum_result = checksum_handle
.join()
.map_err(|_| Error::Thread("checksum"))?;
if let Err(error) = &read_result {
tracing::error!(?error, "unpacking loop failed");
}
write_result.inspect_err(|error| tracing::error!(?error, "writer thread failed"))?;
let mut checksums =
checksum_result.inspect_err(|error| tracing::error!(?error, "checksum thread failed"))?;
read_result?;
checksums.remove(Path::new(CHECKSUM_FILE));
verify_checksums(&checksums, &read_checksum_file(dest.join(CHECKSUM_FILE))?)?;
Ok(())
}
fn copy_to_channels<const N: usize, E>(
reader: &mut impl io::Read,
tx: [&std::sync::mpsc::SyncSender<Message>; N],
) -> Result<(), Error<E>> {
let mut buf = [0; 8192];
let mut bigbuf = bytes::BytesMut::with_capacity(HEAP_BUFFER_SIZE);
fn exchange(
buffer: bytes::BytesMut,
bigbuf: &mut bytes::BytesMut,
tx: &[&std::sync::mpsc::SyncSender<Message>],
) -> Result<(), std::sync::mpsc::SendError<Message>> {
let b = std::mem::replace(bigbuf, buffer).freeze();
for tx in tx {
tx.send(Message::Payload(b.clone()))?;
}
Ok(())
}
loop {
let n = reader.read(&mut buf)?;
if n == 0 {
if !bigbuf.is_empty() {
exchange(bytes::BytesMut::new(), &mut bigbuf, &tx)?;
}
break;
}
if bigbuf.len() + n > bigbuf.capacity() {
exchange(
bytes::BytesMut::with_capacity(HEAP_BUFFER_SIZE),
&mut bigbuf,
&tx,
)?;
}
bigbuf.extend_from_slice(&buf[..n]);
}
Ok(())
}
fn write_to_file<R: io::Read, P: AsRef<Path>>(
source: &mut R,
output: P,
) -> Result<(), std::io::Error> {
let mut f = std::fs::File::create(output.as_ref().join(DATA_FILE))?;
io::copy(source, &mut f)?;
Ok(())
}
fn read_checksum_file(path: impl AsRef<Path>) -> Result<BTreeMap<PathBuf, String>, std::io::Error> {
use std::io::BufRead as _;
let mut reader = io::BufReader::new(std::fs::File::open(path)?);
let mut parsed = BTreeMap::new();
let mut buf = String::new();
while reader.read_line(&mut buf)? > 0 {
let (checksum, path) = buf
.trim()
.split_once(char::is_whitespace)
.ok_or_else(|| std::io::Error::other("Unable to parse the checksum file"))?;
parsed.insert(PathBuf::from(path), checksum.to_string());
buf.clear();
}
Ok(parsed)
}
fn verify_checksums<E>(
source: &BTreeMap<PathBuf, String>,
reference: &BTreeMap<PathBuf, String>,
) -> Result<(), Error<E>> {
for (path, checksum) in source {
let expected = reference.get(path).ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("unable to find checksum for file: {path:?}"),
)
})?;
if !checksum.eq_ignore_ascii_case(expected) {
Err(ChecksumMismatchError {
path: path.clone(),
expected: expected.clone(),
actual: checksum.clone(),
})?;
}
}
Ok(())
}
#[derive(Debug)]
pub enum Error<E> {
IO(std::io::Error),
Send(String),
Pgp(crate::openpgp::error::PgpError),
Thread(&'static str),
ChecksumMismatch(ChecksumMismatchError),
Verification(crate::package::error::VerificationError<E>),
Zip(crate::zip::error::ReadStreamError<E>),
AsyncTask(tokio::task::JoinError),
}
impl<E> std::fmt::Display for Error<E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Decryption failed")?;
if let Self::Thread(thread) = self {
write!(f, ": {thread} thread: join error")?;
}
Ok(())
}
}
impl<E: core::error::Error + 'static> core::error::Error for Error<E> {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::IO(source) => Some(source),
Self::Pgp(source) => Some(source),
Self::ChecksumMismatch(source) => Some(source),
Self::Verification(source) => Some(source),
Self::Zip(source) => Some(source),
Self::AsyncTask(source) => Some(source),
Self::Send(_) | Self::Thread(_) => None,
}
}
}
impl<E> From<std::io::Error> for Error<E> {
fn from(value: std::io::Error) -> Self {
Self::IO(value)
}
}
impl<E, T> From<std::sync::mpsc::SendError<T>> for Error<E> {
fn from(value: std::sync::mpsc::SendError<T>) -> Self {
Self::Send(format!("{value}"))
}
}
impl<E> From<crate::openpgp::error::PgpError> for Error<E> {
fn from(value: crate::openpgp::error::PgpError) -> Self {
Self::Pgp(value)
}
}
impl<E> From<ChecksumMismatchError> for Error<E> {
fn from(value: ChecksumMismatchError) -> Self {
Self::ChecksumMismatch(value)
}
}
impl<E> From<crate::package::error::MetadataError<E>> for Error<E> {
fn from(value: crate::package::error::MetadataError<E>) -> Self {
Self::Verification(value.into())
}
}
impl<E> From<crate::package::error::VerificationError<E>> for Error<E> {
fn from(value: crate::package::error::VerificationError<E>) -> Self {
Self::Verification(value)
}
}
impl<E> From<crate::zip::error::ReadStreamError<E>> for Error<E> {
fn from(value: crate::zip::error::ReadStreamError<E>) -> Self {
Self::Zip(value)
}
}
impl<E> From<tokio::task::JoinError> for Error<E> {
fn from(value: tokio::task::JoinError) -> Self {
Self::AsyncTask(value)
}
}
#[derive(Debug)]
pub struct ChecksumMismatchError {
pub path: std::path::PathBuf,
pub expected: String,
pub actual: String,
}
impl std::fmt::Display for ChecksumMismatchError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"wrong checksum for {:?} (expected {}, computed {})",
self.path, self.expected, self.actual
)
}
}
impl core::error::Error for ChecksumMismatchError {}