use base64::engine::{general_purpose, Engine as _};
use chacha20::cipher::{KeyIvInit, StreamCipher};
use hmac::{KeyInit, Mac};
use zeroize::Zeroize;
#[derive(Debug)]
pub enum Nip44Error {
SharedSecretError,
FromHexError(nostro2_traits::hex::HexError),
NostrNoteError(nostro2::errors::NostrErrors),
InvalidLength,
Base64DecodingError(base64::DecodeError),
FromUtf8Error(std::str::Utf8Error),
HkdfError,
HmacError,
SliceError(chacha20::cipher::InvalidLength),
InvalidPrefixLen,
FromArrayError(std::array::TryFromSliceError),
BufferTooSmall,
FromIntError(std::num::TryFromIntError),
}
impl std::fmt::Display for Nip44Error {
#[allow(unknown_lints, crappy)]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::SharedSecretError => f.write_str("shared secret error"),
Self::FromHexError(e) => write!(f, "hex decoding error: {e}"),
Self::NostrNoteError(e) => write!(f, "{e}"),
Self::InvalidLength => f.write_str("invalid input length"),
Self::Base64DecodingError(e) => write!(f, "base64 decoding error: {e}"),
Self::FromUtf8Error(e) => write!(f, "UTF-8 conversion error: {e}"),
Self::HkdfError => f.write_str("HKDF key derivation failed"),
Self::HmacError => f.write_str("HMAC failure"),
Self::SliceError(e) => write!(f, "ChaCha20 slice error: {e}"),
Self::InvalidPrefixLen => f.write_str("invalid length prefix"),
Self::FromArrayError(e) => write!(f, "decryption error: {e}"),
Self::BufferTooSmall => f.write_str("buffer too small"),
Self::FromIntError(e) => write!(f, "encryption error: {e}"),
}
}
}
impl std::error::Error for Nip44Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::FromHexError(e) => Some(e),
Self::NostrNoteError(e) => Some(e),
Self::Base64DecodingError(e) => Some(e),
Self::FromUtf8Error(e) => Some(e),
Self::FromArrayError(e) => Some(e),
Self::FromIntError(e) => Some(e),
_ => None,
}
}
}
impl From<nostro2_traits::hex::HexError> for Nip44Error {
fn from(e: nostro2_traits::hex::HexError) -> Self {
Self::FromHexError(e)
}
}
impl From<nostro2::errors::NostrErrors> for Nip44Error {
fn from(e: nostro2::errors::NostrErrors) -> Self {
Self::NostrNoteError(e)
}
}
impl From<base64::DecodeError> for Nip44Error {
fn from(e: base64::DecodeError) -> Self {
Self::Base64DecodingError(e)
}
}
impl From<std::str::Utf8Error> for Nip44Error {
fn from(e: std::str::Utf8Error) -> Self {
Self::FromUtf8Error(e)
}
}
impl From<chacha20::cipher::InvalidLength> for Nip44Error {
fn from(e: chacha20::cipher::InvalidLength) -> Self {
Self::SliceError(e)
}
}
impl From<std::array::TryFromSliceError> for Nip44Error {
fn from(e: std::array::TryFromSliceError) -> Self {
Self::FromArrayError(e)
}
}
impl From<std::num::TryFromIntError> for Nip44Error {
fn from(e: std::num::TryFromIntError) -> Self {
Self::FromIntError(e)
}
}
pub struct MacComponents<'a> {
nonce: zeroize::Zeroizing<[u8; 12]>,
ciphertext: &'a [u8],
}
pub trait Nip44: nostro2::NostrKeypair {
fn shared_secret(&self, peer_pubkey: &str) -> Result<zeroize::Zeroizing<[u8; 32]>, Nip44Error> {
Ok(nostro2::NostrKeypair::shared_point(self, peer_pubkey)
.map_err(|_| Nip44Error::SharedSecretError)?
.into())
}
fn nip44_encrypt_note<'a>(
&self,
note: &'a mut nostro2::NostrNote,
peer_pubkey: &'a str,
) -> Result<(), Nip44Error> {
note.content = self.nip_44_encrypt(¬e.content, peer_pubkey)?.to_string();
Ok(())
}
fn nip44_decrypt_note<'a>(
&self,
note: &'a nostro2::NostrNote,
peer_pubkey: &'a str,
) -> Result<std::borrow::Cow<'a, str>, Nip44Error> {
self.nip_44_decrypt(¬e.content, peer_pubkey)
}
fn nip_44_encrypt<'a>(
&self,
plaintext: &'a str,
peer_pubkey: &'a str,
) -> Result<std::borrow::Cow<'a, str>, Nip44Error> {
let mut buffer =
zeroize::Zeroizing::new(vec![
0_u8;
(plaintext.len() + 2).next_power_of_two().max(32)
]);
let shared_secret = self.shared_secret(peer_pubkey)?;
let mut conversation_key = Self::derive_conversation_key(shared_secret, b"nip44-v2")?;
let mut nonce = Self::generate_nonce();
let ciphertext = Self::encrypt(
plaintext.as_bytes(),
conversation_key.as_slice(),
nonce.as_slice(),
buffer.as_mut_slice(),
)?;
let mac = Self::calculate_mac(ciphertext, conversation_key.as_slice())?;
let encoded = Self::base64_encode_params(b"1", nonce.as_slice(), ciphertext, &mac);
conversation_key.zeroize();
nonce.zeroize();
Ok(encoded.into())
}
fn nip_44_decrypt<'a>(
&self,
ciphertext: &'a str,
peer_pubkey: &'a str,
) -> Result<std::borrow::Cow<'a, str>, Nip44Error> {
let mut buffer = zeroize::Zeroizing::new(vec![0_u8; ciphertext.len()]);
let shared_secret = self.shared_secret(peer_pubkey)?;
let conversation_key = Self::derive_conversation_key(shared_secret, b"nip44-v2")?;
let mut decoded = zeroize::Zeroizing::new(general_purpose::STANDARD.decode(ciphertext)?);
let MacComponents { nonce, ciphertext } = Self::extract_components(&decoded)?;
let decrypted = Self::decrypt(ciphertext, conversation_key, nonce, buffer.as_mut_slice())?;
decoded.zeroize();
Ok(std::str::from_utf8(decrypted)?.to_string().into())
}
fn encrypt<'a>(
content: &[u8],
key: &[u8],
nonce: &[u8],
buffer: &'a mut [u8],
) -> Result<&'a [u8], Nip44Error> {
let padded = Self::pad_string(content, buffer)?;
let mut cipher = chacha20::ChaCha20::new_from_slices(key, nonce)?;
cipher.apply_keystream(padded);
Ok(&padded[..])
}
fn decrypt<'a>(
ciphertext: &[u8],
mut key: zeroize::Zeroizing<[u8; 32]>,
mut nonce: zeroize::Zeroizing<[u8; 12]>,
buffer: &'a mut [u8],
) -> Result<&'a [u8], Nip44Error> {
if key.len() != 32 || nonce.len() != 12 {
return Err(Nip44Error::InvalidLength);
}
if buffer.len() < ciphertext.len() {
return Err(Nip44Error::InvalidLength);
}
buffer[..ciphertext.len()].copy_from_slice(ciphertext);
let mut cipher = chacha20::ChaCha20::new_from_slices(key.as_slice(), nonce.as_slice())?;
cipher.apply_keystream(&mut buffer[..ciphertext.len()]);
if ciphertext.len() < 2 {
return Err(Nip44Error::InvalidLength);
}
let len = u16::from_be_bytes([buffer[0], buffer[1]]) as usize;
if len > ciphertext.len() - 2 {
return Err(Nip44Error::InvalidPrefixLen);
}
key.zeroize();
nonce.zeroize();
Ok(&buffer[2..2 + len])
}
fn derive_conversation_key(
mut shared_secret: zeroize::Zeroizing<[u8; 32]>,
salt: &[u8],
) -> Result<zeroize::Zeroizing<[u8; 32]>, Nip44Error> {
let hkdf = hkdf::Hkdf::<sha2::Sha256>::new(Some(salt), shared_secret.as_slice());
shared_secret.zeroize();
let mut okm = [0_u8; 32];
hkdf.expand(&[], &mut okm)
.map_err(|_| Nip44Error::HkdfError)?;
Ok(okm.into())
}
fn extract_components(decoded: &[u8]) -> Result<MacComponents<'_>, Nip44Error> {
if decoded.len() < 1 + 12 + 32 {
return Err(Nip44Error::InvalidLength);
}
Ok(MacComponents {
nonce: zeroize::Zeroizing::new(decoded[1..13].try_into()?),
ciphertext: &decoded[13..decoded.len() - 32],
})
}
fn calculate_mac(data: &[u8], key: &[u8]) -> Result<[u8; 32], Nip44Error> {
let mut mac =
hmac::Hmac::<sha2::Sha256>::new_from_slice(key).map_err(|_| Nip44Error::HmacError)?;
mac.update(data);
let result = mac.finalize().into_bytes();
Ok(result.into())
}
fn pad_string<'a>(plaintext: &[u8], buffer: &'a mut [u8]) -> Result<&'a mut [u8], Nip44Error> {
if plaintext.is_empty() || plaintext.len() > 65535 {
return Err(Nip44Error::InvalidLength);
}
let total_len = (plaintext.len() + 2).next_power_of_two().max(32);
if buffer.len() < total_len {
return Err(Nip44Error::BufferTooSmall);
}
let len_bytes = u16::try_from(plaintext.len())?.to_be_bytes();
buffer[..2].copy_from_slice(&len_bytes);
buffer[2..2 + plaintext.len()].copy_from_slice(plaintext);
for b in &mut buffer[2 + plaintext.len()..total_len] {
*b = 0;
}
Ok(&mut buffer[..total_len])
}
#[must_use]
fn generate_nonce() -> zeroize::Zeroizing<[u8; 12]> {
let mut nonce = [0_u8; 12];
getrandom::fill(&mut nonce).expect("getrandom failed");
nonce.into()
}
#[must_use]
fn base64_encode_params(version: &[u8], nonce: &[u8], ciphertext: &[u8], mac: &[u8]) -> String {
let mut buf =
Vec::with_capacity(version.len() + nonce.len() + ciphertext.len() + mac.len());
buf.extend_from_slice(version);
buf.extend_from_slice(nonce);
buf.extend_from_slice(ciphertext);
buf.extend_from_slice(mac);
let mut out = String::with_capacity((buf.len() * 4).div_ceil(3));
general_purpose::STANDARD.encode_string(&buf, &mut out);
out
}
}
impl<T: nostro2::NostrKeypair + ?Sized> Nip44 for T {}
#[cfg(test)]
mod tests {
use super::*;
use nostro2::{NostrKeypair, NostrSigner};
#[test]
fn test_encrypt_decrypt_success() {
let sender = crate::tests::NipTester::generate();
let receiver = crate::tests::NipTester::generate();
let plaintext = "Hello NIP-44 encryption!";
let receiver_pk = receiver.public_key();
let sender_pk = sender.public_key();
let ciphertext = sender.nip_44_encrypt(plaintext, &receiver_pk).unwrap();
let decrypted = receiver.nip_44_decrypt(&ciphertext, &sender_pk).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_invalid_decryption_key() {
let sender = crate::tests::NipTester::generate();
let receiver = crate::tests::NipTester::generate();
let wrong_receiver = crate::tests::NipTester::generate();
let plaintext = "Hello NIP-44 encryption!";
let receiver_pk = receiver.public_key();
let sender_pk = sender.public_key();
let ciphertext = sender.nip_44_encrypt(plaintext, &receiver_pk).unwrap();
let result = wrong_receiver.nip_44_decrypt(&ciphertext, &sender_pk);
assert!(result.is_err());
}
use std::fmt::Write as _;
#[test]
fn encrypt_very_large_note() {
let sender = crate::tests::NipTester::generate();
let receiver = crate::tests::NipTester::generate();
let mut plaintext = String::new();
for i in 0..15329 {
let _ = write!(plaintext, "{i}");
}
let receiver_pk = receiver.public_key();
let sender_pk = sender.public_key();
let ciphertext = sender.nip_44_encrypt(&plaintext, &receiver_pk).unwrap();
let decrypted = receiver.nip_44_decrypt(&ciphertext, &sender_pk).unwrap();
assert_eq!(decrypted, plaintext);
}
fn utf8_err() -> std::str::Utf8Error {
let bad = [0xff_u8];
std::str::from_utf8(bad.as_slice()).unwrap_err()
}
fn slice_err() -> std::array::TryFromSliceError {
<[u8; 4]>::try_from([0_u8; 3].as_slice()).unwrap_err()
}
fn int_err() -> std::num::TryFromIntError {
u8::try_from(256_u16).unwrap_err()
}
#[test]
fn error_display_covers_all_variants() {
let cases: Vec<Nip44Error> = vec![
Nip44Error::SharedSecretError,
Nip44Error::FromHexError(nostro2_traits::hex::HexError::OddLength),
Nip44Error::NostrNoteError(nostro2::errors::NostrErrors::MissingId),
Nip44Error::InvalidLength,
Nip44Error::Base64DecodingError(
base64::engine::general_purpose::STANDARD
.decode("!!!")
.unwrap_err(),
),
Nip44Error::FromUtf8Error(utf8_err()),
Nip44Error::HkdfError,
Nip44Error::HmacError,
Nip44Error::InvalidPrefixLen,
Nip44Error::FromArrayError(slice_err()),
Nip44Error::BufferTooSmall,
Nip44Error::FromIntError(int_err()),
];
for err in &cases {
let msg = format!("{err}");
assert!(
!msg.is_empty(),
"Display must produce non-empty output for {err:?}"
);
}
}
#[test]
fn error_source_delegates_correctly() {
use std::error::Error;
assert!(Nip44Error::SharedSecretError.source().is_none());
assert!(Nip44Error::InvalidLength.source().is_none());
assert!(Nip44Error::HkdfError.source().is_none());
assert!(Nip44Error::HmacError.source().is_none());
assert!(Nip44Error::InvalidPrefixLen.source().is_none());
assert!(Nip44Error::BufferTooSmall.source().is_none());
assert!(
Nip44Error::FromHexError(nostro2_traits::hex::HexError::OddLength)
.source()
.is_some()
);
assert!(
Nip44Error::NostrNoteError(nostro2::errors::NostrErrors::MissingId)
.source()
.is_some()
);
assert!(Nip44Error::Base64DecodingError(
base64::engine::general_purpose::STANDARD
.decode("!!!")
.unwrap_err()
)
.source()
.is_some());
assert!(Nip44Error::FromUtf8Error(utf8_err()).source().is_some());
assert!(Nip44Error::FromArrayError(slice_err()).source().is_some());
assert!(Nip44Error::FromIntError(int_err()).source().is_some());
}
mod proptests {
use super::*;
use nostro2::{NostrKeypair, NostrSigner};
use proptest::prelude::*;
proptest! {
#[test]
fn encrypt_decrypt_round_trip(plaintext in ".{1,256}") {
let sender = crate::tests::NipTester::generate();
let receiver = crate::tests::NipTester::generate();
let receiver_pk = receiver.public_key();
let sender_pk = sender.public_key();
let ciphertext = sender.nip_44_encrypt(&plaintext, &receiver_pk).unwrap();
let decrypted = receiver.nip_44_decrypt(&ciphertext, &sender_pk).unwrap();
prop_assert_eq!(&plaintext, decrypted.as_ref());
}
#[test]
fn encrypt_is_non_deterministic(plaintext in ".{1,64}") {
let sender = crate::tests::NipTester::generate();
let receiver = crate::tests::NipTester::generate();
let receiver_pk = receiver.public_key();
let a = sender.nip_44_encrypt(&plaintext, &receiver_pk).unwrap();
let b = sender.nip_44_encrypt(&plaintext, &receiver_pk).unwrap();
prop_assert_ne!(a, b, "same plaintext must produce different ciphertexts");
}
#[test]
fn pad_string_is_power_of_two(plaintext in ".{1,1024}") {
let total = (plaintext.len() + 2).next_power_of_two().max(32);
let mut buf = vec![0_u8; total];
let padded = crate::tests::NipTester::pad_string(
plaintext.as_bytes(), &mut buf
).unwrap();
prop_assert!(padded.len().is_power_of_two() || padded.len() == 32);
prop_assert!(padded.len() >= 32);
}
}
}
}