use std::{
future,
pin::Pin,
sync::{Arc, Mutex},
};
use futures::{Stream, StreamExt};
use p256::{
ecdsa::{Signature, SigningKey, signature::hazmat::PrehashSigner},
elliptic_curve::SecretKey,
};
use crate::{KeyError, SigningError};
const SIGNATURE_RESOLUTION_CONCURRENCY: usize = 10;
#[derive(Clone)]
pub struct AuthorizationContext {
signers: Arc<Mutex<Vec<Arc<dyn IntoSignatureBoxed + Send + Sync>>>>,
resolution_concurrency: usize,
}
impl std::fmt::Debug for AuthorizationContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AuthorizationContext").finish()
}
}
impl Default for AuthorizationContext {
fn default() -> Self {
Self::new()
}
}
impl AuthorizationContext {
#[must_use]
pub fn new() -> Self {
Self {
signers: Default::default(),
resolution_concurrency: SIGNATURE_RESOLUTION_CONCURRENCY,
}
}
pub fn push<T: IntoSignature + 'static + Send + Sync>(self, key: T) -> Self {
self.signers
.lock()
.expect("lock poisoned")
.push(Arc::new(key));
self
}
pub fn sign<'a>(
&'a self,
message: &'a [u8],
) -> impl Stream<Item = Result<Signature, SigningError>> + 'a {
let keys = self.signers.lock().expect("lock poisoned").clone();
futures::stream::iter(keys)
.map(move |key| {
let key = key.clone();
async move { key.sign_boxed(message).await }
})
.buffer_unordered(self.resolution_concurrency)
}
pub async fn validate(&self) -> Vec<SigningError> {
self.sign(&[])
.filter_map(|r| future::ready(r.err())) .collect::<Vec<_>>()
.await
}
}
type Key = SecretKey<p256::NistP256>;
pub trait IntoKey {
fn get_key(&self) -> impl Future<Output = Result<Key, KeyError>> + Send;
}
pub trait IntoSignature {
fn sign(&self, message: &[u8]) -> impl Future<Output = Result<Signature, SigningError>> + Send;
}
impl<T> IntoSignature for T
where
T: IntoKey + Sync,
{
async fn sign(&self, message: &[u8]) -> Result<Signature, SigningError> {
let key = self.get_key().await?;
key.sign(message).await
}
}
trait IntoSignatureBoxed {
fn sign_boxed<'a>(
&'a self,
message: &'a [u8],
) -> Pin<Box<dyn Future<Output = Result<Signature, SigningError>> + Send + 'a>>;
}
impl<T: IntoSignature + 'static> IntoSignatureBoxed for T {
fn sign_boxed<'a>(
&'a self,
message: &'a [u8],
) -> Pin<Box<dyn Future<Output = Result<Signature, SigningError>> + Send + 'a>> {
Box::pin(self.sign(message))
}
}
pub struct FnSigner<F>(pub F);
pub struct FnKey<F>(pub F);
impl<F, Fut> IntoSignature for FnSigner<F>
where
F: Fn(&[u8]) -> Fut,
Fut: Future<Output = Result<Signature, SigningError>> + Send,
{
fn sign(&self, message: &[u8]) -> impl Future<Output = Result<Signature, SigningError>> + Send {
(self.0)(message)
}
}
impl<F, Fut> IntoKey for FnKey<F>
where
F: Fn() -> Fut,
Fut: Future<Output = Result<Key, KeyError>> + Send,
{
fn get_key(&self) -> impl Future<Output = Result<Key, KeyError>> + Send {
(self.0)()
}
}
pub struct JwtUser(pub crate::PrivyClient, pub String);
impl IntoKey for JwtUser {
async fn get_key(&self) -> Result<Key, KeyError> {
self.0
.jwt_exchange
.exchange_jwt_for_authorization_key(self)
.await
}
}
impl IntoSignature for Key {
async fn sign(&self, message: &[u8]) -> Result<Signature, SigningError> {
use sha2::{Digest, Sha256};
tracing::debug!(
"Starting ECDSA signing process for {} byte message",
message.len()
);
let hashed = {
let mut sha256 = Sha256::new();
sha256.update(message);
sha256.finalize()
};
tracing::debug!("SHA256 hash computed: {}", hex::encode(hashed));
let signing_key = SigningKey::from(self.clone());
let signature: Signature = signing_key.sign_prehash(&hashed)?;
tracing::debug!("ECDSA signature generated using deterministic RFC 6979");
Ok(signature)
}
}
impl IntoSignature for Signature {
async fn sign(&self, _message: &[u8]) -> Result<Signature, SigningError> {
Ok(*self)
}
}
pub struct PrivateKey(zeroize::Zeroizing<String>);
impl PrivateKey {
pub fn new(key: String) -> Self {
Self(zeroize::Zeroizing::new(key))
}
}
impl IntoKey for PrivateKey {
async fn get_key(&self) -> Result<Key, KeyError> {
SecretKey::<p256::NistP256>::from_sec1_pem(self.0.as_str()).map_err(|e| {
tracing::error!("Failed to parse SEC1 PEM: {:?}", e);
KeyError::InvalidFormat("provided PEM string is malformed".to_string())
})
}
}
#[cfg(test)]
mod tests {
use base64::{Engine, engine::general_purpose::STANDARD};
use futures::TryStreamExt;
use p256::{
ecdsa::Signature,
elliptic_curve::{SecretKey, generic_array::GenericArray},
};
use test_case::test_case;
use tracing_test::traced_test;
use super::*;
use crate::{AuthorizationContext, FnKey, FnSigner, KeyError};
const TEST_PRIVATE_KEY_PEM: &str = include_str!("../tests/test_private_key.pem");
#[tokio::test]
async fn test_private_key_creation() {
let key = PrivateKey::new(TEST_PRIVATE_KEY_PEM.to_string());
let result = key.get_key().await;
assert!(result.is_ok(), "Should successfully parse valid PEM key");
}
#[tokio::test]
async fn test_private_key_invalid_format() {
let key = PrivateKey::new("invalid_pem_data".to_string());
let result = key.get_key().await;
assert!(result.is_err(), "Should fail with invalid PEM data");
if let Err(KeyError::InvalidFormat(_)) = result {
} else {
panic!("Expected InvalidFormat error");
}
}
#[tokio::test]
async fn test_private_key_signing() {
let key = PrivateKey::new(TEST_PRIVATE_KEY_PEM.to_string());
let test_key = key.get_key().await.unwrap();
let message1 = b"test message for signing";
let message2 = b"different message";
let signature1a = test_key.sign(message1).await.unwrap();
let signature1b = test_key.sign(message1).await.unwrap();
assert_eq!(
signature1a, signature1b,
"Deterministic signing should produce identical signatures"
);
let signature2 = test_key.sign(message2).await.unwrap();
assert_ne!(
signature1a, signature2,
"Different messages should produce different signatures"
);
}
#[test_case(b"" ; "empty message")]
#[test_case(b"short" ; "short message")]
#[test_case(&[0u8; 1000] ; "long message")]
#[test_case(b"special chars: \x00\xff\n\r\t" ; "special characters")]
#[tokio::test]
async fn test_signing_various_messages(message: &[u8]) {
let key = PrivateKey::new(TEST_PRIVATE_KEY_PEM.to_string());
let test_key = key.get_key().await.unwrap();
let signature = test_key.sign(message).await;
assert!(
signature.is_ok(),
"Should successfully sign message of length {}",
message.len()
);
}
#[tokio::test]
async fn test_signature_into_signature() {
let key = PrivateKey::new(TEST_PRIVATE_KEY_PEM.to_string());
let test_key = key.get_key().await.unwrap();
let original_signature = test_key.sign(b"test").await.unwrap();
let result = original_signature.sign(b"ignored_message").await.unwrap();
assert_eq!(
result, original_signature,
"Signature should return itself regardless of message"
);
}
#[tokio::test]
#[traced_test]
async fn test_authorization_context_empty() {
let ctx = AuthorizationContext::new();
let signatures: Vec<_> = ctx.sign(b"test").try_collect().await.unwrap();
assert!(
signatures.is_empty(),
"Empty context should produce no signatures"
);
}
#[tokio::test]
#[traced_test]
async fn test_authorization_context_single_key() {
let key = PrivateKey::new(TEST_PRIVATE_KEY_PEM.to_string());
let ctx = AuthorizationContext::new().push(key);
let signatures: Vec<_> = ctx.sign(b"test").try_collect().await.unwrap();
assert_eq!(
signatures.len(),
1,
"Context with one key should produce one signature"
);
}
#[tokio::test]
#[traced_test]
async fn test_authorization_context_multiple_keys() {
let key_bytes = [2u8; 32]; let second_key = SecretKey::<p256::NistP256>::from_bytes(&key_bytes.into()).unwrap();
let ctx = AuthorizationContext::new()
.push(PrivateKey::new(TEST_PRIVATE_KEY_PEM.to_string()))
.push(second_key);
let signatures: Vec<_> = ctx.sign(b"test").try_collect().await.unwrap();
assert_eq!(
signatures.len(),
2,
"Context with two keys should produce two signatures"
);
assert_ne!(
signatures[0], signatures[1],
"Different keys should produce different signatures"
);
}
#[tokio::test]
#[traced_test]
async fn test_authorization_context_validation() {
let ctx =
AuthorizationContext::new().push(PrivateKey::new(TEST_PRIVATE_KEY_PEM.to_string()));
let errors = ctx.validate().await;
assert!(
errors.is_empty(),
"Valid context should have no validation errors"
);
let ctx2 =
AuthorizationContext::new().push(PrivateKey::new("invalid_key_data".to_string()));
let errors2 = ctx2.validate().await;
assert!(
!errors2.is_empty(),
"Invalid key should produce validation errors"
);
}
#[tokio::test]
async fn test_fn_signer_wrapper() {
use crate::SigningError;
#[derive(Debug)]
struct DummyError;
impl std::error::Error for DummyError {}
impl std::fmt::Display for DummyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "dummy error")
}
}
let _signer = FnSigner(|_message: &[u8]| async move {
let result: Result<Signature, SigningError> =
Err(SigningError::Other(Box::new(DummyError)));
result
});
assert!(matches!(
_signer.sign(&[0]).await,
Err(SigningError::Other(_))
));
}
#[tokio::test]
async fn test_fn_key_wrapper() {
let key_fn = FnKey(|| async {
PrivateKey::new(TEST_PRIVATE_KEY_PEM.to_string())
.get_key()
.await
});
let key1 = key_fn.get_key().await.unwrap();
let key2 = key_fn.get_key().await.unwrap();
assert_eq!(key1.to_bytes(), key2.to_bytes());
}
#[tokio::test]
#[traced_test]
async fn test_authorization_context_concurrent_signing() {
let mut ctx = AuthorizationContext::new();
for i in 0..5 {
let mut key_bytes = [1u8; 32];
key_bytes[0] = i as u8 + 1; let key = SecretKey::<p256::NistP256>::from_bytes(&key_bytes.into()).unwrap();
ctx = ctx.push(key);
}
let message = b"concurrent test message";
let signatures: Vec<_> = ctx.sign(message).try_collect().await.unwrap();
assert_eq!(
signatures.len(),
5,
"Should produce 5 signatures concurrently"
);
for i in 0..signatures.len() {
for j in (i + 1)..signatures.len() {
assert_ne!(
signatures[i], signatures[j],
"Signatures from different keys should be different"
);
}
}
}
#[tokio::test]
async fn test_key_public_key_derivation() {
let private_key = PrivateKey::new(TEST_PRIVATE_KEY_PEM.to_string());
let key = private_key.get_key().await.unwrap();
let public_key = key.public_key();
assert!(
!public_key.to_string().is_empty(),
"Public key string should not be empty"
);
let public_key2 = key.public_key();
assert_eq!(
public_key.to_string(),
public_key2.to_string(),
"Public key derivation should be consistent"
);
}
#[tokio::test]
#[traced_test]
async fn test_authorization_context_mixed_sources() {
let ctx = AuthorizationContext::new()
.push(PrivateKey::new(
include_str!("../tests/test_private_key.pem").to_string(),
))
.push(Signature::from_bytes(GenericArray::from_slice(&STANDARD.decode("J7GLk/CIqvCNCOSJ8sUZb0rCsqWF9l1H1VgYfsAd1ew2uBJHE5hoY+kV7CSzdKkgOhtdvzj22gXA7gcn5gSqvQ==").unwrap())).expect("right size"));
let sigs = ctx
.sign(&[0, 1, 2, 3])
.try_collect::<Vec<_>>()
.await
.expect("passes");
assert!(
!sigs.is_empty(),
"Context with mixed sources should produce signatures"
);
}
#[tokio::test]
async fn test_signing_error_propagation() {
struct FailingKey;
#[derive(Debug)]
struct DummyError;
impl std::error::Error for DummyError {}
impl std::fmt::Display for DummyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "dummy error")
}
}
impl IntoKey for FailingKey {
async fn get_key(&self) -> Result<SecretKey<p256::NistP256>, KeyError> {
Err(KeyError::Other(Box::new(DummyError)))
}
}
let failing_key = FailingKey;
let result = failing_key.sign(b"test").await;
assert!(matches!(result, Err(SigningError::Key(KeyError::Other(_)))));
}
#[tokio::test]
async fn test_authorization_context_clone_and_debug() {
let ctx1 =
AuthorizationContext::new().push(PrivateKey::new(TEST_PRIVATE_KEY_PEM.to_string()));
let ctx2 = ctx1.clone();
let sigs1: Vec<_> = ctx1.sign(b"test").try_collect().await.unwrap();
let sigs2: Vec<_> = ctx2.sign(b"test").try_collect().await.unwrap();
assert_eq!(sigs1.len(), 1);
assert_eq!(sigs2.len(), 1);
assert_eq!(
sigs1[0], sigs2[0],
"Cloned context should produce same signatures"
);
let debug_str = format!("{ctx1:?}");
assert!(
debug_str.contains("AuthorizationContext"),
"Debug output should contain struct name"
);
}
}