use nostr_sdk::prelude::*;
use crate::error::{MostroError, ServiceError};
#[derive(Debug, Clone)]
pub struct SharedKey(Keys);
impl SharedKey {
pub fn derive(secret: &SecretKey, counterparty: &PublicKey) -> Result<Self, MostroError> {
let bytes = nostr_sdk::util::generate_shared_key(secret, counterparty).map_err(|e| {
MostroError::MostroInternalErr(ServiceError::EncryptionError(format!(
"shared key derivation failed: {e}"
)))
})?;
let secret = SecretKey::from_slice(&bytes).map_err(|e| {
MostroError::MostroInternalErr(ServiceError::EncryptionError(format!(
"invalid shared secret: {e}"
)))
})?;
Ok(Self(Keys::new(secret)))
}
pub fn from_keys(keys: Keys) -> Self {
Self(keys)
}
pub fn keys(&self) -> &Keys {
&self.0
}
pub fn public_key(&self) -> PublicKey {
self.0.public_key()
}
pub fn secret_key(&self) -> &SecretKey {
self.0.secret_key()
}
pub fn to_hex(&self) -> String {
self.0.secret_key().to_secret_hex()
}
pub fn from_hex(hex: &str) -> Result<Self, MostroError> {
let secret = SecretKey::from_hex(hex).map_err(|e| {
MostroError::MostroInternalErr(ServiceError::EncryptionError(format!(
"invalid shared key hex: {e}"
)))
})?;
Ok(Self(Keys::new(secret)))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn derive_is_symmetric_between_peers() {
let alice = Keys::generate();
let bob = Keys::generate();
let from_alice = SharedKey::derive(alice.secret_key(), &bob.public_key()).unwrap();
let from_bob = SharedKey::derive(bob.secret_key(), &alice.public_key()).unwrap();
assert_eq!(from_alice.public_key(), from_bob.public_key());
assert_eq!(from_alice.to_hex(), from_bob.to_hex());
}
#[test]
fn derive_shared_key_hex_roundtrip() {
let alice = Keys::generate();
let bob = Keys::generate();
let derived = SharedKey::derive(alice.secret_key(), &bob.public_key()).unwrap();
let hex = derived.to_hex();
let restored = SharedKey::from_hex(&hex).unwrap();
assert_eq!(derived.public_key(), restored.public_key());
assert_eq!(derived.to_hex(), restored.to_hex());
}
#[test]
fn derive_shared_key_different_peers_produce_different_keys() {
let alice = Keys::generate();
let bob = Keys::generate();
let carol = Keys::generate();
let with_bob = SharedKey::derive(alice.secret_key(), &bob.public_key()).unwrap();
let with_carol = SharedKey::derive(alice.secret_key(), &carol.public_key()).unwrap();
assert_ne!(with_bob.public_key(), with_carol.public_key());
}
#[test]
fn from_hex_rejects_invalid_input() {
let err = SharedKey::from_hex("not-a-hex-string").unwrap_err();
assert!(matches!(err, MostroError::MostroInternalErr(_)));
}
}