pub mod strategy;
use strategy::Strategy;
pub mod error;
pub use error::{EncryptionError, DecryptionError};
mod integrations;
pub mod config;
use config::Config;
mod utilities;
use utilities::base64;
#[cfg(test)]
mod testing;
use std::{fmt::Debug, marker::PhantomData};
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use aes_gcm::{KeyInit as _, Aes256Gcm, AeadInPlace as _};
use secrecy::ExposeSecret as _;
#[derive(Debug, Deserialize, Serialize, PartialEq, Eq)]
#[cfg_attr(feature = "diesel", derive(diesel::AsExpression, diesel::FromSqlRow))]
#[cfg_attr(feature = "diesel", diesel(sql_type = diesel::sql_types::Json))]
#[cfg_attr(all(feature = "diesel", feature = "diesel-postgres"), diesel(sql_type = diesel::sql_types::Jsonb))]
pub struct EncryptedMessage<P: Debug + DeserializeOwned + Serialize, C: Config> {
#[serde(rename = "p")]
payload: String,
#[serde(rename = "h")]
headers: EncryptedMessageHeaders,
#[serde(skip)]
payload_type: PhantomData<P>,
#[serde(skip)]
config: PhantomData<C>,
}
#[derive(Debug, Deserialize, Serialize, PartialEq, Eq)]
struct EncryptedMessageHeaders {
#[serde(rename = "iv")]
nonce: String,
#[serde(rename = "at")]
tag: String,
}
impl<P: Debug + DeserializeOwned + Serialize, C: Config> EncryptedMessage<P, C> {
pub fn encrypt_with_config(payload: P, config: &C) -> Result<Self, EncryptionError> {
let payload = serde_json::to_vec(&payload)?;
let key = config.primary_key();
let nonce = C::Strategy::generate_nonce_for(&payload, key.expose_secret());
let cipher = Aes256Gcm::new_from_slice(key.expose_secret()).unwrap();
let mut buffer = payload;
let tag = cipher.encrypt_in_place_detached(&nonce.into(), b"", &mut buffer).unwrap();
Ok(EncryptedMessage {
payload: base64::encode(buffer),
headers: EncryptedMessageHeaders {
nonce: base64::encode(nonce),
tag: base64::encode(tag),
},
payload_type: PhantomData,
config: PhantomData,
})
}
pub fn decrypt_with_config(&self, config: &C) -> Result<P, DecryptionError> {
let payload = base64::decode(&self.payload)?;
let nonce = base64::decode(&self.headers.nonce)?;
let tag = base64::decode(&self.headers.tag)?;
for key in config.keys() {
let cipher = Aes256Gcm::new_from_slice(key.expose_secret()).unwrap();
let mut buffer = payload.clone();
if cipher.decrypt_in_place_detached(nonce.as_slice().into(), b"", &mut buffer, tag.as_slice().into()).is_err() {
continue;
};
return Ok(serde_json::from_slice(&buffer)?);
}
Err(DecryptionError::Decryption)
}
}
impl<P: Debug + DeserializeOwned + Serialize, C: Config + Default> EncryptedMessage<P, C> {
pub fn encrypt(payload: P) -> Result<Self, EncryptionError> {
Self::encrypt_with_config(payload, &C::default())
}
pub fn decrypt(&self) -> Result<P, DecryptionError> {
self.decrypt_with_config(&C::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use crate::testing::{TestConfigDeterministic, TestConfigRandomized};
mod encrypt {
use super::*;
#[test]
fn deterministic() {
assert_eq!(
EncryptedMessage::<String, TestConfigDeterministic>::encrypt("rigo does pretty codes".to_string()).unwrap(),
EncryptedMessage {
payload: "K6FbTsR8lNt9osq7vfvpDl4gPOxaQUhH".to_string(),
headers: EncryptedMessageHeaders {
nonce: "1WOXnWc3iX5iA3wd".to_string(),
tag: "fdnw5HvNImSdBm0nTFiRFw==".to_string(),
},
payload_type: PhantomData,
config: PhantomData,
},
);
}
#[test]
fn randomized() {
let payload = "much secret much secure".to_string();
assert_ne!(
EncryptedMessage::<String, TestConfigRandomized>::encrypt(payload.clone()).unwrap(),
EncryptedMessage::<String, TestConfigRandomized>::encrypt(payload).unwrap(),
);
}
#[test]
fn test_serialization_error() {
let map = std::collections::HashMap::<[u8; 2], String>::from([([1, 2], "Hi".to_string())]);
assert!(matches!(EncryptedMessage::<_, TestConfigDeterministic>::encrypt(map).unwrap_err(), EncryptionError::Serialization(_)));
}
}
mod decrypt {
use super::*;
#[test]
fn decrypts_correctly() {
let payload = "hi :D".to_string();
let message = EncryptedMessage::<String, TestConfigDeterministic>::encrypt(payload.clone()).unwrap();
assert_eq!(message.decrypt().unwrap(), payload);
}
#[test]
fn test_base64_decoding_error() {
fn generate() -> EncryptedMessage<String, TestConfigDeterministic> {
EncryptedMessage::encrypt("hi :)".to_string()).unwrap()
}
let mut message = generate();
message.payload = "invalid".to_string();
assert!(matches!(message.decrypt().unwrap_err(), DecryptionError::Base64Decoding(_)));
let mut message = generate();
message.headers.nonce = "invalid".to_string();
assert!(matches!(message.decrypt().unwrap_err(), DecryptionError::Base64Decoding(_)));
let mut message = generate();
message.headers.tag = "invalid".to_string();
assert!(matches!(message.decrypt().unwrap_err(), DecryptionError::Base64Decoding(_)));
}
#[test]
fn test_decryption_error() {
let message = EncryptedMessage {
payload: "2go7QdfuErm53fOI2jiNnHcPunwGWHpM".to_string(),
headers: EncryptedMessageHeaders {
nonce: "Exz8Fa9hKHEWvvmZ".to_string(),
tag: "r/AdKM4Dp0YAr/7dzAqujw==".to_string(),
},
payload_type: PhantomData::<String>,
config: PhantomData::<TestConfigDeterministic>,
};
assert!(matches!(message.decrypt().unwrap_err(), DecryptionError::Decryption));
}
#[test]
fn test_deserialization_error() {
let message = EncryptedMessage::<String, TestConfigDeterministic>::encrypt("hi :)".to_string()).unwrap();
let message = EncryptedMessage {
payload: message.payload,
headers: message.headers,
payload_type: PhantomData::<u8>,
config: message.config,
};
assert!(matches!(message.decrypt().unwrap_err(), DecryptionError::Deserialization(_)));
}
}
#[test]
fn allows_rotating_keys() {
let message = EncryptedMessage {
payload: "DT6PJ1ROSA==".to_string(),
headers: EncryptedMessageHeaders {
nonce: "nv6rH50Sn2Po320K".to_string(),
tag: "ZtAoub/4fB30QetW+O7oaA==".to_string(),
},
payload_type: PhantomData::<String>,
config: PhantomData::<TestConfigDeterministic>,
};
let expected_payload = "hi :)".to_string();
assert_ne!(
EncryptedMessage::<String, TestConfigDeterministic>::encrypt(expected_payload.clone()).unwrap(),
message,
);
assert_eq!(message.decrypt().unwrap(), expected_payload);
}
#[test]
fn handles_empty_payload() {
let message = EncryptedMessage::<String, TestConfigDeterministic>::encrypt("".to_string()).unwrap();
assert_eq!(message.decrypt().unwrap(), "");
}
#[test]
fn handles_json_types() {
let encrypted = EncryptedMessage::<Option<String>, TestConfigRandomized>::encrypt(None).unwrap();
assert_eq!(encrypted.decrypt().unwrap(), None);
let encrypted = EncryptedMessage::<Option<String>, TestConfigRandomized>::encrypt(Some("rigo is cool".to_string())).unwrap();
assert_eq!(encrypted.decrypt().unwrap(), Some("rigo is cool".to_string()));
let encrypted = EncryptedMessage::<bool, TestConfigRandomized>::encrypt(true).unwrap();
assert_eq!(encrypted.decrypt().unwrap() as u8, 1);
let encrypted = EncryptedMessage::<u8, TestConfigRandomized>::encrypt(255).unwrap();
assert_eq!(encrypted.decrypt().unwrap(), 255);
let encrypted = EncryptedMessage::<f64, TestConfigRandomized>::encrypt(0.12345).unwrap();
assert_eq!(encrypted.decrypt().unwrap(), 0.12345);
let encrypted = EncryptedMessage::<String, TestConfigRandomized>::encrypt("rigo is cool".to_string()).unwrap();
assert_eq!(encrypted.decrypt().unwrap(), "rigo is cool");
let encrypted = EncryptedMessage::<Vec<u8>, TestConfigRandomized>::encrypt(vec![1, 2, 3]).unwrap();
assert_eq!(encrypted.decrypt().unwrap(), vec![1, 2, 3]);
let encrypted = EncryptedMessage::<serde_json::Value, TestConfigRandomized>::encrypt(json!({ "a": 1, "b": "hello", "c": false })).unwrap();
assert_eq!(encrypted.decrypt().unwrap(), json!({ "a": 1, "b": "hello", "c": false }));
}
#[test]
fn to_and_from_json() {
let message = EncryptedMessage {
payload: "SBwByX5cxBSMgPlixDEf0pYEa6W41TIA".to_string(),
headers: EncryptedMessageHeaders {
nonce: "xg172uWMpjJqmWro".to_string(),
tag: "S88wdO9tf/381mZQ88kMNw==".to_string(),
},
payload_type: PhantomData::<String>,
config: PhantomData::<TestConfigRandomized>,
};
let message_json = serde_json::to_value(&message).unwrap();
assert_eq!(
message_json,
json!({
"p": "SBwByX5cxBSMgPlixDEf0pYEa6W41TIA",
"h": {
"iv": "xg172uWMpjJqmWro",
"at": "S88wdO9tf/381mZQ88kMNw==",
},
}),
);
assert_eq!(
serde_json::from_value::<EncryptedMessage::<_, _>>(message_json).unwrap(),
message,
);
}
}