use bytes::Bytes;
use log::error;
use rand::distr::Alphanumeric;
use rand::{Rng, rng};
use std::io::{Error, ErrorKind};
use url::Url;
#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
pub enum DatapipeError {
#[error("ConfigurationError: {0}")]
ConfigurationError(String),
#[error("InputOutputError: {0}")]
InputOutputError(String),
#[error("EncryptionError: {0}")]
EncryptionError(String),
#[error("ValidationError: {0}")]
ValidationError(String),
}
impl From<chacha20poly1305::Error> for DatapipeError {
fn from(error: chacha20poly1305::Error) -> Self {
Self::EncryptionError(format!("{error}"))
}
}
impl From<std::io::Error> for DatapipeError {
fn from(error: std::io::Error) -> Self {
let error_string = error_root_cause(&error);
Self::InputOutputError(error_string)
}
}
impl From<std::env::VarError> for DatapipeError {
fn from(error: std::env::VarError) -> Self {
Self::ConfigurationError(format!("{error}"))
}
}
impl From<rustls::Error> for DatapipeError {
fn from(error: rustls::Error) -> Self {
Self::ConfigurationError(format!("{error}"))
}
}
impl From<rustls::client::VerifierBuilderError> for DatapipeError {
fn from(error: rustls::client::VerifierBuilderError) -> Self {
Self::ConfigurationError(format!("{error}"))
}
}
pub fn error_root_cause(mut err: &(dyn std::error::Error + 'static)) -> String {
use std::fmt::Write;
let mut s = format!("{}", err);
while let Some(src) = err.source() {
let _ = write!(s, "\n\tCaused by: {}", src);
err = src;
}
s
}
#[allow(async_fn_in_trait)]
pub trait InputReader {
async fn read(&mut self) -> Result<Bytes, Error>;
}
#[allow(async_fn_in_trait)]
pub trait OutputWriter {
async fn write(&mut self, bytes: &[u8]) -> Result<(), Error>;
}
pub fn generate_random_string(len: usize) -> String {
rng()
.sample_iter(&Alphanumeric)
.take(len)
.map(char::from)
.collect()
}
const KEY_LENGTH: usize = 32;
type KeyBytes = [u8; KEY_LENGTH];
const NONCE_LENGTH: usize = 19;
type NonceBytes = [u8; NONCE_LENGTH];
const REQUIRED_LENGTH: usize = KEY_LENGTH + NONCE_LENGTH;
#[derive(Clone, Debug)]
pub struct EncryptionKey {
pub key: KeyBytes,
pub nonce: NonceBytes,
}
impl EncryptionKey {
pub fn new(encryption_key: &str) -> Result<Self, DatapipeError> {
if encryption_key.len() != REQUIRED_LENGTH {
let error_message = format!(
"Encryption key must be {} bytes long; provided encryption key is {} bytes long",
REQUIRED_LENGTH,
encryption_key.len()
);
error!("{error_message}");
return Err(DatapipeError::ValidationError(error_message));
}
let encryption_key_bytes = encryption_key.as_bytes();
Ok(Self {
key: <KeyBytes>::try_from(&encryption_key_bytes[0..KEY_LENGTH]).unwrap(),
nonce: <NonceBytes>::try_from(&encryption_key_bytes[KEY_LENGTH..]).unwrap(),
})
}
pub fn generate() -> Self {
let encryption_key = generate_random_string(REQUIRED_LENGTH);
let encryption_key_bytes = encryption_key.into_bytes();
Self {
key: <KeyBytes>::try_from(&encryption_key_bytes[0..KEY_LENGTH]).unwrap(),
nonce: <NonceBytes>::try_from(&encryption_key_bytes[KEY_LENGTH..]).unwrap(),
}
}
}
impl std::fmt::Display for EncryptionKey {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
let mut bytes: [u8; REQUIRED_LENGTH] = [0; REQUIRED_LENGTH];
bytes[0..KEY_LENGTH].copy_from_slice(self.key.as_slice());
bytes[KEY_LENGTH..REQUIRED_LENGTH].copy_from_slice(self.nonce.as_slice());
write!(f, "{}", String::from_utf8(bytes.to_vec()).unwrap())
}
}
pub fn good_url(maybe_url: &str, prefix: &str) -> Result<url::Url, Error> {
match maybe_url.starts_with(prefix) {
true => match Url::parse(maybe_url) {
Ok(url) => Ok(url),
Err(error) => {
let error_message = format!("Error parsing URL '{}': {}", maybe_url, error);
error!("{}", error_message);
Err(Error::new(ErrorKind::InvalidInput, error_message))
}
},
false => {
let error_message = format!("URL '{}' must start with '{}'", maybe_url, prefix);
error!("{}", error_message);
Err(Error::new(ErrorKind::InvalidInput, error_message))
}
}
}