use std::io::{Read as _, Write as _};
use std::path::Path;
use std::str::FromStr as _;
use age::secrecy::{ExposeSecret as _, SecretString};
use age::x25519;
use zeroize::Zeroizing;
use crate::error::{Error, Result};
pub fn encrypt(plaintext: &[u8], recipients: &[x25519::Recipient]) -> Result<Vec<u8>> {
if recipients.is_empty() {
return Err(Error::Encrypt("no recipients".into()));
}
let encryptor =
age::Encryptor::with_recipients(recipients.iter().map(|r| -> &dyn age::Recipient { r }))
.map_err(|e| Error::Encrypt(e.to_string()))?;
let mut output = Vec::with_capacity(plaintext.len() + 256);
let mut writer = encryptor
.wrap_output(&mut output)
.map_err(|e| Error::Encrypt(e.to_string()))?;
writer
.write_all(plaintext)
.map_err(|e| Error::Encrypt(e.to_string()))?;
writer.finish().map_err(|e| Error::Encrypt(e.to_string()))?;
Ok(output)
}
pub fn decrypt(ciphertext: &[u8], identity: &x25519::Identity) -> Result<Zeroizing<Vec<u8>>> {
let decryptor =
age::Decryptor::new_buffered(ciphertext).map_err(|e| Error::Decrypt(e.to_string()))?;
if decryptor.is_scrypt() {
return Err(Error::Decrypt(
"file was encrypted with a passphrase, not a recipient".into(),
));
}
let identities: [&dyn age::Identity; 1] = [identity];
let mut reader = decryptor
.decrypt(identities.into_iter())
.map_err(|e| Error::Decrypt(e.to_string()))?;
let mut buf = Zeroizing::new(Vec::with_capacity(ciphertext.len()));
reader
.read_to_end(&mut buf)
.map_err(|e| Error::Decrypt(e.to_string()))?;
Ok(buf)
}
pub fn create_identity(path: &Path, passphrase: SecretString) -> Result<x25519::Identity> {
if path.exists() {
return Err(Error::IdentityExists(path.to_path_buf()));
}
let identity = x25519::Identity::generate();
let serialised = identity.to_string();
let ciphertext = encrypt_with_passphrase(serialised.expose_secret().as_bytes(), passphrase)?;
write_atomic(path, &ciphertext)?;
Ok(identity)
}
pub fn load_identity(path: &Path, passphrase: SecretString) -> Result<x25519::Identity> {
if !path.exists() {
return Err(Error::IdentityNotFound(path.to_path_buf()));
}
identity_from_ciphertext(&std::fs::read(path)?, passphrase)
}
fn identity_from_ciphertext(
ciphertext: &[u8],
passphrase: SecretString,
) -> Result<x25519::Identity> {
let plaintext = decrypt_with_passphrase(ciphertext, passphrase)?;
parse_identity(&plaintext)
}
pub fn change_passphrase(path: &Path, current: SecretString, new: SecretString) -> Result<()> {
let identity = load_identity(path, current)?;
let serialised = identity.to_string();
let ciphertext = encrypt_with_passphrase(serialised.expose_secret().as_bytes(), new)?;
write_atomic(path, &ciphertext)?;
Ok(())
}
pub fn export_identity(src: &Path, dst: &Path, armor: bool) -> Result<()> {
if !src.exists() {
return Err(Error::IdentityNotFound(src.to_path_buf()));
}
if dst.exists() {
return Err(Error::IdentityExists(dst.to_path_buf()));
}
let raw = std::fs::read(src)?;
let bytes = if armor { to_armor(&raw)? } else { raw };
write_atomic(dst, &bytes)
}
pub fn armored_identity(src: &Path) -> Result<String> {
if !src.exists() {
return Err(Error::IdentityNotFound(src.to_path_buf()));
}
let armored = to_armor(&std::fs::read(src)?)?;
String::from_utf8(armored)
.map_err(|e| Error::Encrypt(format!("armored identity is not valid UTF-8: {e}")))
}
pub fn import_identity(
backup: &[u8],
dst: &Path,
passphrase: SecretString,
force: bool,
) -> Result<x25519::Identity> {
let raw = from_armor(backup)?;
let identity = identity_from_ciphertext(&raw, passphrase)?;
if dst.exists() && !force {
return Err(Error::IdentityExists(dst.to_path_buf()));
}
write_atomic(dst, &raw)?;
Ok(identity)
}
fn to_armor(raw: &[u8]) -> Result<Vec<u8>> {
use age::armor::{ArmoredWriter, Format};
let mut out = Vec::with_capacity(raw.len());
let mut writer = ArmoredWriter::wrap_output(&mut out, Format::AsciiArmor)
.map_err(|e| Error::Encrypt(e.to_string()))?;
writer
.write_all(raw)
.map_err(|e| Error::Encrypt(e.to_string()))?;
writer.finish().map_err(|e| Error::Encrypt(e.to_string()))?;
Ok(out)
}
fn from_armor(input: &[u8]) -> Result<Vec<u8>> {
use age::armor::ArmoredReader;
let mut reader = ArmoredReader::new(input);
let mut out = Vec::with_capacity(input.len());
reader
.read_to_end(&mut out)
.map_err(|e| Error::Decrypt(e.to_string()))?;
Ok(out)
}
pub fn parse_recipients(text: &str) -> Result<Vec<x25519::Recipient>> {
let mut out = Vec::new();
for (idx, raw) in text.lines().enumerate() {
let line = raw.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
let recipient = x25519::Recipient::from_str(line)
.map_err(|e| Error::InvalidRecipient(format!("line {}: {e}", idx.saturating_add(1))))?;
if !recipients_contain(&out, &recipient) {
out.push(recipient);
}
}
Ok(out)
}
pub fn load_recipients(path: &Path) -> Result<Vec<x25519::Recipient>> {
if !path.exists() {
return Err(Error::NoRecipients(path.to_path_buf()));
}
let recipients = parse_recipients(&std::fs::read_to_string(path)?)?;
if recipients.is_empty() {
return Err(Error::NoRecipients(path.to_path_buf()));
}
Ok(recipients)
}
pub fn save_recipients(path: &Path, recipients: &[x25519::Recipient]) -> Result<()> {
let mut keys: Vec<String> = recipients.iter().map(ToString::to_string).collect();
keys.sort_unstable();
keys.dedup();
let mut body = String::from(
"# ks recipients — public keys allowed to decrypt this store.\n\
# Add one with `ks recipients add <age1...>`.\n",
);
for key in &keys {
body.push_str(key);
body.push('\n');
}
write_atomic(path, body.as_bytes())
}
#[must_use]
pub fn recipients_contain(list: &[x25519::Recipient], target: &x25519::Recipient) -> bool {
let needle = target.to_string();
list.iter().any(|r| r.to_string() == needle)
}
pub(crate) fn write_atomic(path: &Path, bytes: &[u8]) -> Result<()> {
let parent = path.parent().unwrap_or_else(|| Path::new("."));
create_dir_all_secure(parent)?;
let file_name = path
.file_name()
.and_then(|s| s.to_str())
.ok_or_else(|| Error::Io(std::io::Error::other("invalid target file name")))?;
let tmp = parent.join(format!(".{file_name}.{:016x}.tmp", rand::random::<u64>()));
let write = || -> Result<()> {
let mut file = open_excl_owner_only(&tmp)?;
file.write_all(bytes)?;
file.sync_all()?;
Ok(())
};
if let Err(e) = write() {
std::fs::remove_file(&tmp).ok();
return Err(e);
}
if let Err(e) = std::fs::rename(&tmp, path) {
std::fs::remove_file(&tmp).ok();
return Err(Error::Io(e));
}
fsync_dir(parent);
Ok(())
}
pub(crate) fn rename_replace(src: &Path, dst: &Path) -> Result<()> {
let parent = dst.parent().unwrap_or_else(|| Path::new("."));
create_dir_all_secure(parent)?;
std::fs::rename(src, dst)?;
fsync_dir(parent);
Ok(())
}
fn encrypt_with_passphrase(plaintext: &[u8], passphrase: SecretString) -> Result<Vec<u8>> {
let encryptor = age::Encryptor::with_user_passphrase(passphrase);
let mut output = Vec::with_capacity(plaintext.len() + 256);
let mut writer = encryptor
.wrap_output(&mut output)
.map_err(|e| Error::Encrypt(e.to_string()))?;
writer
.write_all(plaintext)
.map_err(|e| Error::Encrypt(e.to_string()))?;
writer.finish().map_err(|e| Error::Encrypt(e.to_string()))?;
Ok(output)
}
fn decrypt_with_passphrase(
ciphertext: &[u8],
passphrase: SecretString,
) -> Result<Zeroizing<Vec<u8>>> {
let decryptor =
age::Decryptor::new_buffered(ciphertext).map_err(|e| Error::Decrypt(e.to_string()))?;
if !decryptor.is_scrypt() {
return Err(Error::Decrypt(
"file was encrypted to a recipient, not a passphrase".into(),
));
}
let identity = age::scrypt::Identity::new(passphrase);
let identities: [&dyn age::Identity; 1] = [&identity];
let mut reader = decryptor
.decrypt(identities.into_iter())
.map_err(|_| Error::WrongPassphrase)?;
let mut buf = Zeroizing::new(Vec::with_capacity(ciphertext.len()));
reader
.read_to_end(&mut buf)
.map_err(|e| Error::Decrypt(e.to_string()))?;
Ok(buf)
}
fn parse_identity(plaintext: &[u8]) -> Result<x25519::Identity> {
let text = std::str::from_utf8(plaintext)
.map_err(|e| Error::Decrypt(format!("identity is not valid UTF-8: {e}")))?;
for raw in text.lines() {
let line = raw.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
return x25519::Identity::from_str(line)
.map_err(|e| Error::Decrypt(format!("invalid identity payload: {e}")));
}
Err(Error::Decrypt("identity file is empty".into()))
}
#[cfg(unix)]
pub(crate) fn create_dir_all_secure(dir: &Path) -> Result<()> {
use std::os::unix::fs::DirBuilderExt as _;
std::fs::DirBuilder::new()
.recursive(true)
.mode(0o700)
.create(dir)
.map_err(Error::Io)
}
#[cfg(not(unix))]
pub(crate) fn create_dir_all_secure(dir: &Path) -> Result<()> {
std::fs::create_dir_all(dir).map_err(Error::Io)
}
#[cfg(unix)]
fn open_excl_owner_only(path: &Path) -> Result<std::fs::File> {
use std::os::unix::fs::OpenOptionsExt as _;
std::fs::OpenOptions::new()
.write(true)
.create_new(true)
.mode(0o600)
.open(path)
.map_err(Error::Io)
}
#[cfg(not(unix))]
fn open_excl_owner_only(path: &Path) -> Result<std::fs::File> {
std::fs::OpenOptions::new()
.write(true)
.create_new(true)
.open(path)
.map_err(Error::Io)
}
#[cfg(unix)]
fn fsync_dir(dir: &Path) {
if let Ok(f) = std::fs::File::open(dir) {
f.sync_all().ok();
}
}
#[cfg(not(unix))]
const fn fsync_dir(_dir: &Path) {}
#[cfg(test)]
mod tests {
use super::*;
fn tempdir() -> std::path::PathBuf {
let dir = std::env::temp_dir().join(format!("ks-crypto-{}", rand::random::<u64>()));
std::fs::create_dir_all(&dir).expect("create temp dir");
dir
}
#[test]
fn recipient_roundtrip() {
let identity = x25519::Identity::generate();
let ct = encrypt(b"super secret api token", &[identity.to_public()]).expect("encrypt");
let pt = decrypt(&ct, &identity).expect("decrypt");
assert_eq!(&pt[..], b"super secret api token");
}
#[test]
fn identity_create_load_roundtrip() {
let path = tempdir().join("identity.age");
let pp = SecretString::from("hunter2".to_owned());
let created = create_identity(&path, pp.clone()).expect("create");
let loaded = load_identity(&path, pp).expect("load");
assert_eq!(
created.to_public().to_string(),
loaded.to_public().to_string()
);
}
#[test]
fn identity_refuses_overwrite() {
let path = tempdir().join("identity.age");
let pp = SecretString::from("pw".to_owned());
create_identity(&path, pp.clone()).expect("first");
assert!(matches!(
create_identity(&path, pp),
Err(Error::IdentityExists(_))
));
}
#[test]
fn identity_wrong_passphrase_distinguishable() {
let path = tempdir().join("identity.age");
create_identity(&path, SecretString::from("right".to_owned())).expect("create");
let err = load_identity(&path, SecretString::from("wrong".to_owned()))
.err()
.expect("must fail");
assert!(matches!(err, Error::WrongPassphrase));
}
#[test]
fn change_passphrase_works() {
let path = tempdir().join("identity.age");
let one = SecretString::from("one".to_owned());
let two = SecretString::from("two".to_owned());
create_identity(&path, one.clone()).expect("create");
change_passphrase(&path, one.clone(), two.clone()).expect("change");
assert!(load_identity(&path, one).is_err());
assert!(load_identity(&path, two).is_ok());
}
#[test]
fn identity_export_import_roundtrip_binary() {
let dir = tempdir();
let src = dir.join("identity.age");
let pp = SecretString::from("backup-pw".to_owned());
let created = create_identity(&src, pp.clone()).expect("create");
let backup = dir.join("backup.age");
export_identity(&src, &backup, false).expect("export");
assert!(
matches!(
export_identity(&src, &backup, false),
Err(Error::IdentityExists(_))
),
"export must refuse to overwrite an existing backup"
);
let restored_path = dir.join("restored.age");
let bytes = std::fs::read(&backup).expect("read backup");
let restored = import_identity(&bytes, &restored_path, pp, false).expect("import");
assert_eq!(
created.to_public().to_string(),
restored.to_public().to_string()
);
load_identity(&restored_path, SecretString::from("backup-pw".to_owned()))
.expect("restored identity loads normally");
}
#[test]
fn identity_export_import_roundtrip_armored() {
let dir = tempdir();
let src = dir.join("identity.age");
let pp = SecretString::from("backup-pw".to_owned());
let created = create_identity(&src, pp.clone()).expect("create");
let armored = armored_identity(&src).expect("armor");
assert!(
armored.contains("BEGIN AGE ENCRYPTED FILE"),
"armored output must carry the age armor header"
);
let restored_path = dir.join("restored.age");
let restored =
import_identity(armored.as_bytes(), &restored_path, pp, false).expect("import armored");
assert_eq!(
created.to_public().to_string(),
restored.to_public().to_string()
);
}
#[test]
fn identity_import_rejects_wrong_passphrase() {
let dir = tempdir();
let src = dir.join("identity.age");
create_identity(&src, SecretString::from("right".to_owned())).expect("create");
let bytes = std::fs::read(&src).expect("read");
let restored = dir.join("restored.age");
let err = import_identity(
&bytes,
&restored,
SecretString::from("wrong".to_owned()),
false,
)
.err()
.expect("must fail");
assert!(matches!(err, Error::WrongPassphrase));
assert!(
!restored.exists(),
"a failed import must not write the destination"
);
}
#[test]
fn identity_import_refuses_overwrite_without_force() {
let dir = tempdir();
let src = dir.join("identity.age");
let pp = SecretString::from("pw".to_owned());
create_identity(&src, pp.clone()).expect("create");
let bytes = std::fs::read(&src).expect("read");
let dst = dir.join("existing.age");
std::fs::write(&dst, b"do not clobber").expect("write existing");
assert!(matches!(
import_identity(&bytes, &dst, pp.clone(), false),
Err(Error::IdentityExists(_))
));
import_identity(&bytes, &dst, pp, true).expect("force import");
load_identity(&dst, SecretString::from("pw".to_owned())).expect("load forced");
}
#[test]
fn recipients_parse_skips_comments() {
let id = x25519::Identity::generate();
let pubkey = id.to_public().to_string();
let parsed = parse_recipients(&format!("# c\n\n{pubkey}\n")).expect("parse");
assert_eq!(parsed.len(), 1);
assert_eq!(parsed.first().expect("one recipient").to_string(), pubkey);
}
#[test]
fn recipients_save_load_roundtrip() {
let path = tempdir().join(".age-recipients");
let id = x25519::Identity::generate();
let r = id.to_public();
save_recipients(&path, std::slice::from_ref(&r)).expect("save");
let loaded = load_recipients(&path).expect("load");
assert_eq!(loaded.len(), 1);
assert!(recipients_contain(&loaded, &r));
}
#[test]
fn recipients_reject_invalid() {
assert!(parse_recipients("not-a-key").is_err());
}
#[test]
fn recipients_parse_collapses_duplicates() {
let pubkey = x25519::Identity::generate().to_public().to_string();
let parsed = parse_recipients(&format!("{pubkey}\n{pubkey}\n")).expect("parse");
assert_eq!(parsed.len(), 1, "duplicate keys must be collapsed");
}
#[test]
fn recipients_save_is_sorted_and_deduped() {
let path = tempdir().join(".age-recipients");
let a = x25519::Identity::generate().to_public();
let b = x25519::Identity::generate().to_public();
save_recipients(&path, &[b.clone(), a, b]).expect("save");
let body = std::fs::read_to_string(&path).expect("read");
let keys: Vec<&str> = body
.lines()
.filter(|l| !l.is_empty() && !l.starts_with('#'))
.collect();
assert_eq!(keys.len(), 2, "duplicates must be removed on save");
let mut sorted = keys.clone();
sorted.sort_unstable();
assert_eq!(
keys, sorted,
"keys must be written in canonical sorted order"
);
}
}