#[cfg(feature = "alloc")]
use alloc::{collections::BTreeMap, vec::Vec};
use noxtls_core::{Error, Result};
use noxtls_crypto::{
aes_gcm_encrypt, p256_ecdh_shared_secret, p256_ecdsa_sign_sha256, rsaes_oaep_sha256_decrypt,
rsaes_pkcs1_v15_decrypt, rsassa_pss_sha256_sign, rsassa_sha256_sign, sha256, x25519, AesCipher,
P256PrivateKey, P256PublicKey, RsaPrivateKey,
};
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub struct PsaExternalKeyHandle {
id: Vec<u8>,
}
impl PsaExternalKeyHandle {
pub fn new(id: Vec<u8>) -> Self {
Self { id }
}
pub fn as_bytes(&self) -> &[u8] {
&self.id
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum PsaSignAlgorithm {
RsaPkcs1Sha256,
RsaPssSha256,
EcdsaP256Sha256,
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum PsaDecryptAlgorithm {
RsaPkcs1v15,
RsaOaepSha256,
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum PsaDeriveAlgorithm {
X25519,
EcdhP256,
}
#[derive(Clone, Debug)]
pub struct KeySignRequest<'a> {
pub handle: &'a PsaExternalKeyHandle,
pub algorithm: PsaSignAlgorithm,
pub message: &'a [u8],
pub salt: Option<&'a [u8]>,
}
#[derive(Clone, Debug)]
pub struct KeyDecryptRequest<'a> {
pub handle: &'a PsaExternalKeyHandle,
pub algorithm: PsaDecryptAlgorithm,
pub ciphertext: &'a [u8],
pub label: Option<&'a [u8]>,
}
#[derive(Clone, Debug)]
pub struct KeyDeriveRequest<'a> {
pub handle: &'a PsaExternalKeyHandle,
pub algorithm: PsaDeriveAlgorithm,
pub peer_public_key: &'a [u8],
}
#[derive(Clone, Debug)]
pub struct AeadEncryptRequest<'a> {
pub key: &'a [u8],
pub nonce: &'a [u8],
pub aad: &'a [u8],
pub plaintext: &'a [u8],
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct AeadEncryptResponse {
pub ciphertext: Vec<u8>,
pub tag: [u8; 16],
}
pub trait PsaCryptoBackend {
fn sign(&self, request: &KeySignRequest<'_>) -> Result<Vec<u8>>;
fn decrypt(&self, request: &KeyDecryptRequest<'_>) -> Result<Vec<u8>>;
fn derive(&self, request: &KeyDeriveRequest<'_>) -> Result<Vec<u8>>;
fn random(&self, out: &mut [u8]) -> Result<()>;
fn sha256(&self, input: &[u8]) -> Result<[u8; 32]>;
fn aes_gcm_encrypt(&self, request: &AeadEncryptRequest<'_>) -> Result<AeadEncryptResponse>;
}
#[derive(Clone, Debug)]
pub struct PsaProvider<B> {
backend: B,
}
impl<B> PsaProvider<B> {
pub fn new(backend: B) -> Self {
Self { backend }
}
}
impl<B: PsaCryptoBackend> PsaProvider<B> {
pub fn sign(&self, request: &KeySignRequest<'_>) -> Result<Vec<u8>> {
self.backend.sign(request)
}
pub fn decrypt(&self, request: &KeyDecryptRequest<'_>) -> Result<Vec<u8>> {
self.backend
.decrypt(request)
.map_err(|_| Error::CryptoFailure("psa cryptographic operation failed"))
}
pub fn derive(&self, request: &KeyDeriveRequest<'_>) -> Result<Vec<u8>> {
self.backend.derive(request)
}
pub fn random(&self, out: &mut [u8]) -> Result<()> {
self.backend.random(out)
}
pub fn sha256(&self, input: &[u8]) -> Result<[u8; 32]> {
self.backend.sha256(input)
}
pub fn aes_gcm_encrypt(&self, request: &AeadEncryptRequest<'_>) -> Result<AeadEncryptResponse> {
self.backend.aes_gcm_encrypt(request)
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
struct HandlePolicy {
allow_sign: bool,
allow_decrypt: bool,
allow_derive: bool,
}
#[derive(Clone, Debug)]
enum SoftwarePrivateMaterial {
Rsa(RsaPrivateKey),
X25519([u8; 32]),
P256(P256PrivateKey),
}
#[derive(Clone, Debug, Default)]
pub struct PsaSoftwareBackend {
keys: BTreeMap<Vec<u8>, (SoftwarePrivateMaterial, HandlePolicy)>,
}
impl PsaSoftwareBackend {
pub fn new() -> Self {
Self {
keys: BTreeMap::new(),
}
}
pub fn register_rsa_key(
&mut self,
handle: PsaExternalKeyHandle,
key: RsaPrivateKey,
allow_sign: bool,
allow_decrypt: bool,
) -> Result<()> {
self.insert_key(
handle,
SoftwarePrivateMaterial::Rsa(key),
HandlePolicy {
allow_sign,
allow_decrypt,
allow_derive: false,
},
)
}
pub fn register_x25519_key(
&mut self,
handle: PsaExternalKeyHandle,
key: [u8; 32],
allow_derive: bool,
) -> Result<()> {
self.insert_key(
handle,
SoftwarePrivateMaterial::X25519(key),
HandlePolicy {
allow_sign: false,
allow_decrypt: false,
allow_derive,
},
)
}
pub fn register_p256_key(
&mut self,
handle: PsaExternalKeyHandle,
key: P256PrivateKey,
allow_sign: bool,
allow_derive: bool,
) -> Result<()> {
self.insert_key(
handle,
SoftwarePrivateMaterial::P256(key),
HandlePolicy {
allow_sign,
allow_decrypt: false,
allow_derive,
},
)
}
fn insert_key(
&mut self,
handle: PsaExternalKeyHandle,
material: SoftwarePrivateMaterial,
policy: HandlePolicy,
) -> Result<()> {
if self.keys.contains_key(handle.as_bytes()) {
return Err(Error::StateError("psa key handle already registered"));
}
self.keys.insert(handle.id, (material, policy));
Ok(())
}
fn resolve_key(
&self,
handle: &PsaExternalKeyHandle,
) -> Result<&(SoftwarePrivateMaterial, HandlePolicy)> {
self.keys
.get(handle.as_bytes())
.ok_or(Error::StateError("psa key handle invalid"))
}
}
impl PsaCryptoBackend for PsaSoftwareBackend {
fn sign(&self, request: &KeySignRequest<'_>) -> Result<Vec<u8>> {
let (material, policy) = self.resolve_key(request.handle)?;
if !policy.allow_sign {
return Err(Error::StateError("psa sign not permitted by key policy"));
}
match (request.algorithm, material) {
(PsaSignAlgorithm::RsaPkcs1Sha256, SoftwarePrivateMaterial::Rsa(key)) => {
rsassa_sha256_sign(key, request.message)
}
(PsaSignAlgorithm::RsaPssSha256, SoftwarePrivateMaterial::Rsa(key)) => {
let salt = request.salt.ok_or(Error::InvalidLength(
"rsa-pss-sha256 signing requires a salt",
))?;
rsassa_pss_sha256_sign(key, request.message, salt)
}
(PsaSignAlgorithm::EcdsaP256Sha256, SoftwarePrivateMaterial::P256(key)) => {
let (r, s) = p256_ecdsa_sign_sha256(key, request.message)?;
let mut signature = Vec::with_capacity(64);
signature.extend_from_slice(&r);
signature.extend_from_slice(&s);
Ok(signature)
}
_ => Err(Error::UnsupportedFeature("psa sign algorithm/key mismatch")),
}
}
fn decrypt(&self, request: &KeyDecryptRequest<'_>) -> Result<Vec<u8>> {
let (material, policy) = self.resolve_key(request.handle)?;
if !policy.allow_decrypt {
return Err(Error::StateError("psa decrypt not permitted by key policy"));
}
match (request.algorithm, material) {
(PsaDecryptAlgorithm::RsaPkcs1v15, SoftwarePrivateMaterial::Rsa(key)) => {
rsaes_pkcs1_v15_decrypt(key, request.ciphertext)
}
(PsaDecryptAlgorithm::RsaOaepSha256, SoftwarePrivateMaterial::Rsa(key)) => {
rsaes_oaep_sha256_decrypt(key, request.ciphertext, request.label.unwrap_or(&[]))
}
_ => Err(Error::UnsupportedFeature(
"psa decrypt algorithm/key mismatch",
)),
}
}
fn derive(&self, request: &KeyDeriveRequest<'_>) -> Result<Vec<u8>> {
let (material, policy) = self.resolve_key(request.handle)?;
if !policy.allow_derive {
return Err(Error::StateError("psa derive not permitted by key policy"));
}
match (request.algorithm, material) {
(PsaDeriveAlgorithm::X25519, SoftwarePrivateMaterial::X25519(private)) => {
if request.peer_public_key.len() != 32 {
return Err(Error::ParseFailure("x25519 peer public key length invalid"));
}
let mut peer = [0u8; 32];
peer.copy_from_slice(request.peer_public_key);
Ok(x25519(private, &peer).to_vec())
}
(PsaDeriveAlgorithm::EcdhP256, SoftwarePrivateMaterial::P256(private)) => {
let peer = P256PublicKey::from_uncompressed(request.peer_public_key)?;
Ok(p256_ecdh_shared_secret(private, &peer)?.to_vec())
}
_ => Err(Error::UnsupportedFeature(
"psa derive algorithm/key mismatch",
)),
}
}
fn random(&self, out: &mut [u8]) -> Result<()> {
for (idx, byte) in out.iter_mut().enumerate() {
*byte = (idx as u8).wrapping_mul(17).wrapping_add(0x5A);
}
Ok(())
}
fn sha256(&self, input: &[u8]) -> Result<[u8; 32]> {
Ok(sha256(input))
}
fn aes_gcm_encrypt(&self, request: &AeadEncryptRequest<'_>) -> Result<AeadEncryptResponse> {
let cipher = AesCipher::new(request.key)?;
let (ciphertext, tag) =
aes_gcm_encrypt(&cipher, request.nonce, request.aad, request.plaintext)?;
Ok(AeadEncryptResponse { ciphertext, tag })
}
}
pub type PsaSoftwareProvider = PsaProvider<PsaSoftwareBackend>;