use crate::error::WSError;
use crate::platform::{KeyHandle, PublicKey, SecureKeyProvider, SecurityLevel};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use tss_esapi::{
attributes::ObjectAttributesBuilder,
interface_types::{
algorithm::{HashingAlgorithm, PublicAlgorithm},
ecc::EccCurve,
resource_handles::Hierarchy,
},
structures::{
EccPoint, EccScheme, HashScheme, KeyDerivationFunctionScheme,
PublicBuilder, PublicEccParametersBuilder, SignatureScheme,
},
tcti_ldr::{DeviceConfig, NetworkTPMConfig, TctiNameConf},
utils::PublicKey as TssPublicKey,
Context,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TpmAlgorithm {
Ed25519,
EcdsaP256,
}
struct TpmKeyData {
key_handle: tss_esapi::handles::KeyHandle,
public_key: Vec<u8>,
algorithm: TpmAlgorithm,
key_id: Option<Vec<u8>>,
}
pub struct Tpm2Provider {
context: Arc<Mutex<Context>>,
keys: Mutex<HashMap<u64, TpmKeyData>>,
next_handle: Mutex<u64>,
preferred_algorithm: TpmAlgorithm,
manufacturer: Option<String>,
}
impl Tpm2Provider {
pub fn new() -> Result<Self, WSError> {
let tcti = Self::detect_tcti()?;
Self::with_tcti(tcti)
}
pub fn with_tcti(tcti: TctiNameConf) -> Result<Self, WSError> {
let context = Context::new(tcti).map_err(|e| {
WSError::HardwareError(format!("Failed to connect to TPM: {}", e))
})?;
let context = Arc::new(Mutex::new(context));
let preferred_algorithm = Self::detect_preferred_algorithm(&context)?;
let manufacturer = Self::get_manufacturer_info(&context).ok();
log::info!(
"TPM2 provider initialized: algorithm={:?}, manufacturer={:?}",
preferred_algorithm,
manufacturer
);
Ok(Self {
context,
keys: Mutex::new(HashMap::new()),
next_handle: Mutex::new(1),
preferred_algorithm,
manufacturer,
})
}
pub fn with_simulator() -> Result<Self, WSError> {
let tcti = TctiNameConf::Swtpm(NetworkTPMConfig::default());
Self::with_tcti(tcti)
}
fn detect_tcti() -> Result<TctiNameConf, WSError> {
if std::env::var("TPM2_TCTI").is_ok() {
return TctiNameConf::from_environment_variable().map_err(|e| {
WSError::HardwareError(format!("Invalid TPM2_TCTI: {}", e))
});
}
#[cfg(target_os = "linux")]
{
use std::path::Path;
if Path::new("/dev/tpmrm0").exists() {
log::debug!("Using TPM resource manager at /dev/tpmrm0");
return Ok(TctiNameConf::Device(DeviceConfig::default()));
}
if Path::new("/dev/tpm0").exists() {
log::debug!("Using TPM device at /dev/tpm0");
return Ok(TctiNameConf::Device(DeviceConfig::default()));
}
}
#[cfg(target_os = "windows")]
{
log::debug!("Using Windows TBS");
return Ok(TctiNameConf::Tbs(Default::default()));
}
Err(WSError::HardwareError(
"No TPM2 device found. On Linux, ensure /dev/tpmrm0 exists and is accessible. \
Set TPM2_TCTI environment variable for custom configuration."
.to_string(),
))
}
fn detect_preferred_algorithm(context: &Arc<Mutex<Context>>) -> Result<TpmAlgorithm, WSError> {
let mut ctx = context
.lock()
.map_err(|_| WSError::InternalError("TPM context lock poisoned".to_string()))?;
let (caps, _more) = ctx
.get_capability(
tss_esapi::constants::CapabilityType::EccCurves,
0,
100,
)
.map_err(|e| {
WSError::HardwareError(format!("Failed to query TPM capabilities: {}", e))
})?;
log::debug!("TPM ECC capabilities: {:?}", caps);
log::info!("TPM2: Using ECDSA P-256 (universal TPM2 support)");
Ok(TpmAlgorithm::EcdsaP256)
}
fn get_manufacturer_info(context: &Arc<Mutex<Context>>) -> Result<String, WSError> {
let mut ctx = context
.lock()
.map_err(|_| WSError::InternalError("TPM context lock poisoned".to_string()))?;
let (caps, _) = ctx
.get_capability(
tss_esapi::constants::CapabilityType::TpmProperties,
tss_esapi::constants::tss::TPM2_PT_MANUFACTURER,
1,
)
.map_err(|e| WSError::InternalError(e.to_string()))?;
Ok(format!("{:?}", caps))
}
fn allocate_handle(&self) -> u64 {
let mut next = self.next_handle.lock().unwrap();
let handle = *next;
*next += 1;
handle
}
pub fn algorithm(&self) -> TpmAlgorithm {
self.preferred_algorithm
}
pub fn manufacturer(&self) -> Option<&str> {
self.manufacturer.as_deref()
}
fn generate_p256_key(
ctx: &mut Context,
) -> Result<(tss_esapi::handles::KeyHandle, Vec<u8>), WSError> {
let object_attributes = ObjectAttributesBuilder::new()
.with_fixed_tpm(true)
.with_fixed_parent(true)
.with_sensitive_data_origin(true)
.with_user_with_auth(true)
.with_sign_encrypt(true)
.build()
.map_err(|e| WSError::InternalError(format!("Failed to build attributes: {}", e)))?;
let ecc_params = PublicEccParametersBuilder::new()
.with_ecc_scheme(EccScheme::EcDsa(HashScheme::new(HashingAlgorithm::Sha256)))
.with_curve(EccCurve::NistP256)
.with_is_signing_key(true)
.with_is_decryption_key(false)
.with_restricted(false)
.with_key_derivation_function_scheme(KeyDerivationFunctionScheme::Null)
.build()
.map_err(|e| WSError::InternalError(format!("Failed to build ECC params: {}", e)))?;
let public = PublicBuilder::new()
.with_public_algorithm(PublicAlgorithm::Ecc)
.with_name_hashing_algorithm(HashingAlgorithm::Sha256)
.with_object_attributes(object_attributes)
.with_ecc_parameters(ecc_params)
.with_ecc_unique_identifier(EccPoint::default())
.build()
.map_err(|e| WSError::InternalError(format!("Failed to build public template: {}", e)))?;
let result = ctx
.execute_with_nullauth_session(|ctx| {
ctx.create_primary(Hierarchy::Owner, public.clone(), None, None, None, None)
})
.map_err(|e| WSError::InternalError(format!("TPM key creation failed: {}", e)))?;
let public_key = Self::extract_ecc_public_key(&result.out_public)?;
Ok((result.key_handle, public_key))
}
fn extract_ecc_public_key(
public: &tss_esapi::structures::Public,
) -> Result<Vec<u8>, WSError> {
let public_key = TssPublicKey::try_from(public.clone())
.map_err(|e| WSError::InternalError(format!("Failed to extract public key: {}", e)))?;
match public_key {
TssPublicKey::Ecc { x, y } => {
let mut pk = Vec::with_capacity(1 + x.len() + y.len());
pk.push(0x04); pk.extend_from_slice(&x);
pk.extend_from_slice(&y);
Ok(pk)
}
_ => Err(WSError::InternalError(
"Expected ECC public key from TPM".to_string(),
)),
}
}
fn convert_ecdsa_signature_to_der(
signature: &tss_esapi::structures::Signature,
) -> Result<Vec<u8>, WSError> {
match signature {
tss_esapi::structures::Signature::EcDsa(ecdsa_sig) => {
let r = ecdsa_sig.signature_r().value();
let s = ecdsa_sig.signature_s().value();
let mut der = Vec::new();
fn encode_integer(val: &[u8]) -> Vec<u8> {
let mut result = Vec::new();
let val = val.iter().skip_while(|&&b| b == 0).copied().collect::<Vec<_>>();
let val = if val.is_empty() { vec![0] } else { val };
let needs_padding = val[0] & 0x80 != 0;
let len = val.len() + if needs_padding { 1 } else { 0 };
result.push(0x02); result.push(len as u8);
if needs_padding {
result.push(0x00);
}
result.extend(val);
result
}
let r_der = encode_integer(r);
let s_der = encode_integer(s);
der.push(0x30); der.push((r_der.len() + s_der.len()) as u8);
der.extend(r_der);
der.extend(s_der);
Ok(der)
}
_ => Err(WSError::InternalError(
"Expected ECDSA signature from TPM".to_string(),
)),
}
}
}
impl SecureKeyProvider for Tpm2Provider {
fn name(&self) -> &str {
"TPM 2.0"
}
fn security_level(&self) -> SecurityLevel {
SecurityLevel::HardwareBacked
}
fn health_check(&self) -> Result<(), WSError> {
let mut ctx = self
.context
.lock()
.map_err(|_| WSError::InternalError("TPM context lock poisoned".to_string()))?;
ctx.get_capability(
tss_esapi::constants::CapabilityType::TpmProperties,
tss_esapi::constants::tss::TPM2_PT_FAMILY_INDICATOR,
1,
)
.map_err(|e| WSError::HardwareError(format!("TPM health check failed: {}", e)))?;
Ok(())
}
fn generate_key(&self) -> Result<KeyHandle, WSError> {
let mut ctx = self
.context
.lock()
.map_err(|_| WSError::InternalError("TPM context lock poisoned".to_string()))?;
let (tpm_handle, public_key) = Self::generate_p256_key(&mut ctx)?;
let handle_id = self.allocate_handle();
let key_data = TpmKeyData {
key_handle: tpm_handle,
public_key,
algorithm: TpmAlgorithm::EcdsaP256,
key_id: None,
};
drop(ctx); self.keys.lock().unwrap().insert(handle_id, key_data);
log::debug!("Generated TPM key with handle {}", handle_id);
Ok(KeyHandle::from_raw(handle_id))
}
fn load_key(&self, key_id: &str) -> Result<KeyHandle, WSError> {
Err(WSError::KeyNotFound(format!(
"Loading persistent TPM keys not yet implemented. Key ID: {}",
key_id
)))
}
fn sign(&self, handle: KeyHandle, data: &[u8]) -> Result<Vec<u8>, WSError> {
let tpm_handle = {
let keys = self.keys.lock().unwrap();
let key_data = keys.get(&handle.as_raw()).ok_or(WSError::InvalidKeyHandle)?;
key_data.key_handle
};
let mut ctx = self
.context
.lock()
.map_err(|_| WSError::InternalError("TPM context lock poisoned".to_string()))?;
let data_buffer = tss_esapi::structures::MaxBuffer::try_from(data.to_vec())
.map_err(|e| WSError::InternalError(format!("Data too large for TPM: {}", e)))?;
let signature = ctx
.execute_with_nullauth_session(|ctx| {
let (digest, ticket) = ctx.hash(
data_buffer.clone(),
HashingAlgorithm::Sha256,
Hierarchy::Owner,
)?;
ctx.sign(tpm_handle, digest, SignatureScheme::Null, ticket)
})
.map_err(|e| WSError::InternalError(format!("TPM signing failed: {}", e)))?;
Self::convert_ecdsa_signature_to_der(&signature)
}
fn get_public_key(&self, handle: KeyHandle) -> Result<PublicKey, WSError> {
let keys = self.keys.lock().unwrap();
let _key_data = keys.get(&handle.as_raw()).ok_or(WSError::InvalidKeyHandle)?;
Err(WSError::InternalError(
"TPM2 uses P-256, not Ed25519. Use get_public_key_bytes() instead.".to_string(),
))
}
fn delete_key(&self, handle: KeyHandle) -> Result<(), WSError> {
let key_data = {
let mut keys = self.keys.lock().unwrap();
keys.remove(&handle.as_raw()).ok_or(WSError::InvalidKeyHandle)?
};
let mut ctx = self
.context
.lock()
.map_err(|_| WSError::InternalError("TPM context lock poisoned".to_string()))?;
ctx.flush_context(key_data.key_handle.into())
.map_err(|e| WSError::InternalError(format!("Failed to flush TPM key: {}", e)))?;
log::debug!("Deleted TPM key with handle {}", handle.as_raw());
Ok(())
}
fn list_keys(&self) -> Result<Vec<KeyHandle>, WSError> {
let keys = self.keys.lock().unwrap();
Ok(keys.keys().map(|&id| KeyHandle::from_raw(id)).collect())
}
}
impl Tpm2Provider {
pub fn get_public_key_bytes(&self, handle: KeyHandle) -> Result<Vec<u8>, WSError> {
let keys = self.keys.lock().unwrap();
let key_data = keys.get(&handle.as_raw()).ok_or(WSError::InvalidKeyHandle)?;
Ok(key_data.public_key.clone())
}
pub fn verify_signature(
&self,
handle: KeyHandle,
data: &[u8],
signature_der: &[u8],
) -> Result<bool, WSError> {
use p256::ecdsa::{signature::Verifier, Signature, VerifyingKey};
let public_key_bytes = self.get_public_key_bytes(handle)?;
let verifying_key = VerifyingKey::from_sec1_bytes(&public_key_bytes)
.map_err(|e| WSError::VerificationError(format!("Invalid public key: {}", e)))?;
let signature = Signature::from_der(signature_der)
.map_err(|e| WSError::VerificationError(format!("Invalid signature: {}", e)))?;
Ok(verifying_key.verify(data, &signature).is_ok())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[ignore = "Requires TPM hardware or swtpm simulator"]
fn test_tpm2_provider_creation() {
let provider = Tpm2Provider::new();
match provider {
Ok(p) => {
println!("TPM2 provider created successfully");
println!("Algorithm: {:?}", p.algorithm());
println!("Manufacturer: {:?}", p.manufacturer());
}
Err(e) => {
println!("TPM2 not available (expected on most dev machines): {}", e);
}
}
}
#[test]
#[ignore = "Requires TPM hardware or swtpm simulator"]
fn test_tpm2_with_simulator() {
let provider = match Tpm2Provider::with_simulator() {
Ok(p) => p,
Err(e) => {
println!("swtpm not running: {}", e);
return;
}
};
provider.health_check().expect("Health check failed");
println!("TPM health check passed");
}
#[test]
#[ignore = "Requires TPM hardware or swtpm simulator"]
fn test_tpm2_key_generation() {
let provider = match Tpm2Provider::with_simulator() {
Ok(p) => p,
Err(_) => return,
};
let handle = provider.generate_key().expect("Key generation failed");
println!("Generated key with handle: {}", handle.as_raw());
let public_key = provider
.get_public_key_bytes(handle)
.expect("Get public key failed");
println!("Public key length: {} bytes", public_key.len());
assert_eq!(public_key.len(), 65);
provider.delete_key(handle).expect("Key deletion failed");
}
#[test]
#[ignore = "Requires TPM hardware or swtpm simulator"]
fn test_tpm2_sign_verify() {
let provider = match Tpm2Provider::with_simulator() {
Ok(p) => p,
Err(_) => return,
};
let handle = provider.generate_key().expect("Key generation failed");
let data = b"test data for TPM signing";
let signature = provider.sign(handle, data).expect("Signing failed");
println!("Signature length: {} bytes", signature.len());
let verified = provider
.verify_signature(handle, data, &signature)
.expect("Verification failed");
assert!(verified, "Signature verification should succeed");
let wrong_verified = provider
.verify_signature(handle, b"wrong data", &signature)
.expect("Verification call failed");
assert!(!wrong_verified, "Wrong data should not verify");
provider.delete_key(handle).unwrap();
}
#[test]
#[ignore = "Requires TPM hardware or swtpm simulator"]
fn test_tpm2_list_keys() {
let provider = match Tpm2Provider::with_simulator() {
Ok(p) => p,
Err(_) => return,
};
let keys = provider.list_keys().expect("List keys failed");
let initial_count = keys.len();
let h1 = provider.generate_key().expect("Key 1 generation failed");
let h2 = provider.generate_key().expect("Key 2 generation failed");
let keys = provider.list_keys().expect("List keys failed");
assert_eq!(keys.len(), initial_count + 2);
provider.delete_key(h1).unwrap();
provider.delete_key(h2).unwrap();
let keys = provider.list_keys().expect("List keys failed");
assert_eq!(keys.len(), initial_count);
}
}