dstack_sdk_types/
dstack.rs

1// SPDX-FileCopyrightText: © 2025 Daniel Sharifi <daniel.sharifi@nearone.org>
2//
3// SPDX-License-Identifier: Apache-2.0
4
5use alloc::{
6    collections::BTreeMap,
7    string::{String, ToString},
8    vec::Vec,
9};
10use anyhow::{Context as _, Result};
11use hex::{encode as hex_encode, FromHexError};
12use serde::{Deserialize, Serialize};
13use serde_json::{from_str, Value};
14use sha2::Digest;
15
16#[cfg(feature = "borsh_schema")]
17use borsh::BorshSchema;
18#[cfg(feature = "borsh")]
19use borsh::{BorshDeserialize, BorshSerialize};
20
21const INIT_MR: &str = "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000";
22
23fn replay_rtmr(history: Vec<String>) -> Result<String, FromHexError> {
24    if history.is_empty() {
25        return Ok(INIT_MR.to_string());
26    }
27    let mut mr = hex::decode(INIT_MR)?;
28    for content in history {
29        let mut content_bytes = hex::decode(content)?;
30        if content_bytes.len() < 48 {
31            content_bytes.resize(48, 0);
32        }
33        mr.extend_from_slice(&content_bytes);
34        mr = sha2::Sha384::digest(&mr).to_vec();
35    }
36    Ok(hex_encode(mr))
37}
38
39/// Represents an event log entry in the system
40#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serialize, Deserialize)]
41#[cfg_attr(feature = "borsh", derive(BorshSerialize, BorshDeserialize))]
42#[cfg_attr(feature = "borsh_schema", derive(BorshSchema))]
43pub struct EventLog {
44    /// The index of the IMR (Integrity Measurement Register)
45    pub imr: u32,
46    /// The type of event being logged
47    pub event_type: u32,
48    /// The cryptographic digest of the event
49    pub digest: String,
50    /// The type of event as a string
51    pub event: String,
52    /// The payload data associated with the event
53    pub event_payload: String,
54}
55
56/// Configuration for TLS key generation
57#[derive(Debug, bon::Builder, Serialize, Deserialize)]
58#[cfg_attr(feature = "borsh", derive(BorshSerialize, BorshDeserialize))]
59#[cfg_attr(feature = "borsh_schema", derive(BorshSchema))]
60pub struct TlsKeyConfig {
61    /// The subject name for the certificate
62    #[builder(into, default = String::new())]
63    pub subject: String,
64    /// Alternative names for the certificate
65    #[builder(default = Vec::new())]
66    pub alt_names: Vec<String>,
67    /// Whether the key should be used for remote attestation TLS
68    #[builder(default = false)]
69    pub usage_ra_tls: bool,
70    /// Whether the key should be used for server authentication
71    #[builder(default = true)]
72    pub usage_server_auth: bool,
73    /// Whether the key should be used for client authentication
74    #[builder(default = false)]
75    pub usage_client_auth: bool,
76}
77
78/// Response containing a key and its signature chain
79#[derive(Debug, Serialize, Deserialize)]
80#[cfg_attr(feature = "borsh", derive(BorshSerialize, BorshDeserialize))]
81#[cfg_attr(feature = "borsh_schema", derive(BorshSchema))]
82pub struct GetKeyResponse {
83    /// The key in hexadecimal format
84    pub key: String,
85    /// The chain of signatures verifying the key
86    pub signature_chain: Vec<String>,
87}
88
89impl GetKeyResponse {
90    pub fn decode_key(&self) -> Result<Vec<u8>, FromHexError> {
91        hex::decode(&self.key)
92    }
93
94    pub fn decode_signature_chain(&self) -> Result<Vec<Vec<u8>>, FromHexError> {
95        self.signature_chain.iter().map(hex::decode).collect()
96    }
97}
98
99/// Response containing a quote and associated event log
100#[derive(Debug, Serialize, Deserialize)]
101#[cfg_attr(feature = "borsh", derive(BorshSerialize, BorshDeserialize))]
102#[cfg_attr(feature = "borsh_schema", derive(BorshSchema))]
103pub struct GetQuoteResponse {
104    /// The attestation quote in hexadecimal format
105    pub quote: String,
106    /// The event log associated with the quote
107    pub event_log: String,
108    /// The report data
109    #[serde(default)]
110    pub report_data: String,
111    /// VM configuration
112    #[serde(default)]
113    pub vm_config: String,
114}
115
116impl GetQuoteResponse {
117    pub fn decode_quote(&self) -> Result<Vec<u8>, FromHexError> {
118        hex::decode(&self.quote)
119    }
120
121    pub fn decode_event_log(&self) -> Result<Vec<EventLog>, serde_json::Error> {
122        serde_json::from_str(&self.event_log)
123    }
124
125    pub fn replay_rtmrs(&self) -> Result<BTreeMap<u8, String>> {
126        let parsed_event_log: Vec<EventLog> = self.decode_event_log()?;
127        let mut rtmrs = BTreeMap::new();
128        for idx in 0..4 {
129            let mut history = Vec::new();
130            for event in &parsed_event_log {
131                if event.imr == idx {
132                    history.push(event.digest.clone());
133                }
134            }
135            rtmrs.insert(
136                idx as u8,
137                replay_rtmr(history)
138                    .ok()
139                    .context("Invalid digest in event log")?,
140            );
141        }
142        Ok(rtmrs)
143    }
144}
145
146/// Response containing instance information and attestation data
147#[derive(Debug, Serialize, Deserialize)]
148#[cfg_attr(feature = "borsh", derive(BorshSerialize, BorshDeserialize))]
149#[cfg_attr(feature = "borsh_schema", derive(BorshSchema))]
150pub struct InfoResponse {
151    /// The application identifier
152    pub app_id: String,
153    /// The instance identifier
154    pub instance_id: String,
155    /// The application certificate
156    pub app_cert: String,
157    /// Trusted Computing Base information
158    pub tcb_info: TcbInfo,
159    /// The name of the application
160    pub app_name: String,
161    /// The device identifier
162    pub device_id: String,
163    /// The aggregated measurement register
164    #[serde(default)]
165    pub mr_aggregated: String,
166    /// The hash of the OS image
167    /// Optional: empty if OS image is not measured by KMS
168    #[serde(default)]
169    pub os_image_hash: String,
170    /// Information about the key provider
171    pub key_provider_info: String,
172    /// The hash of the compose configuration
173    pub compose_hash: String,
174    /// VM configuration
175    #[serde(default)]
176    pub vm_config: String,
177}
178
179impl InfoResponse {
180    pub fn validated_from_value(mut obj: Value) -> Result<Self, serde_json::Error> {
181        if let Some(tcb_info_str) = obj.get("tcb_info").and_then(Value::as_str) {
182            let parsed_tcb_info: TcbInfo = from_str(tcb_info_str)?;
183            obj["tcb_info"] = serde_json::to_value(parsed_tcb_info)?;
184        }
185        serde_json::from_value(obj)
186    }
187}
188
189/// Trusted Computing Base information structure
190#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serialize, Deserialize)]
191#[cfg_attr(feature = "borsh", derive(BorshSerialize, BorshDeserialize))]
192#[cfg_attr(feature = "borsh_schema", derive(BorshSchema))]
193pub struct TcbInfo {
194    /// The measurement root of trust
195    pub mrtd: String,
196    /// The value of RTMR0 (Runtime Measurement Register 0)
197    pub rtmr0: String,
198    /// The value of RTMR1 (Runtime Measurement Register 1)
199    pub rtmr1: String,
200    /// The value of RTMR2 (Runtime Measurement Register 2)
201    pub rtmr2: String,
202    /// The value of RTMR3 (Runtime Measurement Register 3)
203    pub rtmr3: String,
204    /// The hash of the OS image. This is empty if the OS image is not measured by KMS.
205    #[serde(default)]
206    pub os_image_hash: String,
207    /// The hash of the compose configuration
208    pub compose_hash: String,
209    /// The device identifier
210    pub device_id: String,
211    /// The app compose
212    pub app_compose: String,
213    /// The event log entries
214    pub event_log: Vec<EventLog>,
215}
216
217/// Response containing TLS key and certificate chain
218#[derive(Debug, Serialize, Deserialize)]
219#[cfg_attr(feature = "borsh", derive(BorshSerialize, BorshDeserialize))]
220#[cfg_attr(feature = "borsh_schema", derive(BorshSchema))]
221pub struct GetTlsKeyResponse {
222    /// The TLS key in hexadecimal format
223    pub key: String,
224    /// The chain of certificates
225    pub certificate_chain: Vec<String>,
226}