use alloc::{
collections::BTreeMap,
string::{String, ToString},
vec::Vec,
};
use anyhow::{bail, Context as _, Result};
use hex::{encode as hex_encode, FromHexError};
use serde::{Deserialize, Serialize};
use sha2::Digest;
#[cfg(feature = "borsh_schema")]
use borsh::BorshSchema;
#[cfg(feature = "borsh")]
use borsh::{BorshDeserialize, BorshSerialize};
use crate::dstack::EventLog;
const INIT_MR: &str = "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000";
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "borsh", derive(BorshSerialize, BorshDeserialize))]
#[cfg_attr(feature = "borsh_schema", derive(BorshSchema))]
pub enum QuoteHashAlgorithm {
Sha256,
Sha384,
Sha512,
Sha3_256,
Sha3_384,
Sha3_512,
Keccak256,
Keccak384,
Keccak512,
Raw,
}
impl QuoteHashAlgorithm {
pub fn as_str(&self) -> &'static str {
match self {
Self::Sha256 => "sha256",
Self::Sha384 => "sha384",
Self::Sha512 => "sha512",
Self::Sha3_256 => "sha3-256",
Self::Sha3_384 => "sha3-384",
Self::Sha3_512 => "sha3-512",
Self::Keccak256 => "keccak256",
Self::Keccak384 => "keccak384",
Self::Keccak512 => "keccak512",
Self::Raw => "raw",
}
}
}
fn replay_rtmr(history: Vec<String>) -> Result<String, FromHexError> {
if history.is_empty() {
return Ok(INIT_MR.to_string());
}
let mut mr = hex::decode(INIT_MR)?;
for content in history {
let mut content_bytes = hex::decode(content)?;
if content_bytes.len() < 48 {
content_bytes.resize(48, 0);
}
mr.extend_from_slice(&content_bytes);
mr = sha2::Sha384::digest(&mr).to_vec();
}
Ok(hex_encode(mr))
}
#[derive(Debug, Serialize, Deserialize)]
#[cfg_attr(feature = "borsh", derive(BorshSerialize, BorshDeserialize))]
#[cfg_attr(feature = "borsh_schema", derive(BorshSchema))]
pub struct DeriveKeyResponse {
pub key: String,
pub certificate_chain: Vec<String>,
}
impl DeriveKeyResponse {
pub fn decode_key(&self) -> Result<Vec<u8>, anyhow::Error> {
use pkcs8::der::asn1::{Int, OctetString};
use pkcs8::der::{Decode, Document, Reader, SliceReader};
use pkcs8::PrivateKeyInfo;
let key_content = self.key.trim();
let (label, doc) = Document::from_pem(key_content)
.map_err(|e| anyhow::anyhow!("Failed to parse PEM: {:?}", e))?;
if label != "PRIVATE KEY" {
bail!("Expected PRIVATE KEY PEM label, got: {}", label);
}
let private_key_info = PrivateKeyInfo::from_der(doc.as_bytes())
.map_err(|e| anyhow::anyhow!("Failed to parse PKCS#8 private key: {:?}", e))?;
let private_key_data = private_key_info.private_key;
let mut reader = SliceReader::new(private_key_data)
.map_err(|e| anyhow::anyhow!("Failed to create reader: {:?}", e))?;
let key_bytes = reader
.sequence(|reader| {
let _version: Int = reader.decode()?;
let private_key: OctetString = reader.decode()?;
while !reader.is_finished() {
let _: pkcs8::der::Any = reader.decode()?;
}
Ok(private_key.as_bytes().to_vec())
})
.map_err(|e| anyhow::anyhow!("Failed to parse ECPrivateKey structure: {:?}", e))?;
if key_bytes.len() != 32 {
bail!(
"Expected 32-byte ECDSA P-256 private key, got {} bytes",
key_bytes.len()
);
}
Ok(key_bytes)
}
}
#[derive(Debug, Serialize, Deserialize)]
#[cfg_attr(feature = "borsh", derive(BorshSerialize, BorshDeserialize))]
#[cfg_attr(feature = "borsh_schema", derive(BorshSchema))]
pub struct TdxQuoteResponse {
pub quote: String,
pub event_log: String,
#[serde(default)]
pub hash_algorithm: Option<String>,
#[serde(default)]
pub prefix: Option<String>,
}
impl TdxQuoteResponse {
pub fn decode_quote(&self) -> Result<Vec<u8>, FromHexError> {
hex::decode(&self.quote)
}
pub fn decode_event_log(&self) -> Result<Vec<EventLog>, serde_json::Error> {
serde_json::from_str(&self.event_log)
}
pub fn replay_rtmrs(&self) -> Result<BTreeMap<u8, String>> {
let parsed_event_log: Vec<EventLog> = self.decode_event_log()?;
let mut rtmrs = BTreeMap::new();
for idx in 0..4 {
let mut history = Vec::new();
for event in &parsed_event_log {
if event.imr == idx {
history.push(event.digest.clone());
}
}
rtmrs.insert(
idx as u8,
replay_rtmr(history)
.ok()
.context("Invalid digest in event log")?,
);
}
Ok(rtmrs)
}
}
#[derive(Debug, Serialize, Deserialize)]
#[cfg_attr(feature = "borsh", derive(BorshSerialize, BorshDeserialize))]
#[cfg_attr(feature = "borsh_schema", derive(BorshSchema))]
pub struct TappdTcbInfo {
pub mrtd: String,
pub rtmr0: String,
pub rtmr1: String,
pub rtmr2: String,
pub rtmr3: String,
pub event_log: Vec<EventLog>,
pub app_compose: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[cfg_attr(feature = "borsh", derive(BorshSerialize, BorshDeserialize))]
#[cfg_attr(feature = "borsh_schema", derive(BorshSchema))]
pub struct TappdInfoResponse {
pub app_id: String,
pub instance_id: String,
pub app_cert: String,
pub tcb_info: TappdTcbInfo,
pub app_name: String,
}