use anyhow::{Context, Result};
use async_trait::async_trait;
use cryptoki::context::{CInitializeArgs, Pkcs11};
use cryptoki::mechanism::Mechanism;
use cryptoki::object::{Attribute, AttributeType, ObjectHandle};
use cryptoki::session::{Session, UserType};
use cryptoki::slot::Slot;
use cryptoki::types::AuthPin;
use std::sync::Arc;
use zeroize::Zeroize;
use super::CryptoProvider;
pub struct Pkcs11CryptoProvider {
pkcs11: Arc<Pkcs11>,
slot: Slot,
pin: String,
key_label: String,
public_key: Vec<u8>,
key_id: String,
context: Vec<u8>,
mac_base_key: [u8; 32],
}
impl Pkcs11CryptoProvider {
pub async fn new(
module_path: &str,
slot: u64,
pin: &str,
key_label: &str,
key_id: String,
context: Vec<u8>,
) -> Result<Self> {
let pkcs11 = Pkcs11::new(module_path).context("Failed to load PKCS#11 module")?;
pkcs11
.initialize(CInitializeArgs::OsThreads)
.context("Failed to initialize PKCS#11")?;
let pkcs11 = Arc::new(pkcs11);
let slot_id = Slot::try_from(slot).context("Invalid slot number")?;
let session = pkcs11
.open_rw_session(slot_id)
.context("Failed to open HSM session")?;
let auth_pin = AuthPin::new(pin.to_string());
session
.login(UserType::User, Some(&auth_pin))
.context("Failed to authenticate with HSM")?;
let private_key_handle = Self::find_key_by_label(&session, key_label, true)
.context("Failed to find private key in HSM")?;
let public_key_handle = Self::find_key_by_label(&session, key_label, false)
.context("Failed to find public key in HSM")?;
let public_key = Self::extract_public_key(&session, public_key_handle)
.context("Failed to extract public key from HSM")?;
let mac_base_key = Self::derive_mac_base_key(&session, private_key_handle)
.context("Failed to derive MAC base key from HSM")?;
let _ = session.logout();
drop(session);
Ok(Self {
pkcs11,
slot: slot_id,
pin: pin.to_string(),
key_label: key_label.to_string(),
public_key,
key_id,
context,
mac_base_key,
})
}
fn open_authenticated_session(&self) -> Result<Session> {
let session = self
.pkcs11
.open_rw_session(self.slot)
.context("Failed to open HSM session")?;
let auth_pin = AuthPin::new(self.pin.clone());
session
.login(UserType::User, Some(&auth_pin))
.context("Failed to authenticate with HSM")?;
Ok(session)
}
fn find_key_by_label(session: &Session, label: &str, is_private: bool) -> Result<ObjectHandle> {
let class = if is_private {
cryptoki::object::ObjectClass::PRIVATE_KEY
} else {
cryptoki::object::ObjectClass::PUBLIC_KEY
};
let template = vec![
Attribute::Class(class),
Attribute::Label(label.as_bytes().to_vec()),
];
let objects = session
.find_objects(&template)
.context("HSM key search failed")?;
objects
.first()
.copied()
.ok_or_else(|| anyhow::anyhow!("Key '{}' not found in HSM", label))
}
fn extract_public_key(session: &Session, public_key_handle: ObjectHandle) -> Result<Vec<u8>> {
let attributes = session
.get_attributes(public_key_handle, &[AttributeType::EcPoint])
.context("Failed to read public key from HSM")?;
let ec_point = attributes
.first()
.and_then(|attr| {
if let Attribute::EcPoint(point) = attr {
Some(point.clone())
} else {
None
}
})
.ok_or_else(|| anyhow::anyhow!("EC_POINT attribute not found"))?;
if ec_point.len() < 65 {
anyhow::bail!("Invalid EC_POINT length: {}", ec_point.len());
}
let point_data = if ec_point[0] == 0x04 && ec_point[1] == 0x41 {
&ec_point[2..] } else {
&ec_point[..]
};
if point_data.len() != 65 || point_data[0] != 0x04 {
anyhow::bail!("Expected uncompressed P-256 point (65 bytes)");
}
let x = &point_data[1..33];
let y = &point_data[33..65];
let prefix = if y[31] & 1 == 0 { 0x02 } else { 0x03 };
let mut compressed = vec![prefix];
compressed.extend_from_slice(x);
Ok(compressed)
}
fn derive_mac_base_key(
session: &Session,
private_key_handle: ObjectHandle,
) -> Result<[u8; 32]> {
let message = b"freebird-mac-base-key-derivation-v1";
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(message);
let digest = hasher.finalize();
let mechanism = Mechanism::Ecdsa;
let signature = session
.sign(&mechanism, private_key_handle, &digest)
.context("Failed to sign with HSM for MAC derivation")?;
let signature = Self::normalize_ecdsa_signature(&signature)?;
let mut hasher = Sha256::new();
hasher.update(&signature);
let base_key = hasher.finalize();
let mut result = [0u8; 32];
result.copy_from_slice(&base_key);
Ok(result)
}
fn parse_der_len(data: &[u8], idx: usize) -> Result<(usize, usize)> {
let first = *data
.get(idx)
.ok_or_else(|| anyhow::anyhow!("invalid DER length"))?;
if first & 0x80 == 0 {
return Ok((first as usize, idx + 1));
}
let count = (first & 0x7f) as usize;
if count == 0 || count > 4 {
anyhow::bail!("unsupported DER length encoding");
}
let mut len = 0usize;
let mut pos = idx + 1;
for _ in 0..count {
let b = *data
.get(pos)
.ok_or_else(|| anyhow::anyhow!("truncated DER length"))?;
len = (len << 8) | (b as usize);
pos += 1;
}
Ok((len, pos))
}
fn asn1_int_to_32(bytes: &[u8]) -> Result<[u8; 32]> {
let mut v = bytes;
while v.len() > 1 && v[0] == 0 {
v = &v[1..];
}
if v.len() > 32 {
anyhow::bail!("ECDSA integer too large");
}
let mut out = [0u8; 32];
out[32 - v.len()..].copy_from_slice(v);
Ok(out)
}
fn normalize_ecdsa_signature(sig: &[u8]) -> Result<[u8; 64]> {
if sig.len() == 64 {
let mut out = [0u8; 64];
out.copy_from_slice(sig);
return Ok(out);
}
if sig.first().copied() != Some(0x30) {
anyhow::bail!("unsupported ECDSA signature format");
}
let (seq_len, mut idx) = Self::parse_der_len(sig, 1)?;
if idx + seq_len != sig.len() {
anyhow::bail!("invalid DER signature length");
}
if sig.get(idx).copied() != Some(0x02) {
anyhow::bail!("missing DER INTEGER for r");
}
idx += 1;
let (r_len, next) = Self::parse_der_len(sig, idx)?;
idx = next;
let r_end = idx + r_len;
let r = Self::asn1_int_to_32(
sig.get(idx..r_end)
.ok_or_else(|| anyhow::anyhow!("truncated DER r"))?,
)?;
idx = r_end;
if sig.get(idx).copied() != Some(0x02) {
anyhow::bail!("missing DER INTEGER for s");
}
idx += 1;
let (s_len, next) = Self::parse_der_len(sig, idx)?;
idx = next;
let s_end = idx + s_len;
let s = Self::asn1_int_to_32(
sig.get(idx..s_end)
.ok_or_else(|| anyhow::anyhow!("truncated DER s"))?,
)?;
idx = s_end;
if idx != sig.len() {
anyhow::bail!("trailing bytes in DER signature");
}
let mut out = [0u8; 64];
out[..32].copy_from_slice(&r);
out[32..].copy_from_slice(&s);
Ok(out)
}
fn voprf_evaluate_internal(&self, _blinded: &[u8]) -> Result<Vec<u8>> {
anyhow::bail!(
"PKCS#11 VOPRF evaluation not yet implemented. \
HSM-native scalar multiplication requires vendor-specific extensions. \
Consider using SoftwareCryptoProvider for now."
)
}
}
#[async_trait]
impl CryptoProvider for Pkcs11CryptoProvider {
async fn voprf_evaluate(&self, blinded: &[u8]) -> Result<Vec<u8>> {
self.voprf_evaluate_internal(blinded)
}
async fn derive_mac_key(&self, issuer_id: &str, kid: &str, epoch: u32) -> Result<[u8; 32]> {
Ok(crate::derive_mac_key_v2(
&self.mac_base_key,
issuer_id,
kid,
epoch,
))
}
async fn sign_token_metadata(
&self,
token_bytes: &[u8],
kid: &str,
exp: i64,
issuer_id: &str,
) -> Result<[u8; 64]> {
use sha2::{Digest, Sha256};
let mut msg = Vec::new();
msg.extend_from_slice(token_bytes);
msg.extend_from_slice(kid.as_bytes());
msg.extend_from_slice(&exp.to_be_bytes());
msg.extend_from_slice(issuer_id.as_bytes());
let msg_hash = Sha256::digest(&msg);
let session = self.open_authenticated_session()?;
let private_key_handle = Self::find_key_by_label(&session, &self.key_label, true)
.context("Failed to find private key in HSM")?;
let sig = session
.sign(&Mechanism::Ecdsa, private_key_handle, &msg_hash)
.context("Failed to sign token metadata with HSM")?;
let _ = session.logout();
Self::normalize_ecdsa_signature(&sig)
}
fn public_key(&self) -> &[u8] {
&self.public_key
}
fn key_id(&self) -> &str {
&self.key_id
}
fn context(&self) -> &[u8] {
&self.context
}
}
impl Drop for Pkcs11CryptoProvider {
fn drop(&mut self) {
self.mac_base_key.zeroize();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
#[ignore] async fn test_pkcs11_provider_creation() {
let result = Pkcs11CryptoProvider::new(
"/usr/lib/softhsm/libsofthsm2.so",
0,
"1234",
"test-key",
"key-001".to_string(),
b"test-context".to_vec(),
)
.await;
assert!(result.is_err() || result.is_ok());
}
#[test]
fn test_normalize_ecdsa_signature_raw_passthrough() {
let mut raw = [0u8; 64];
for (i, b) in raw.iter_mut().enumerate() {
*b = i as u8;
}
let normalized = Pkcs11CryptoProvider::normalize_ecdsa_signature(&raw).unwrap();
assert_eq!(normalized, raw);
}
#[test]
fn test_normalize_ecdsa_signature_der_standard() {
let r: [u8; 32] = [
0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee,
0xff, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xa0, 0xb0, 0xc0, 0xd0,
0xe0, 0xf0, 0x01, 0x02,
];
let s: [u8; 32] = [
0xfe, 0xed, 0xdc, 0xcb, 0xba, 0xa9, 0x98, 0x87, 0x76, 0x65, 0x54, 0x43, 0x32, 0x21,
0x10, 0x0f, 0x1e, 0x2d, 0x3c, 0x4b, 0x5a, 0x69, 0x78, 0x87, 0x96, 0xa5, 0xb4, 0xc3,
0xd2, 0xe1, 0xf0, 0x00,
];
let mut der = Vec::new();
der.push(0x30); der.push(0x44); der.push(0x02); der.push(0x20); der.extend_from_slice(&r);
der.push(0x02); der.push(0x20); der.extend_from_slice(&s);
let normalized = Pkcs11CryptoProvider::normalize_ecdsa_signature(&der).unwrap();
assert_eq!(&normalized[..32], &r);
assert_eq!(&normalized[32..], &s);
}
#[test]
fn test_normalize_ecdsa_signature_der_with_leading_zeros() {
let mut r = [0u8; 32];
let mut s = [0u8; 32];
r[0] = 0x80;
r[31] = 0x7f;
s[0] = 0x90;
s[31] = 0x01;
let mut der = Vec::new();
der.push(0x30); der.push(0x46); der.push(0x02); der.push(0x21); der.push(0x00); der.extend_from_slice(&r);
der.push(0x02); der.push(0x21); der.push(0x00); der.extend_from_slice(&s);
let normalized = Pkcs11CryptoProvider::normalize_ecdsa_signature(&der).unwrap();
assert_eq!(&normalized[..32], &r);
assert_eq!(&normalized[32..], &s);
}
#[test]
fn test_normalize_ecdsa_signature_rejects_invalid_format() {
let bad = [0x01, 0x02, 0x03];
let err = Pkcs11CryptoProvider::normalize_ecdsa_signature(&bad);
assert!(err.is_err());
}
#[test]
fn test_normalize_ecdsa_signature_rejects_truncated_der() {
let der = [0x30, 0x44, 0x02, 0x20, 0xaa];
let err = Pkcs11CryptoProvider::normalize_ecdsa_signature(&der);
assert!(err.is_err());
}
}