use anyhow::{Context, Result, anyhow, bail};
use argon2::Argon2;
use aws_lc_rs::{
aead::{AES_256_GCM, Aad, Nonce, RandomizedNonceKey},
rand,
};
use bincode_next::{Decode, Encode};
use keyring::{Entry, Error as KeyringError};
use libsalus::{SetInfo, decode, encode};
use zeroize::Zeroizing;
use crate::error::Error;
const SERVICE: &str = "salus";
const REGISTRY_ACCOUNT: &str = "sets";
const SALT_LEN: usize = 16;
const NONCE_LEN: usize = 12;
#[derive(Clone, Debug, Decode, Default, Encode)]
struct Registry {
shared_auto_count: u8,
sets: Vec<SetRecord>,
}
#[derive(Clone, Debug, Decode, Encode)]
struct SetRecord {
name: String,
auto_count: u8,
independent_auto: bool,
}
fn entry(account: &str) -> Result<Entry> {
Entry::new(SERVICE, account).with_context(|| format!("opening keyring entry '{account}'"))
}
fn read_registry() -> Result<Registry> {
let secret = entry(REGISTRY_ACCOUNT)?.get_secret();
if let Err(KeyringError::NoEntry) = &secret {
return Ok(Registry::default());
}
decode::<Registry>(&secret?)
}
fn write_registry(registry: &Registry) -> Result<()> {
let bytes = encode(registry.clone())?;
entry(REGISTRY_ACCOUNT)?
.set_secret(&bytes)
.context("writing the keyring set registry")?;
Ok(())
}
fn derive_key(passphrase: &str, salt: &[u8]) -> Result<Zeroizing<[u8; 32]>> {
let mut key = Zeroizing::new([0u8; 32]);
Argon2::default()
.hash_password_into(passphrase.as_bytes(), salt, &mut *key)
.map_err(|e| anyhow!("argon2 key derivation failed: {e}"))?;
Ok(key)
}
fn seal(plaintext: &str, passphrase: &str) -> Result<Vec<u8>> {
let mut salt = [0u8; SALT_LEN];
rand::fill(&mut salt)?;
let key = derive_key(passphrase, &salt)?;
let rnkey = RandomizedNonceKey::new(&AES_256_GCM, key.as_slice())
.with_context(|| Error::NonceKeyGen)?;
let mut in_out = plaintext.as_bytes().to_vec();
let nonce = rnkey.seal_in_place_append_tag(Aad::empty(), &mut in_out)?;
let mut blob = Vec::with_capacity(SALT_LEN + NONCE_LEN + in_out.len());
blob.extend_from_slice(&salt);
blob.extend_from_slice(nonce.as_ref());
blob.extend_from_slice(&in_out);
Ok(blob)
}
pub fn unseal(blob: &[u8], passphrase: &str) -> Result<Option<String>> {
if blob.len() < SALT_LEN + NONCE_LEN {
bail!("sealed share blob is malformed (too short)");
}
let (salt, rest) = blob.split_at(SALT_LEN);
let (nonce_bytes, ciphertext) = rest.split_at(NONCE_LEN);
let nonce_arr: [u8; NONCE_LEN] = nonce_bytes
.try_into()
.map_err(|_e| anyhow!("invalid nonce length"))?;
let key = derive_key(passphrase, salt)?;
let rnkey = RandomizedNonceKey::new(&AES_256_GCM, key.as_slice())
.with_context(|| Error::NonceKeyGen)?;
let mut buf = ciphertext.to_vec();
match rnkey.open_in_place(Nonce::from(&nonce_arr), Aad::empty(), &mut buf) {
Ok(plaintext) => {
let share = String::from_utf8(plaintext.to_vec())
.context("decrypted share is not valid UTF-8")?;
Ok(Some(share))
}
Err(_e) => Ok(None),
}
}
pub fn shared_auto_count() -> Result<Option<u8>> {
let registry = read_registry()?;
Ok((registry.shared_auto_count > 0).then_some(registry.shared_auto_count))
}
pub fn enroll_full(
name: &str,
shares: &[String],
passphrase: &str,
independent: bool,
force: bool,
) -> Result<()> {
if shares.len() < 2 {
bail!("enrollment needs at least the threshold number of shares (>= 2)");
}
let auto_count = u8::try_from(shares.len() - 1).context("too many shares to enroll")?;
let mut registry = read_registry()?;
if registry.sets.iter().any(|r| r.name == name) && !force {
bail!("a set named '{name}' is already enrolled; pass --force to replace it");
}
let (auto, manual) = shares.split_at(shares.len() - 1);
let final_share = &manual[0];
if independent {
for (i, share) in auto.iter().enumerate() {
entry(&format!("{name}/auto-share-{i}"))?
.set_password(share)
.context("writing a per-set automatic share")?;
}
} else {
for (i, share) in auto.iter().enumerate() {
entry(&format!("auto-share-{i}"))?
.set_password(share)
.context("writing a shared automatic share")?;
}
registry.shared_auto_count = auto_count;
}
let blob = seal(final_share, passphrase)?;
entry(&format!("{name}/final-blob"))?
.set_secret(&blob)
.context("writing the sealed final share")?;
registry.sets.retain(|r| r.name != name);
registry.sets.push(SetRecord {
name: name.to_string(),
auto_count,
independent_auto: independent,
});
write_registry(®istry)?;
Ok(())
}
pub fn enroll_final_only(
name: &str,
final_share: &str,
passphrase: &str,
force: bool,
) -> Result<()> {
let mut registry = read_registry()?;
let shared_auto_count = registry.shared_auto_count;
if shared_auto_count == 0 {
bail!("no shared automatic shares exist yet; enroll a first set fully");
}
if registry.sets.iter().any(|r| r.name == name) && !force {
bail!("a set named '{name}' is already enrolled; pass --force to replace it");
}
let blob = seal(final_share, passphrase)?;
entry(&format!("{name}/final-blob"))?
.set_secret(&blob)
.context("writing the sealed final share")?;
registry.sets.retain(|r| r.name != name);
registry.sets.push(SetRecord {
name: name.to_string(),
auto_count: shared_auto_count,
independent_auto: false,
});
write_registry(®istry)?;
Ok(())
}
pub fn forget(name: &str) -> Result<bool> {
let mut registry = read_registry()?;
let Some(pos) = registry.sets.iter().position(|r| r.name == name) else {
return Ok(false);
};
let record = registry.sets.remove(pos);
let _del = entry(&format!("{name}/final-blob"))?.delete_credential();
if record.independent_auto {
for i in 0..record.auto_count {
let _del = entry(&format!("{name}/auto-share-{i}"))?.delete_credential();
}
}
if registry.sets.is_empty() {
for i in 0..registry.shared_auto_count {
let _del = entry(&format!("auto-share-{i}"))?.delete_credential();
}
let _del = entry(REGISTRY_ACCOUNT)?.delete_credential();
} else {
write_registry(®istry)?;
}
Ok(true)
}
pub fn forget_all() -> Result<()> {
let registry = read_registry()?;
for record in ®istry.sets {
let _del = entry(&format!("{}/final-blob", record.name))?.delete_credential();
if record.independent_auto {
for i in 0..record.auto_count {
let _del = entry(&format!("{}/auto-share-{i}", record.name))?.delete_credential();
}
}
}
for i in 0..registry.shared_auto_count {
let _del = entry(&format!("auto-share-{i}"))?.delete_credential();
}
let _del = entry(REGISTRY_ACCOUNT)?.delete_credential();
Ok(())
}
pub fn list_sets() -> Result<Vec<SetInfo>> {
let registry = read_registry()?;
Ok(registry
.sets
.iter()
.map(|r| SetInfo {
name: r.name.clone(),
auto_count: r.auto_count,
})
.collect())
}
pub fn load_auto_shares(name: &str) -> Result<Vec<String>> {
let registry = read_registry()?;
let Some(record) = registry.sets.iter().find(|r| r.name == name) else {
bail!("no enrolled set named '{name}'");
};
let mut shares = Vec::with_capacity(usize::from(record.auto_count));
for i in 0..record.auto_count {
let account = if record.independent_auto {
format!("{name}/auto-share-{i}")
} else {
format!("auto-share-{i}")
};
shares.push(
entry(&account)?
.get_password()
.with_context(|| format!("reading automatic share '{account}'"))?,
);
}
Ok(shares)
}
pub fn load_sealed_blob(name: &str) -> Result<Option<Vec<u8>>> {
let secret = entry(&format!("{name}/final-blob"))?.get_secret();
if let Err(KeyringError::NoEntry) = &secret {
return Ok(None);
}
Ok(Some(secret?))
}
#[cfg(test)]
mod test {
use anyhow::Result;
use super::{
enroll_final_only, enroll_full, forget, forget_all, list_sets, load_auto_shares,
load_sealed_blob, read_registry, seal, shared_auto_count, unseal,
};
use crate::test_keyring::guard;
fn shares(n: usize) -> Vec<String> {
(0..n).map(|i| format!("share-{i}")).collect()
}
#[test]
fn seal_unseal_round_trip() {
let blob = seal("share-value", "correct horse battery staple").unwrap();
let out = unseal(&blob, "correct horse battery staple").unwrap();
assert_eq!(out.as_deref(), Some("share-value"));
}
#[test]
fn wrong_passphrase_returns_none() {
let blob = seal("share-value", "right-passphrase").unwrap();
let out = unseal(&blob, "wrong-passphrase").unwrap();
assert!(out.is_none());
}
#[test]
fn malformed_blob_errors() {
assert!(unseal(b"too-short", "whatever").is_err());
}
#[test]
fn read_registry_defaults_when_absent() -> Result<()> {
let _g = guard();
let registry = read_registry()?;
assert_eq!(registry.shared_auto_count, 0);
assert!(registry.sets.is_empty());
Ok(())
}
#[test]
fn enroll_full_shared_round_trips() -> Result<()> {
let _g = guard();
enroll_full("alpha", &shares(3), "pass", false, false)?;
let sets = list_sets()?;
assert_eq!(sets.len(), 1);
assert_eq!(sets[0].name, "alpha");
assert_eq!(sets[0].auto_count, 2);
assert_eq!(shared_auto_count()?, Some(2));
let auto = load_auto_shares("alpha")?;
assert_eq!(auto, vec!["share-0".to_string(), "share-1".to_string()]);
let blob = load_sealed_blob("alpha")?.expect("sealed blob present");
assert_eq!(unseal(&blob, "pass")?.as_deref(), Some("share-2"));
Ok(())
}
#[test]
fn enroll_full_independent_uses_per_set_shares() -> Result<()> {
let _g = guard();
enroll_full("beta", &shares(3), "pass", true, false)?;
assert_eq!(shared_auto_count()?, None);
assert_eq!(load_auto_shares("beta")?.len(), 2);
Ok(())
}
#[test]
fn enroll_full_rejects_single_share() {
let _g = guard();
assert!(enroll_full("gamma", &shares(1), "pass", false, false).is_err());
}
#[test]
fn enroll_full_rejects_duplicate_without_force() -> Result<()> {
let _g = guard();
enroll_full("delta", &shares(3), "pass", false, false)?;
assert!(enroll_full("delta", &shares(3), "pass", false, false).is_err());
enroll_full("delta", &shares(4), "pass2", false, true)?;
let sets = list_sets()?;
assert_eq!(sets.len(), 1);
assert_eq!(sets[0].auto_count, 3);
let blob = load_sealed_blob("delta")?.expect("sealed blob present");
assert_eq!(unseal(&blob, "pass2")?.as_deref(), Some("share-3"));
Ok(())
}
#[test]
fn enroll_final_only_reuses_shared_shares() -> Result<()> {
let _g = guard();
enroll_full("alpha", &shares(3), "pass", false, false)?;
enroll_final_only("epsilon", "final-share", "secret", false)?;
let sets = list_sets()?;
assert_eq!(sets.len(), 2);
let blob = load_sealed_blob("epsilon")?.expect("sealed blob present");
assert_eq!(unseal(&blob, "secret")?.as_deref(), Some("final-share"));
Ok(())
}
#[test]
fn enroll_final_only_requires_shared_shares() {
let _g = guard();
assert!(enroll_final_only("epsilon", "final-share", "secret", false).is_err());
}
#[test]
fn forget_unknown_returns_false() -> Result<()> {
let _g = guard();
assert!(!forget("nope")?);
Ok(())
}
#[test]
fn forget_removes_set_and_last_clears_shared() -> Result<()> {
let _g = guard();
enroll_full("alpha", &shares(3), "pass", false, false)?;
enroll_final_only("beta", "final", "secret", false)?;
assert!(forget("beta")?);
assert_eq!(list_sets()?.len(), 1);
assert_eq!(shared_auto_count()?, Some(2));
assert!(forget("alpha")?);
assert!(list_sets()?.is_empty());
assert_eq!(shared_auto_count()?, None);
Ok(())
}
#[test]
fn forget_all_clears_everything() -> Result<()> {
let _g = guard();
enroll_full("alpha", &shares(3), "pass", false, false)?;
enroll_full("beta", &shares(3), "pass", true, false)?;
forget_all()?;
assert!(list_sets()?.is_empty());
assert_eq!(shared_auto_count()?, None);
Ok(())
}
#[test]
fn load_auto_shares_unknown_errors() {
let _g = guard();
assert!(load_auto_shares("missing").is_err());
}
#[test]
fn load_sealed_blob_absent_returns_none() -> Result<()> {
let _g = guard();
assert!(load_sealed_blob("missing")?.is_none());
Ok(())
}
}