use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use hmac::{Hmac, Mac};
use sea_orm::{ColumnTrait, DatabaseConnection, EntityTrait, Iterable};
use sha2::Sha256;
use crate::{Record, RecordError, querying::AsyncQuerying};
use rustrails_support::{database, runtime};
type HmacSha256 = Hmac<Sha256>;
#[allow(private_bounds)]
pub trait SignedId: Record {
fn signed_id(&self) -> Result<String, RecordError> {
let id = self.id().ok_or(RecordError::NotSaved)?;
sign_id(id, &Self::signing_secret())
}
async fn find_signed(token: &str, db: &DatabaseConnection) -> Result<Self, RecordError>
where
Self: AsyncQuerying,
<Self::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
let id = verify_signed_id(token, &Self::signing_secret())?;
Self::find(id, db).await
}
fn find_signed_sync(token: &str) -> Result<Self, RecordError>
where
Self: AsyncQuerying,
<Self::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
database::with_db(|db| runtime::block_on(Self::find_signed(token, db)))
}
fn signing_secret() -> String {
std::env::var("RUSTRAILS_SECRET_KEY_BASE")
.unwrap_or_else(|_| "development-secret".to_owned())
}
}
impl<T: Record> SignedId for T {}
fn sign_id(id: i64, secret: &str) -> Result<String, RecordError> {
let id_bytes = id.to_string().into_bytes();
let signature = sign_hmac(secret.as_bytes(), &id_bytes)?;
let mut payload = Vec::with_capacity(id_bytes.len() + 1 + signature.len());
payload.extend_from_slice(&id_bytes);
payload.push(b'.');
payload.extend_from_slice(&signature);
Ok(URL_SAFE_NO_PAD.encode(payload))
}
fn verify_signed_id(token: &str, secret: &str) -> Result<i64, RecordError> {
let decoded = URL_SAFE_NO_PAD
.decode(token)
.map_err(|_| invalid_signed_id_error())?;
let separator = decoded
.iter()
.position(|byte| *byte == b'.')
.ok_or_else(invalid_signed_id_error)?;
let id_bytes = &decoded[..separator];
let signature = decoded
.get(separator + 1..)
.filter(|bytes| !bytes.is_empty())
.ok_or_else(invalid_signed_id_error)?;
verify_hmac(secret.as_bytes(), id_bytes, signature)?;
let id = std::str::from_utf8(id_bytes)
.map_err(|_| invalid_signed_id_error())?
.parse::<i64>()
.map_err(|_| invalid_signed_id_error())?;
Ok(id)
}
fn sign_hmac(secret: &[u8], payload: &[u8]) -> Result<Vec<u8>, RecordError> {
let mut mac = HmacSha256::new_from_slice(secret)
.map_err(|_| RecordError::Invalid("invalid signing secret".to_owned()))?;
mac.update(payload);
Ok(mac.finalize().into_bytes().to_vec())
}
fn verify_hmac(secret: &[u8], payload: &[u8], signature: &[u8]) -> Result<(), RecordError> {
let mut mac = HmacSha256::new_from_slice(secret)
.map_err(|_| RecordError::Invalid("invalid signing secret".to_owned()))?;
mac.update(payload);
mac.verify_slice(signature)
.map_err(|_| invalid_signed_id_error())
}
fn invalid_signed_id_error() -> RecordError {
RecordError::Invalid("invalid signed id".to_owned())
}
#[cfg(test)]
mod tests {
use super::SignedId;
use crate::{
Querying, RecordError,
base::test_support::{TestUser, seed_users, test_user},
};
use rustrails_support::{database, runtime};
use sea_orm::{ConnectionTrait, Schema};
fn run_sync_test(seed: bool, test: impl FnOnce() + Send + 'static) {
std::thread::spawn(move || {
let _runtime = runtime::init_runtime();
database::establish("sqlite::memory:")
.expect("sqlite in-memory connection should succeed");
runtime::block_on(async {
let db = database::db();
let schema = Schema::new(db.get_database_backend());
db.execute(&schema.create_table_from_entity(test_user::Entity))
.await
.expect("test_users table should be created");
if seed {
seed_users(&db).await;
}
});
test();
})
.join()
.expect("signed-id sync test should not panic");
}
#[tokio::test]
async fn signed_id_generates_non_empty_token_for_persisted_record() {
let db = crate::base::test_support::setup_db().await;
let user = seed_users(&db).await.remove(0);
let token = user.signed_id().expect("signed_id should succeed");
assert!(!token.is_empty());
}
#[tokio::test]
async fn find_signed_recovers_original_record() {
let db = crate::base::test_support::setup_db().await;
let original = seed_users(&db).await.remove(1);
let token = original.signed_id().expect("signed_id should succeed");
let found = TestUser::find_signed(&token, &db)
.await
.expect("find_signed should succeed");
assert_eq!(found, original);
}
#[tokio::test]
async fn find_signed_rejects_tampered_tokens() {
let db = crate::base::test_support::setup_db().await;
let user = seed_users(&db).await.remove(0);
let mut token = user.signed_id().expect("signed_id should succeed");
token.push('x');
let error = TestUser::find_signed(&token, &db)
.await
.expect_err("tampered token should be rejected");
assert!(matches!(error, RecordError::Invalid(_)));
}
#[test]
fn find_signed_sync_recovers_original_record() {
run_sync_test(true, || {
let original = TestUser::find_sync(2).expect("seeded record should exist");
let token = original.signed_id().expect("signed_id should succeed");
let found =
TestUser::find_signed_sync(&token).expect("find_signed_sync should succeed");
assert_eq!(found, original);
});
}
}