use crate::error::{DecryptError, EncryptError};
use sodiumoxide::crypto::pwhash;
use sodiumoxide::crypto::secretstream::xchacha20poly1305::{Header, Key};
use sodiumoxide::crypto::secretstream::{Stream, Tag, ABYTES, HEADERBYTES, KEYBYTES};
use std::io::Write as _;
use std::sync::Arc;
const CHUNKSIZE: usize = 4096;
pub const SIGNATURE: [u8; 4] = [0xC1, 0x0A, 0x4B, 0xED];
#[derive(Clone)]
pub struct Password {
password: Arc<String>,
}
impl Password {
pub fn new(password: impl Into<String>) -> Self {
Password {
password: Arc::new(password.into()),
}
}
pub fn encrypt(&self, bytes: impl AsRef<[u8]>) -> Result<Vec<u8>, EncryptError> {
let bytes = bytes.as_ref();
let mut output = Vec::default();
output.write_all(&SIGNATURE).map_err(EncryptError::write)?;
let salt = pwhash::gen_salt();
output.write_all(&salt.0).map_err(EncryptError::write)?;
let mut key = [0u8; KEYBYTES];
pwhash::derive_key_interactive(&mut key, self.password.as_bytes(), &salt).unwrap();
let key = Key(key);
let mut offset = 0;
let (mut stream, header) = Stream::init_push(&key).map_err(|_| EncryptError::Init)?;
output.write_all(&header.0).map_err(EncryptError::write)?;
while offset < bytes.len() {
let bytes_left = bytes.len().saturating_sub(offset);
let tag = match bytes_left {
0 => Tag::Final,
_ => Tag::Message,
};
let end = std::cmp::min(offset + CHUNKSIZE, bytes.len());
output
.write_all(
&stream
.push(&bytes[offset..end], None, tag)
.map_err(|_| EncryptError::EncryptChunk)?,
)
.map_err(EncryptError::write)?;
offset += CHUNKSIZE;
}
Ok(output)
}
pub fn decrypt(&self, bytes: impl AsRef<[u8]>) -> Result<Vec<u8>, DecryptError> {
let bytes = bytes.as_ref();
if bytes.len() <= (pwhash::SALTBYTES + HEADERBYTES + SIGNATURE.len()) {
return Err(DecryptError::InputTooShort);
}
let mut offset = 0;
let mut salt = [0u8; pwhash::SALTBYTES];
let mut signature = [0u8; 4];
signature.copy_from_slice(&bytes[offset..offset + SIGNATURE.len()]);
offset += signature.len();
salt.copy_from_slice(&bytes[offset..offset + pwhash::SALTBYTES]);
offset += salt.len();
let salt = pwhash::Salt(salt);
let mut header = [0u8; HEADERBYTES];
header.copy_from_slice(&bytes[offset..offset + HEADERBYTES]);
offset += header.len();
let header = Header(header);
let mut key = [0u8; KEYBYTES];
pwhash::derive_key(
&mut key,
self.password.as_bytes(),
&salt,
pwhash::OPSLIMIT_INTERACTIVE,
pwhash::MEMLIMIT_INTERACTIVE,
)
.map_err(|_| DecryptError::DeriveKey)?;
let key = Key(key);
let mut stream = Stream::init_pull(&header, &key).map_err(|_| DecryptError::Init)?;
let mut output = Vec::new();
while stream.is_not_finalized() {
if offset >= bytes.len() {
break;
}
let end = std::cmp::min(offset + CHUNKSIZE + ABYTES, bytes.len());
let (decrypted, _tag) = stream
.pull(&bytes[offset..end], None)
.map_err(|_| DecryptError::LikelyWrongPassword)?;
output.write_all(&decrypted).map_err(DecryptError::write)?;
offset = end;
}
Ok(output)
}
}
impl From<super::Password> for Password {
fn from(pass: super::Password) -> Self {
Self {
password: pass.password,
}
}
}