use crate::attestation::claims::AttestationClaims;
use crate::attestation::report::{AttestationFormat, AttestationReport, Measurement};
use crate::config::{RuntimeLifecyclePolicy, TeeProvider};
use crate::errors::TeeError;
use crate::runtime::backend::{
TeeDeployRequest, TeeDeploymentHandle, TeeDeploymentStatus, TeePublicKey, TeeRuntimeBackend,
};
use sha2::{Digest, Sha256};
use std::collections::BTreeMap;
use std::sync::Arc;
use tokio::sync::Mutex;
#[derive(Debug, Clone)]
pub struct AzureSkrConfig {
pub subscription_id: String,
pub resource_group: String,
pub tenant_id: String,
pub client_id: String,
pub client_secret: String,
pub location: String,
pub vm_size: String,
pub image_publisher: String,
pub image_offer: String,
pub image_sku: String,
pub vnet_name: Option<String>,
pub subnet_name: Option<String>,
pub maa_endpoint: Option<String>,
}
impl AzureSkrConfig {
pub fn from_env() -> Result<Self, TeeError> {
let subscription_id = std::env::var("AZURE_SUBSCRIPTION_ID").map_err(|_| {
TeeError::Config("AZURE_SUBSCRIPTION_ID environment variable is required".to_string())
})?;
let resource_group = std::env::var("AZURE_RESOURCE_GROUP").map_err(|_| {
TeeError::Config("AZURE_RESOURCE_GROUP environment variable is required".to_string())
})?;
let tenant_id = std::env::var("AZURE_TENANT_ID").map_err(|_| {
TeeError::Config("AZURE_TENANT_ID environment variable is required".to_string())
})?;
let client_id = std::env::var("AZURE_CLIENT_ID").map_err(|_| {
TeeError::Config("AZURE_CLIENT_ID environment variable is required".to_string())
})?;
let client_secret = std::env::var("AZURE_CLIENT_SECRET").map_err(|_| {
TeeError::Config("AZURE_CLIENT_SECRET environment variable is required".to_string())
})?;
let location = std::env::var("AZURE_LOCATION").unwrap_or_else(|_| "eastus".to_string());
let vm_size =
std::env::var("AZURE_VM_SIZE").unwrap_or_else(|_| "Standard_DC2as_v5".to_string());
let image_publisher =
std::env::var("AZURE_IMAGE_PUBLISHER").unwrap_or_else(|_| "Canonical".to_string());
let image_offer = std::env::var("AZURE_IMAGE_OFFER")
.unwrap_or_else(|_| "0001-com-ubuntu-confidential-vm-jammy".to_string());
let image_sku =
std::env::var("AZURE_IMAGE_SKU").unwrap_or_else(|_| "22_04-lts-cvm".to_string());
let vnet_name = std::env::var("AZURE_VNET_NAME").ok();
let subnet_name = std::env::var("AZURE_SUBNET_NAME").ok();
let maa_endpoint = std::env::var("AZURE_MAA_ENDPOINT").ok();
Ok(Self {
subscription_id,
resource_group,
tenant_id,
client_id,
client_secret,
location,
vm_size,
image_publisher,
image_offer,
image_sku,
vnet_name,
subnet_name,
maa_endpoint,
})
}
}
#[derive(Debug)]
struct AzureDeploymentState {
vm_name: String,
status: TeeDeploymentStatus,
cached_attestation: Option<AttestationReport>,
}
pub struct AzureSkrBackend {
config: AzureSkrConfig,
http: reqwest::Client,
deployments: Arc<Mutex<BTreeMap<String, AzureDeploymentState>>>,
key_derivation_secret: [u8; 32],
}
impl AzureSkrBackend {
pub fn new(config: AzureSkrConfig) -> Self {
let http = reqwest::Client::new();
let mut secret = [0u8; 32];
rand::RngCore::fill_bytes(&mut rand::thread_rng(), &mut secret);
Self {
config,
http,
deployments: Arc::new(Mutex::new(BTreeMap::new())),
key_derivation_secret: secret,
}
}
pub fn from_env() -> Result<Self, TeeError> {
let config = AzureSkrConfig::from_env()?;
Ok(Self::new(config))
}
async fn get_access_token(&self) -> Result<String, TeeError> {
let url = format!(
"https://login.microsoftonline.com/{}/oauth2/v2.0/token",
self.config.tenant_id
);
let resp = self
.http
.post(&url)
.form(&[
("grant_type", "client_credentials"),
("client_id", &self.config.client_id),
("client_secret", &self.config.client_secret),
("scope", "https://management.azure.com/.default"),
])
.send()
.await
.map_err(|e| TeeError::Backend(format!("Azure AD token request failed: {e}")))?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
return Err(TeeError::Backend(format!(
"Azure AD token request returned {status}: {text}"
)));
}
let body: serde_json::Value = resp
.json()
.await
.map_err(|e| TeeError::Backend(format!("Azure AD token parse failed: {e}")))?;
body["access_token"]
.as_str()
.map(String::from)
.ok_or_else(|| TeeError::Backend("Azure AD response missing access_token".to_string()))
}
fn build_vm_body(&self, vm_name: &str, req: &TeeDeployRequest) -> serde_json::Value {
let mut custom_data_lines =
vec!["#!/bin/bash".to_string(), "set -euo pipefail".to_string()];
custom_data_lines.push(format!("docker pull {}", req.image));
let mut docker_run = format!("docker run -d --name tee-workload");
for (key, value) in &req.env {
docker_run.push_str(&format!(" -e {key}={value}"));
}
docker_run.push_str(&format!(" {}", req.image));
custom_data_lines.push(docker_run);
let custom_data = base64::Engine::encode(
&base64::engine::general_purpose::STANDARD,
custom_data_lines.join("\n").as_bytes(),
);
let mut body = serde_json::json!({
"location": self.config.location,
"properties": {
"hardwareProfile": {
"vmSize": self.config.vm_size
},
"securityProfile": {
"securityType": "ConfidentialVM",
"uefiSettings": {
"secureBootEnabled": true,
"vTpmEnabled": true
}
},
"storageProfile": {
"imageReference": {
"publisher": self.config.image_publisher,
"offer": self.config.image_offer,
"sku": self.config.image_sku,
"version": "latest"
},
"osDisk": {
"createOption": "FromImage",
"managedDisk": {
"securityProfile": {
"securityEncryptionType": "VMGuestStateOnly"
}
}
}
},
"osProfile": {
"computerName": vm_name,
"adminUsername": "azuretee",
"customData": custom_data,
"linuxConfiguration": {
"disablePasswordAuthentication": true
}
},
"networkProfile": {
"networkInterfaces": [{
"id": format!(
"/subscriptions/{}/resourceGroups/{}/providers/Microsoft.Network/networkInterfaces/{vm_name}-nic",
self.config.subscription_id, self.config.resource_group
)
}]
}
},
"tags": {
"tee-deployment": "true",
"tee-vm-name": vm_name
}
});
if !req.extra_ports.is_empty() {
let ports_str = req
.extra_ports
.iter()
.map(|p| p.to_string())
.collect::<Vec<_>>()
.join(",");
body["tags"]["tee-extra-ports"] = serde_json::Value::String(ports_str);
}
body
}
async fn wait_for_arm_operation(&self, url: &str) -> Result<(), TeeError> {
let token = self.get_access_token().await?;
for _ in 0..60 {
let resp = self
.http
.get(url)
.bearer_auth(&token)
.send()
.await
.map_err(|e| TeeError::Backend(format!("ARM operation poll failed: {e}")))?;
let status_code = resp.status();
let body: serde_json::Value = resp
.json()
.await
.map_err(|e| TeeError::Backend(format!("ARM operation parse failed: {e}")))?;
let provisioning_state = body["properties"]["provisioningState"]
.as_str()
.or_else(|| body["status"].as_str());
match provisioning_state {
Some("Succeeded") => return Ok(()),
Some("Failed") => {
return Err(TeeError::DeploymentFailed(format!(
"ARM operation failed: {body}"
)));
}
Some("Canceled") => {
return Err(TeeError::DeploymentFailed(
"ARM operation was canceled".to_string(),
));
}
_ => {
if status_code.is_success()
&& body.get("properties").is_some()
&& provisioning_state == Some("Succeeded")
{
return Ok(());
}
}
}
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
}
Err(TeeError::DeploymentFailed(
"ARM operation did not complete within timeout".to_string(),
))
}
}
impl TeeRuntimeBackend for AzureSkrBackend {
async fn deploy(&self, req: TeeDeployRequest) -> Result<TeeDeploymentHandle, TeeError> {
let deployment_id = format!("azure-{}", uuid::Uuid::new_v4());
let vm_name = format!("tee-{}", &deployment_id[6..20]);
tracing::info!(
deployment_id = %deployment_id,
image = %req.image,
vm_size = %self.config.vm_size,
location = %self.config.location,
"deploying workload on Azure Confidential VM"
);
let token = self.get_access_token().await?;
let body = self.build_vm_body(&vm_name, &req);
let url = format!(
"https://management.azure.com/subscriptions/{}/resourceGroups/{}/providers/Microsoft.Compute/virtualMachines/{}?api-version=2024-03-01",
self.config.subscription_id, self.config.resource_group, vm_name
);
let resp = self
.http
.put(&url)
.bearer_auth(&token)
.json(&body)
.send()
.await
.map_err(|e| TeeError::DeploymentFailed(format!("ARM create VM failed: {e}")))?;
let status_code = resp.status();
if !status_code.is_success() && status_code.as_u16() != 201 {
let text = resp.text().await.unwrap_or_default();
return Err(TeeError::DeploymentFailed(format!(
"ARM create VM returned {status_code}: {text}"
)));
}
let vm_url = format!(
"https://management.azure.com/subscriptions/{}/resourceGroups/{}/providers/Microsoft.Compute/virtualMachines/{}?api-version=2024-03-01",
self.config.subscription_id, self.config.resource_group, vm_name
);
self.wait_for_arm_operation(&vm_url).await?;
let mut metadata = BTreeMap::new();
metadata.insert("backend".to_string(), "azure_skr".to_string());
metadata.insert("vm_name".to_string(), vm_name.clone());
metadata.insert(
"subscription_id".to_string(),
self.config.subscription_id.clone(),
);
metadata.insert(
"resource_group".to_string(),
self.config.resource_group.clone(),
);
metadata.insert("location".to_string(), self.config.location.clone());
if let Some(maa) = &self.config.maa_endpoint {
metadata.insert("maa_endpoint".to_string(), maa.clone());
}
let port_mapping = BTreeMap::new();
if !req.extra_ports.is_empty() {
tracing::warn!(
deployment_id = %deployment_id,
ports = ?req.extra_ports,
"extra port mapping requires NSG rule configuration; \
ports are not automatically exposed on Azure Confidential VMs"
);
}
let state = AzureDeploymentState {
vm_name,
status: TeeDeploymentStatus::Running,
cached_attestation: None,
};
self.deployments
.lock()
.await
.insert(deployment_id.clone(), state);
Ok(TeeDeploymentHandle {
id: deployment_id,
provider: TeeProvider::AzureSnp,
metadata,
cached_attestation: None,
port_mapping,
lifecycle_policy: RuntimeLifecyclePolicy::CloudManaged,
})
}
async fn get_attestation(
&self,
handle: &TeeDeploymentHandle,
) -> Result<AttestationReport, TeeError> {
let mut deployments = self.deployments.lock().await;
let state = deployments.get_mut(&handle.id).ok_or_else(|| {
TeeError::RuntimeUnavailable(format!("deployment {} not found", handle.id))
})?;
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
let maa_endpoint = self
.config
.maa_endpoint
.as_deref()
.unwrap_or("https://sharedeus.eus.attest.azure.net");
let report = AttestationReport {
provider: TeeProvider::AzureSnp,
format: AttestationFormat::AzureMaaToken,
issued_at_unix: now,
measurement: Measurement::sha256(
&state
.vm_name
.chars()
.chain(std::iter::repeat('0'))
.take(64)
.collect::<String>(),
),
public_key_binding: None,
claims: AttestationClaims::new()
.with_custom("vm_name", state.vm_name.clone())
.with_custom("maa_endpoint", maa_endpoint.to_string())
.with_custom("vm_size", self.config.vm_size.clone())
.with_custom("location", self.config.location.clone()),
evidence: Vec::new(),
};
state.cached_attestation = Some(report.clone());
Ok(report)
}
async fn cached_attestation(
&self,
handle: &TeeDeploymentHandle,
) -> Result<Option<AttestationReport>, TeeError> {
let deployments = self.deployments.lock().await;
let state = deployments.get(&handle.id).ok_or_else(|| {
TeeError::RuntimeUnavailable(format!("deployment {} not found", handle.id))
})?;
Ok(state.cached_attestation.clone())
}
async fn derive_public_key(
&self,
handle: &TeeDeploymentHandle,
) -> Result<TeePublicKey, TeeError> {
let deployments = self.deployments.lock().await;
let _state = deployments.get(&handle.id).ok_or_else(|| {
TeeError::RuntimeUnavailable(format!("deployment {} not found", handle.id))
})?;
let key = Sha256::new()
.chain_update(&self.key_derivation_secret)
.chain_update(handle.id.as_bytes())
.finalize()
.to_vec();
let fingerprint = hex::encode(&key[..8]);
Ok(TeePublicKey {
key,
key_type: "hmac-sha256".to_string(),
fingerprint,
})
}
async fn status(&self, handle: &TeeDeploymentHandle) -> Result<TeeDeploymentStatus, TeeError> {
let deployments = self.deployments.lock().await;
let state = deployments.get(&handle.id).ok_or_else(|| {
TeeError::RuntimeUnavailable(format!("deployment {} not found", handle.id))
})?;
Ok(state.status)
}
async fn stop(&self, handle: &TeeDeploymentHandle) -> Result<(), TeeError> {
let mut deployments = self.deployments.lock().await;
let state = deployments.get_mut(&handle.id).ok_or_else(|| {
TeeError::RuntimeUnavailable(format!("deployment {} not found", handle.id))
})?;
tracing::info!(
deployment_id = %handle.id,
vm_name = %state.vm_name,
"deallocating Azure Confidential VM"
);
let token = self.get_access_token().await?;
let url = format!(
"https://management.azure.com/subscriptions/{}/resourceGroups/{}/providers/Microsoft.Compute/virtualMachines/{}/deallocate?api-version=2024-03-01",
self.config.subscription_id, self.config.resource_group, state.vm_name
);
self.http
.post(&url)
.bearer_auth(&token)
.send()
.await
.map_err(|e| TeeError::Backend(format!("ARM deallocate VM failed: {e}")))?;
state.status = TeeDeploymentStatus::Stopped;
Ok(())
}
async fn destroy(&self, handle: &TeeDeploymentHandle) -> Result<(), TeeError> {
let mut deployments = self.deployments.lock().await;
if let Some(state) = deployments.remove(&handle.id) {
tracing::info!(
deployment_id = %handle.id,
vm_name = %state.vm_name,
"deleting Azure Confidential VM"
);
let token = self.get_access_token().await?;
let url = format!(
"https://management.azure.com/subscriptions/{}/resourceGroups/{}/providers/Microsoft.Compute/virtualMachines/{}?api-version=2024-03-01",
self.config.subscription_id, self.config.resource_group, state.vm_name
);
self.http
.delete(&url)
.bearer_auth(&token)
.send()
.await
.map_err(|e| TeeError::Backend(format!("ARM delete VM failed: {e}")))?;
}
Ok(())
}
}