use std::io::{Read as _, Write as _};
use std::path::Path;
use std::str::FromStr as _;
use age::secrecy::{ExposeSecret as _, SecretString};
use age::x25519;
use zeroize::Zeroizing;
use crate::error::{Error, Result};
pub fn encrypt(plaintext: &[u8], recipients: &[x25519::Recipient]) -> Result<Vec<u8>> {
if recipients.is_empty() {
return Err(Error::Encrypt("no recipients".into()));
}
let encryptor =
age::Encryptor::with_recipients(recipients.iter().map(|r| -> &dyn age::Recipient { r }))
.map_err(|e| Error::Encrypt(e.to_string()))?;
let mut output = Vec::with_capacity(plaintext.len() + 256);
let mut writer = encryptor
.wrap_output(&mut output)
.map_err(|e| Error::Encrypt(e.to_string()))?;
writer
.write_all(plaintext)
.map_err(|e| Error::Encrypt(e.to_string()))?;
writer.finish().map_err(|e| Error::Encrypt(e.to_string()))?;
Ok(output)
}
pub fn decrypt(ciphertext: &[u8], identity: &x25519::Identity) -> Result<Zeroizing<Vec<u8>>> {
let decryptor =
age::Decryptor::new_buffered(ciphertext).map_err(|e| Error::Decrypt(e.to_string()))?;
if decryptor.is_scrypt() {
return Err(Error::Decrypt(
"file was encrypted with a passphrase, not a recipient".into(),
));
}
let identities: [&dyn age::Identity; 1] = [identity];
let mut reader = decryptor
.decrypt(identities.into_iter())
.map_err(|e| Error::Decrypt(e.to_string()))?;
let mut buf = Zeroizing::new(Vec::with_capacity(ciphertext.len()));
reader
.read_to_end(&mut buf)
.map_err(|e| Error::Decrypt(e.to_string()))?;
Ok(buf)
}
pub fn create_identity(path: &Path, passphrase: SecretString) -> Result<x25519::Identity> {
if path.exists() {
return Err(Error::IdentityExists(path.to_path_buf()));
}
let identity = x25519::Identity::generate();
let serialised = identity.to_string();
let ciphertext = encrypt_with_passphrase(serialised.expose_secret().as_bytes(), passphrase)?;
write_atomic(path, &ciphertext)?;
Ok(identity)
}
pub fn load_identity(path: &Path, passphrase: SecretString) -> Result<x25519::Identity> {
if !path.exists() {
return Err(Error::IdentityNotFound(path.to_path_buf()));
}
let ciphertext = std::fs::read(path)?;
let plaintext = decrypt_with_passphrase(&ciphertext, passphrase)?;
parse_identity(&plaintext)
}
pub fn change_passphrase(path: &Path, current: SecretString, new: SecretString) -> Result<()> {
let identity = load_identity(path, current)?;
let serialised = identity.to_string();
let ciphertext = encrypt_with_passphrase(serialised.expose_secret().as_bytes(), new)?;
write_atomic(path, &ciphertext)?;
Ok(())
}
pub fn parse_recipients(text: &str) -> Result<Vec<x25519::Recipient>> {
let mut out = Vec::new();
for (idx, raw) in text.lines().enumerate() {
let line = raw.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
let recipient = x25519::Recipient::from_str(line)
.map_err(|e| Error::InvalidRecipient(format!("line {}: {e}", idx.saturating_add(1))))?;
out.push(recipient);
}
Ok(out)
}
pub fn load_recipients(path: &Path) -> Result<Vec<x25519::Recipient>> {
if !path.exists() {
return Err(Error::NoRecipients(path.to_path_buf()));
}
let recipients = parse_recipients(&std::fs::read_to_string(path)?)?;
if recipients.is_empty() {
return Err(Error::NoRecipients(path.to_path_buf()));
}
Ok(recipients)
}
pub fn save_recipients(path: &Path, recipients: &[x25519::Recipient]) -> Result<()> {
let mut body = String::from(
"# ks recipients — public keys allowed to decrypt this store.\n\
# Add one with `ks recipients add <age1...>`.\n",
);
for r in recipients {
body.push_str(&r.to_string());
body.push('\n');
}
write_atomic(path, body.as_bytes())
}
#[must_use]
pub fn recipients_contain(list: &[x25519::Recipient], target: &x25519::Recipient) -> bool {
let needle = target.to_string();
list.iter().any(|r| r.to_string() == needle)
}
pub(crate) fn write_atomic(path: &Path, bytes: &[u8]) -> Result<()> {
let parent = path.parent().unwrap_or_else(|| Path::new("."));
create_dir_all_secure(parent)?;
let file_name = path
.file_name()
.and_then(|s| s.to_str())
.ok_or_else(|| Error::Io(std::io::Error::other("invalid target file name")))?;
let tmp = parent.join(format!(".{file_name}.{:016x}.tmp", rand::random::<u64>()));
let write = || -> Result<()> {
let mut file = open_excl_owner_only(&tmp)?;
file.write_all(bytes)?;
file.sync_all()?;
Ok(())
};
if let Err(e) = write() {
std::fs::remove_file(&tmp).ok();
return Err(e);
}
if let Err(e) = std::fs::rename(&tmp, path) {
std::fs::remove_file(&tmp).ok();
return Err(Error::Io(e));
}
fsync_dir(parent);
Ok(())
}
pub(crate) fn rename_replace(src: &Path, dst: &Path) -> Result<()> {
let parent = dst.parent().unwrap_or_else(|| Path::new("."));
create_dir_all_secure(parent)?;
std::fs::rename(src, dst)?;
fsync_dir(parent);
Ok(())
}
fn encrypt_with_passphrase(plaintext: &[u8], passphrase: SecretString) -> Result<Vec<u8>> {
let encryptor = age::Encryptor::with_user_passphrase(passphrase);
let mut output = Vec::with_capacity(plaintext.len() + 256);
let mut writer = encryptor
.wrap_output(&mut output)
.map_err(|e| Error::Encrypt(e.to_string()))?;
writer
.write_all(plaintext)
.map_err(|e| Error::Encrypt(e.to_string()))?;
writer.finish().map_err(|e| Error::Encrypt(e.to_string()))?;
Ok(output)
}
fn decrypt_with_passphrase(
ciphertext: &[u8],
passphrase: SecretString,
) -> Result<Zeroizing<Vec<u8>>> {
let decryptor =
age::Decryptor::new_buffered(ciphertext).map_err(|e| Error::Decrypt(e.to_string()))?;
if !decryptor.is_scrypt() {
return Err(Error::Decrypt(
"file was encrypted to a recipient, not a passphrase".into(),
));
}
let identity = age::scrypt::Identity::new(passphrase);
let identities: [&dyn age::Identity; 1] = [&identity];
let mut reader = decryptor
.decrypt(identities.into_iter())
.map_err(|_| Error::WrongPassphrase)?;
let mut buf = Zeroizing::new(Vec::with_capacity(ciphertext.len()));
reader
.read_to_end(&mut buf)
.map_err(|e| Error::Decrypt(e.to_string()))?;
Ok(buf)
}
fn parse_identity(plaintext: &[u8]) -> Result<x25519::Identity> {
let text = std::str::from_utf8(plaintext)
.map_err(|e| Error::Decrypt(format!("identity is not valid UTF-8: {e}")))?;
for raw in text.lines() {
let line = raw.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
return x25519::Identity::from_str(line)
.map_err(|e| Error::Decrypt(format!("invalid identity payload: {e}")));
}
Err(Error::Decrypt("identity file is empty".into()))
}
#[cfg(unix)]
pub(crate) fn create_dir_all_secure(dir: &Path) -> Result<()> {
use std::os::unix::fs::DirBuilderExt as _;
std::fs::DirBuilder::new()
.recursive(true)
.mode(0o700)
.create(dir)
.map_err(Error::Io)
}
#[cfg(not(unix))]
pub(crate) fn create_dir_all_secure(dir: &Path) -> Result<()> {
std::fs::create_dir_all(dir).map_err(Error::Io)
}
#[cfg(unix)]
fn open_excl_owner_only(path: &Path) -> Result<std::fs::File> {
use std::os::unix::fs::OpenOptionsExt as _;
std::fs::OpenOptions::new()
.write(true)
.create_new(true)
.mode(0o600)
.open(path)
.map_err(Error::Io)
}
#[cfg(not(unix))]
fn open_excl_owner_only(path: &Path) -> Result<std::fs::File> {
std::fs::OpenOptions::new()
.write(true)
.create_new(true)
.open(path)
.map_err(Error::Io)
}
#[cfg(unix)]
fn fsync_dir(dir: &Path) {
if let Ok(f) = std::fs::File::open(dir) {
f.sync_all().ok();
}
}
#[cfg(not(unix))]
const fn fsync_dir(_dir: &Path) {}
#[cfg(test)]
mod tests {
use super::*;
fn tempdir() -> std::path::PathBuf {
let dir = std::env::temp_dir().join(format!("ks-crypto-{}", rand::random::<u64>()));
std::fs::create_dir_all(&dir).expect("create temp dir");
dir
}
#[test]
fn recipient_roundtrip() {
let identity = x25519::Identity::generate();
let ct = encrypt(b"super secret api token", &[identity.to_public()]).expect("encrypt");
let pt = decrypt(&ct, &identity).expect("decrypt");
assert_eq!(&pt[..], b"super secret api token");
}
#[test]
fn identity_create_load_roundtrip() {
let path = tempdir().join("identity.age");
let pp = SecretString::from("hunter2".to_owned());
let created = create_identity(&path, pp.clone()).expect("create");
let loaded = load_identity(&path, pp).expect("load");
assert_eq!(
created.to_public().to_string(),
loaded.to_public().to_string()
);
}
#[test]
fn identity_refuses_overwrite() {
let path = tempdir().join("identity.age");
let pp = SecretString::from("pw".to_owned());
create_identity(&path, pp.clone()).expect("first");
assert!(matches!(
create_identity(&path, pp),
Err(Error::IdentityExists(_))
));
}
#[test]
fn identity_wrong_passphrase_distinguishable() {
let path = tempdir().join("identity.age");
create_identity(&path, SecretString::from("right".to_owned())).expect("create");
let err = load_identity(&path, SecretString::from("wrong".to_owned()))
.err()
.expect("must fail");
assert!(matches!(err, Error::WrongPassphrase));
}
#[test]
fn change_passphrase_works() {
let path = tempdir().join("identity.age");
let one = SecretString::from("one".to_owned());
let two = SecretString::from("two".to_owned());
create_identity(&path, one.clone()).expect("create");
change_passphrase(&path, one.clone(), two.clone()).expect("change");
assert!(load_identity(&path, one).is_err());
assert!(load_identity(&path, two).is_ok());
}
#[test]
fn recipients_parse_skips_comments() {
let id = x25519::Identity::generate();
let pubkey = id.to_public().to_string();
let parsed = parse_recipients(&format!("# c\n\n{pubkey}\n")).expect("parse");
assert_eq!(parsed.len(), 1);
assert_eq!(parsed.first().expect("one recipient").to_string(), pubkey);
}
#[test]
fn recipients_save_load_roundtrip() {
let path = tempdir().join(".age-recipients");
let id = x25519::Identity::generate();
let r = id.to_public();
save_recipients(&path, std::slice::from_ref(&r)).expect("save");
let loaded = load_recipients(&path).expect("load");
assert_eq!(loaded.len(), 1);
assert!(recipients_contain(&loaded, &r));
}
#[test]
fn recipients_reject_invalid() {
assert!(parse_recipients("not-a-key").is_err());
}
}