use crate::errors::TeeError;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum TeeMode {
#[default]
Disabled,
Direct,
Remote,
Hybrid,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum TeeRequirement {
#[default]
Preferred,
Required,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TeeProvider {
AwsNitro,
AzureSnp,
GcpConfidential,
IntelTdx,
AmdSevSnp,
}
impl core::fmt::Display for TeeProvider {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::AwsNitro => write!(f, "aws_nitro"),
Self::AzureSnp => write!(f, "azure_snp"),
Self::GcpConfidential => write!(f, "gcp_confidential"),
Self::IntelTdx => write!(f, "intel_tdx"),
Self::AmdSevSnp => write!(f, "amd_sev_snp"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TeeProviderSelector {
Any,
AllowList(Vec<TeeProvider>),
}
impl Default for TeeProviderSelector {
fn default() -> Self {
Self::Any
}
}
impl TeeProviderSelector {
pub fn accepts(&self, provider: TeeProvider) -> bool {
match self {
Self::Any => true,
Self::AllowList(providers) => providers.contains(&provider),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TeeRequirements {
pub requirement: TeeRequirement,
#[serde(default)]
pub providers: TeeProviderSelector,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub min_attestation_age_secs: Option<u64>,
}
impl Default for TeeRequirements {
fn default() -> Self {
Self {
requirement: TeeRequirement::Preferred,
providers: TeeProviderSelector::Any,
min_attestation_age_secs: None,
}
}
}
impl TeeRequirements {
pub fn required() -> Self {
Self {
requirement: TeeRequirement::Required,
..Self::default()
}
}
pub fn preferred() -> Self {
Self::default()
}
pub fn is_required(&self) -> bool {
self.requirement == TeeRequirement::Required
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum RuntimeLifecyclePolicy {
Container,
CloudManaged,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SecretInjectionPolicy {
EnvOrSealed,
SealedOnly,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum TeePublicKeyPolicy {
#[default]
Required,
Optional,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum AttestationFreshnessPolicy {
#[default]
ProvisionTimeOnly,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum HybridRoutingSource {
PolicyFile(PathBuf),
}
impl Default for HybridRoutingSource {
fn default() -> Self {
Self::PolicyFile(PathBuf::from("/etc/tee/routing.json"))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TeeKeyExchangeConfig {
#[serde(default = "default_session_ttl_secs")]
pub session_ttl_secs: u64,
#[serde(default = "default_max_sessions")]
pub max_sessions: usize,
#[serde(default)]
pub on_chain_verification: bool,
}
fn default_session_ttl_secs() -> u64 {
300
}
fn default_max_sessions() -> usize {
64
}
impl Default for TeeKeyExchangeConfig {
fn default() -> Self {
Self {
session_ttl_secs: default_session_ttl_secs(),
max_sessions: default_max_sessions(),
on_chain_verification: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(try_from = "TeeConfigRaw")]
pub struct TeeConfig {
pub requirement: TeeRequirement,
pub mode: TeeMode,
pub provider_selector: TeeProviderSelector,
pub key_exchange: TeeKeyExchangeConfig,
#[serde(default = "default_max_attestation_age_secs")]
pub max_attestation_age_secs: u64,
#[serde(default)]
pub secret_injection: SecretInjectionPolicy,
#[serde(default)]
pub attestation_freshness: AttestationFreshnessPolicy,
#[serde(default)]
pub public_key_policy: TeePublicKeyPolicy,
#[serde(default)]
pub hybrid_routing_source: HybridRoutingSource,
}
fn default_max_attestation_age_secs() -> u64 {
3600
}
impl Default for SecretInjectionPolicy {
fn default() -> Self {
Self::EnvOrSealed
}
}
impl Default for TeeConfig {
fn default() -> Self {
Self {
requirement: TeeRequirement::default(),
mode: TeeMode::default(),
provider_selector: TeeProviderSelector::default(),
key_exchange: TeeKeyExchangeConfig::default(),
max_attestation_age_secs: default_max_attestation_age_secs(),
secret_injection: SecretInjectionPolicy::default(),
attestation_freshness: AttestationFreshnessPolicy::default(),
public_key_policy: TeePublicKeyPolicy::default(),
hybrid_routing_source: HybridRoutingSource::default(),
}
}
}
impl TeeConfig {
pub fn builder() -> TeeConfigBuilder {
TeeConfigBuilder::default()
}
pub fn is_enabled(&self) -> bool {
self.mode != TeeMode::Disabled
}
pub fn lifecycle_policy(&self) -> RuntimeLifecyclePolicy {
if self.is_enabled() {
RuntimeLifecyclePolicy::CloudManaged
} else {
RuntimeLifecyclePolicy::Container
}
}
pub fn validate(&self) -> Result<(), TeeError> {
if self.requirement == TeeRequirement::Required && self.mode == TeeMode::Disabled {
return Err(TeeError::Config(
"TEE requirement is Required but mode is Disabled".to_string(),
));
}
if self.mode != TeeMode::Disabled
&& self.secret_injection != SecretInjectionPolicy::SealedOnly
{
return Err(TeeError::Config(
"TEE-enabled configs must use SealedOnly secret injection".to_string(),
));
}
Ok(())
}
}
#[derive(Deserialize)]
struct TeeConfigRaw {
requirement: TeeRequirement,
mode: TeeMode,
provider_selector: TeeProviderSelector,
key_exchange: TeeKeyExchangeConfig,
#[serde(default = "default_max_attestation_age_secs")]
max_attestation_age_secs: u64,
#[serde(default)]
secret_injection: SecretInjectionPolicy,
#[serde(default)]
attestation_freshness: AttestationFreshnessPolicy,
#[serde(default)]
public_key_policy: TeePublicKeyPolicy,
#[serde(default)]
hybrid_routing_source: HybridRoutingSource,
}
impl TryFrom<TeeConfigRaw> for TeeConfig {
type Error = TeeError;
fn try_from(raw: TeeConfigRaw) -> Result<Self, Self::Error> {
let config = TeeConfig {
requirement: raw.requirement,
mode: raw.mode,
provider_selector: raw.provider_selector,
key_exchange: raw.key_exchange,
max_attestation_age_secs: raw.max_attestation_age_secs,
secret_injection: raw.secret_injection,
attestation_freshness: raw.attestation_freshness,
public_key_policy: raw.public_key_policy,
hybrid_routing_source: raw.hybrid_routing_source,
};
config.validate()?;
Ok(config)
}
}
#[derive(Debug, Default)]
pub struct TeeConfigBuilder {
requirement: Option<TeeRequirement>,
mode: Option<TeeMode>,
provider_selector: Option<TeeProviderSelector>,
key_exchange: Option<TeeKeyExchangeConfig>,
max_attestation_age_secs: Option<u64>,
attestation_freshness: Option<AttestationFreshnessPolicy>,
public_key_policy: Option<TeePublicKeyPolicy>,
hybrid_routing_source: Option<HybridRoutingSource>,
}
impl TeeConfigBuilder {
pub fn requirement(mut self, requirement: TeeRequirement) -> Self {
self.requirement = Some(requirement);
self
}
pub fn mode(mut self, mode: TeeMode) -> Self {
self.mode = Some(mode);
self
}
pub fn provider_selector(mut self, selector: TeeProviderSelector) -> Self {
self.provider_selector = Some(selector);
self
}
pub fn allow_providers(mut self, providers: impl IntoIterator<Item = TeeProvider>) -> Self {
self.provider_selector = Some(TeeProviderSelector::AllowList(
providers.into_iter().collect(),
));
self
}
pub fn key_exchange(mut self, config: TeeKeyExchangeConfig) -> Self {
self.key_exchange = Some(config);
self
}
pub fn max_attestation_age_secs(mut self, secs: u64) -> Self {
self.max_attestation_age_secs = Some(secs);
self
}
pub fn attestation_freshness(mut self, policy: AttestationFreshnessPolicy) -> Self {
self.attestation_freshness = Some(policy);
self
}
pub fn public_key_policy(mut self, policy: TeePublicKeyPolicy) -> Self {
self.public_key_policy = Some(policy);
self
}
pub fn hybrid_routing_source(mut self, source: HybridRoutingSource) -> Self {
self.hybrid_routing_source = Some(source);
self
}
pub fn build(self) -> Result<TeeConfig, TeeError> {
let mode = self.mode.unwrap_or_default();
let requirement = self.requirement.unwrap_or_default();
if requirement == TeeRequirement::Required && mode == TeeMode::Disabled {
return Err(TeeError::Config(
"TEE requirement is Required but mode is Disabled".to_string(),
));
}
let secret_injection = if mode != TeeMode::Disabled {
SecretInjectionPolicy::SealedOnly
} else {
SecretInjectionPolicy::EnvOrSealed
};
let attestation_freshness = self.attestation_freshness.unwrap_or_default();
let hybrid_routing_source = self.hybrid_routing_source.unwrap_or_default();
Ok(TeeConfig {
requirement,
mode,
provider_selector: self.provider_selector.unwrap_or_default(),
key_exchange: self.key_exchange.unwrap_or_default(),
max_attestation_age_secs: self
.max_attestation_age_secs
.unwrap_or_else(default_max_attestation_age_secs),
secret_injection,
attestation_freshness,
public_key_policy: self.public_key_policy.unwrap_or_default(),
hybrid_routing_source,
})
}
}