use alloc::string::String;
use alloc::vec::Vec;
use core::fmt;
use core::fmt::Debug;
use base64::engine::{Engine, general_purpose};
#[cfg(feature = "rand")]
use rand::RngCore;
#[cfg(all(feature = "std", feature = "os-rng"))]
use rand::TryRngCore;
#[cfg(all(feature = "std", feature = "os-rng"))]
use rand::rngs::OsRng;
use super::v2::{self, ConversationKey};
#[cfg(feature = "rand")]
use crate::util;
use crate::{PublicKey, SecretKey, key};
#[derive(Debug, PartialEq)]
pub enum Error {
Key(key::Error),
V2(v2::ErrorV2),
Base64Decode(base64::DecodeError),
Utf8Encode,
UnknownVersion(u8),
VersionNotFound,
NotFound(String),
}
impl core::error::Error for Error {}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Key(e) => write!(f, "{e}"),
Self::V2(e) => write!(f, "{e}"),
Self::Base64Decode(e) => write!(f, "Error while decoding from base64: {e}"),
Self::Utf8Encode => f.write_str("Error while encoding to UTF-8"),
Self::UnknownVersion(v) => write!(f, "unknown version: {v}"),
Self::VersionNotFound => f.write_str("Version not found in payload"),
Self::NotFound(value) => write!(f, "{value} not found in payload"),
}
}
}
impl From<key::Error> for Error {
fn from(e: key::Error) -> Self {
Self::Key(e)
}
}
impl From<v2::ErrorV2> for Error {
fn from(e: v2::ErrorV2) -> Self {
Self::V2(e)
}
}
impl From<base64::DecodeError> for Error {
fn from(e: base64::DecodeError) -> Self {
Self::Base64Decode(e)
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(u8)]
pub enum Version {
#[default]
V2 = 0x02,
}
impl Version {
#[inline]
pub fn as_u8(&self) -> u8 {
*self as u8
}
}
impl TryFrom<u8> for Version {
type Error = Error;
fn try_from(version: u8) -> Result<Self, Self::Error> {
match version {
0x02 => Ok(Self::V2),
v => Err(Error::UnknownVersion(v)),
}
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum Nonce {
V2([u8; 32]),
}
#[inline]
#[cfg(all(feature = "std", feature = "os-rng"))]
pub fn encrypt<T>(
secret_key: &SecretKey,
public_key: &PublicKey,
content: T,
version: Version,
) -> Result<String, Error>
where
T: AsRef<[u8]>,
{
encrypt_with_rng(
secret_key,
public_key,
content,
version,
&mut OsRng.unwrap_err(),
)
}
#[cfg(feature = "rand")]
pub fn encrypt_with_rng<R, T>(
secret_key: &SecretKey,
public_key: &PublicKey,
content: T,
version: Version,
rng: &mut R,
) -> Result<String, Error>
where
R: RngCore,
T: AsRef<[u8]>,
{
let nonce: Nonce = match version {
Version::V2 => {
let nonce: [u8; 32] = util::random_32_bytes(rng);
Nonce::V2(nonce)
}
};
encrypt_with_nonce(secret_key, public_key, content, nonce)
}
pub fn encrypt_with_nonce<T>(
secret_key: &SecretKey,
public_key: &PublicKey,
content: T,
nonce: Nonce,
) -> Result<String, Error>
where
T: AsRef<[u8]>,
{
let payload: Vec<u8> = encrypt_to_bytes_with_nonce(secret_key, public_key, content, nonce)?;
Ok(general_purpose::STANDARD.encode(payload))
}
pub fn encrypt_to_bytes_with_nonce<T>(
secret_key: &SecretKey,
public_key: &PublicKey,
content: T,
nonce: Nonce,
) -> Result<Vec<u8>, Error>
where
T: AsRef<[u8]>,
{
match nonce {
Nonce::V2(nonce) => {
let conversation_key: ConversationKey =
ConversationKey::derive(secret_key, public_key)?;
let payload: Vec<u8> =
v2::encrypt_to_bytes_with_nonce(&conversation_key, content.as_ref(), nonce)?;
Ok(payload)
}
}
}
#[inline]
pub fn decrypt<T>(
secret_key: &SecretKey,
public_key: &PublicKey,
payload: T,
) -> Result<String, Error>
where
T: AsRef<[u8]>,
{
let bytes: Vec<u8> = decrypt_to_bytes(secret_key, public_key, payload)?;
String::from_utf8(bytes).map_err(|_| Error::Utf8Encode)
}
pub fn decrypt_to_bytes<T>(
secret_key: &SecretKey,
public_key: &PublicKey,
payload: T,
) -> Result<Vec<u8>, Error>
where
T: AsRef<[u8]>,
{
let payload: Vec<u8> = general_purpose::STANDARD.decode(payload)?;
let version: u8 = *payload.first().ok_or(Error::VersionNotFound)?;
match Version::try_from(version)? {
Version::V2 => {
let conversation_key: ConversationKey =
ConversationKey::derive(secret_key, public_key)?;
v2::decrypt_to_bytes(&conversation_key, &payload)
}
}
}
#[cfg(test)]
#[cfg(all(feature = "std", feature = "os-rng"))]
mod tests {
use core::str::FromStr;
use super::*;
use crate::Keys;
#[test]
fn test_nip44_encryption_decryption() {
let alice_sk =
SecretKey::from_str("5c0c523f52a5b6fad39ed2403092df8cebc36318b39383bca6c00808626fab3a")
.unwrap();
let alice_keys = Keys::new(alice_sk);
let alice_pk = alice_keys.public_key();
let bob_sk =
SecretKey::from_str("4b22aa260e4acb7021e32f38a6cdf4b673c6a277755bfce287e370c924dc936d")
.unwrap();
let bob_keys = Keys::new(bob_sk);
let bob_pk = bob_keys.public_key();
let content = String::from("hello");
let encrypted_content =
encrypt(alice_keys.secret_key(), &bob_pk, &content, Version::V2).unwrap();
assert_eq!(
decrypt(bob_keys.secret_key(), &alice_pk, encrypted_content).unwrap(),
content
);
}
}