Skip to main content

dstack_sdk_types/
dstack.rs

1// SPDX-FileCopyrightText: © 2025 Daniel Sharifi <daniel.sharifi@nearone.org>
2//
3// SPDX-License-Identifier: Apache-2.0
4
5use alloc::{
6    collections::BTreeMap,
7    string::{String, ToString},
8    vec::Vec,
9};
10use anyhow::{Context as _, Result};
11use hex::{encode as hex_encode, FromHexError};
12use serde::{Deserialize, Serialize};
13use serde_json::{from_str, Value};
14use sha2::Digest;
15
16#[cfg(feature = "borsh_schema")]
17use borsh::BorshSchema;
18#[cfg(feature = "borsh")]
19use borsh::{BorshDeserialize, BorshSerialize};
20
21const INIT_MR: &str = "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000";
22
23fn replay_rtmr(history: Vec<String>) -> Result<String, FromHexError> {
24    if history.is_empty() {
25        return Ok(INIT_MR.to_string());
26    }
27    let mut mr = hex::decode(INIT_MR)?;
28    for content in history {
29        let mut content_bytes = hex::decode(content)?;
30        if content_bytes.len() < 48 {
31            content_bytes.resize(48, 0);
32        }
33        mr.extend_from_slice(&content_bytes);
34        mr = sha2::Sha384::digest(&mr).to_vec();
35    }
36    Ok(hex_encode(mr))
37}
38
39/// Represents an event log entry in the system
40#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serialize, Deserialize)]
41#[cfg_attr(feature = "borsh", derive(BorshSerialize, BorshDeserialize))]
42#[cfg_attr(feature = "borsh_schema", derive(BorshSchema))]
43pub struct EventLog {
44    /// The index of the IMR (Integrity Measurement Register)
45    pub imr: u32,
46    /// The type of event being logged
47    pub event_type: u32,
48    /// The cryptographic digest of the event
49    pub digest: String,
50    /// The type of event as a string
51    pub event: String,
52    /// The payload data associated with the event
53    pub event_payload: String,
54}
55
56/// Configuration for TLS key generation
57#[derive(Debug, bon::Builder, Serialize, Deserialize)]
58#[cfg_attr(feature = "borsh", derive(BorshSerialize, BorshDeserialize))]
59#[cfg_attr(feature = "borsh_schema", derive(BorshSchema))]
60pub struct TlsKeyConfig {
61    /// The subject name for the certificate
62    #[builder(into, default = String::new())]
63    pub subject: String,
64    /// Alternative names for the certificate
65    #[builder(default = Vec::new())]
66    pub alt_names: Vec<String>,
67    /// Whether the key should be used for remote attestation TLS
68    #[builder(default = false)]
69    pub usage_ra_tls: bool,
70    /// Whether the key should be used for server authentication
71    #[builder(default = true)]
72    pub usage_server_auth: bool,
73    /// Whether the key should be used for client authentication
74    #[builder(default = false)]
75    pub usage_client_auth: bool,
76}
77
78/// Response containing a key and its signature chain
79#[derive(Debug, Serialize, Deserialize)]
80#[cfg_attr(feature = "borsh", derive(BorshSerialize, BorshDeserialize))]
81#[cfg_attr(feature = "borsh_schema", derive(BorshSchema))]
82pub struct GetKeyResponse {
83    /// The key in hexadecimal format
84    pub key: String,
85    /// The chain of signatures verifying the key
86    pub signature_chain: Vec<String>,
87}
88
89impl GetKeyResponse {
90    pub fn decode_key(&self) -> Result<Vec<u8>, FromHexError> {
91        hex::decode(&self.key)
92    }
93
94    pub fn decode_signature_chain(&self) -> Result<Vec<Vec<u8>>, FromHexError> {
95        self.signature_chain.iter().map(hex::decode).collect()
96    }
97}
98
99/// Response containing a quote and associated event log
100#[derive(Debug, Serialize, Deserialize)]
101#[cfg_attr(feature = "borsh", derive(BorshSerialize, BorshDeserialize))]
102#[cfg_attr(feature = "borsh_schema", derive(BorshSchema))]
103pub struct GetQuoteResponse {
104    /// The attestation quote in hexadecimal format
105    pub quote: String,
106    /// The event log associated with the quote
107    pub event_log: String,
108    /// The report data
109    #[serde(default)]
110    pub report_data: String,
111    /// VM configuration
112    #[serde(default)]
113    pub vm_config: String,
114}
115
116/// Response containing a versioned attestation
117#[derive(Debug, Serialize, Deserialize)]
118#[cfg_attr(feature = "borsh", derive(BorshSerialize, BorshDeserialize))]
119#[cfg_attr(feature = "borsh_schema", derive(BorshSchema))]
120pub struct AttestResponse {
121    /// The attestation in hexadecimal format
122    pub attestation: String,
123}
124
125impl AttestResponse {
126    pub fn decode_attestation(&self) -> Result<Vec<u8>, FromHexError> {
127        hex::decode(&self.attestation)
128    }
129}
130
131impl GetQuoteResponse {
132    pub fn decode_quote(&self) -> Result<Vec<u8>, FromHexError> {
133        hex::decode(&self.quote)
134    }
135
136    pub fn decode_event_log(&self) -> Result<Vec<EventLog>, serde_json::Error> {
137        serde_json::from_str(&self.event_log)
138    }
139
140    pub fn replay_rtmrs(&self) -> Result<BTreeMap<u8, String>> {
141        let parsed_event_log: Vec<EventLog> = self.decode_event_log()?;
142        let mut rtmrs = BTreeMap::new();
143        for idx in 0..4 {
144            let mut history = Vec::new();
145            for event in &parsed_event_log {
146                if event.imr == idx {
147                    history.push(event.digest.clone());
148                }
149            }
150            rtmrs.insert(
151                idx as u8,
152                replay_rtmr(history)
153                    .ok()
154                    .context("Invalid digest in event log")?,
155            );
156        }
157        Ok(rtmrs)
158    }
159}
160
161/// Response containing instance information and attestation data
162#[derive(Debug, Serialize, Deserialize)]
163#[cfg_attr(feature = "borsh", derive(BorshSerialize, BorshDeserialize))]
164#[cfg_attr(feature = "borsh_schema", derive(BorshSchema))]
165pub struct InfoResponse {
166    /// The application identifier
167    pub app_id: String,
168    /// The instance identifier
169    pub instance_id: String,
170    /// The application certificate
171    pub app_cert: String,
172    /// Trusted Computing Base information
173    pub tcb_info: TcbInfo,
174    /// The name of the application
175    pub app_name: String,
176    /// The device identifier
177    pub device_id: String,
178    /// The aggregated measurement register
179    #[serde(default)]
180    pub mr_aggregated: String,
181    /// The hash of the OS image
182    /// Optional: empty if OS image is not measured by KMS
183    #[serde(default)]
184    pub os_image_hash: String,
185    /// Information about the key provider
186    pub key_provider_info: String,
187    /// The hash of the compose configuration
188    pub compose_hash: String,
189    /// VM configuration
190    #[serde(default)]
191    pub vm_config: String,
192}
193
194impl InfoResponse {
195    pub fn validated_from_value(mut obj: Value) -> Result<Self, serde_json::Error> {
196        if let Some(tcb_info_str) = obj.get("tcb_info").and_then(Value::as_str) {
197            let parsed_tcb_info: TcbInfo = from_str(tcb_info_str)?;
198            obj["tcb_info"] = serde_json::to_value(parsed_tcb_info)?;
199        }
200        serde_json::from_value(obj)
201    }
202}
203
204/// Trusted Computing Base information structure
205#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serialize, Deserialize)]
206#[cfg_attr(feature = "borsh", derive(BorshSerialize, BorshDeserialize))]
207#[cfg_attr(feature = "borsh_schema", derive(BorshSchema))]
208pub struct TcbInfo {
209    /// The measurement root of trust
210    pub mrtd: String,
211    /// The value of RTMR0 (Runtime Measurement Register 0)
212    pub rtmr0: String,
213    /// The value of RTMR1 (Runtime Measurement Register 1)
214    pub rtmr1: String,
215    /// The value of RTMR2 (Runtime Measurement Register 2)
216    pub rtmr2: String,
217    /// The value of RTMR3 (Runtime Measurement Register 3)
218    pub rtmr3: String,
219    /// The hash of the OS image. This is empty if the OS image is not measured by KMS.
220    #[serde(default)]
221    pub os_image_hash: String,
222    /// The hash of the compose configuration
223    pub compose_hash: String,
224    /// The device identifier
225    pub device_id: String,
226    /// The app compose
227    pub app_compose: String,
228    /// The event log entries
229    pub event_log: Vec<EventLog>,
230}
231
232/// Response containing TLS key and certificate chain
233#[derive(Debug, Serialize, Deserialize)]
234#[cfg_attr(feature = "borsh", derive(BorshSerialize, BorshDeserialize))]
235#[cfg_attr(feature = "borsh_schema", derive(BorshSchema))]
236pub struct GetTlsKeyResponse {
237    /// The TLS key in hexadecimal format
238    pub key: String,
239    /// The chain of certificates
240    pub certificate_chain: Vec<String>,
241}
242
243/// Response from a Sign request
244#[derive(Debug, Serialize, Deserialize)]
245#[cfg_attr(feature = "borsh", derive(BorshSerialize, BorshDeserialize))]
246#[cfg_attr(feature = "borsh_schema", derive(BorshSchema))]
247pub struct SignResponse {
248    /// The signature in hexadecimal format
249    pub signature: String,
250    /// The chain of signatures in hexadecimal format
251    pub signature_chain: Vec<String>,
252    /// The public key in hexadecimal format
253    pub public_key: String,
254}
255
256impl SignResponse {
257    /// Decodes the signature from hex to bytes
258    pub fn decode_signature(&self) -> Result<Vec<u8>, FromHexError> {
259        hex::decode(&self.signature)
260    }
261
262    /// Decodes the public key from hex to bytes
263    pub fn decode_public_key(&self) -> Result<Vec<u8>, FromHexError> {
264        hex::decode(&self.public_key)
265    }
266
267    /// Decodes the signature chain from hex to bytes
268    pub fn decode_signature_chain(&self) -> Result<Vec<Vec<u8>>, FromHexError> {
269        self.signature_chain.iter().map(hex::decode).collect()
270    }
271}
272
273/// Response from a Verify request
274#[derive(Debug, Serialize, Deserialize)]
275#[cfg_attr(feature = "borsh", derive(BorshSerialize, BorshDeserialize))]
276#[cfg_attr(feature = "borsh_schema", derive(BorshSchema))]
277pub struct VerifyResponse {
278    /// Whether the signature is valid
279    pub valid: bool,
280}
281
282/// Response from a Version request
283#[derive(Debug, Clone, Serialize, Deserialize)]
284#[cfg_attr(feature = "borsh", derive(BorshSerialize, BorshDeserialize))]
285#[cfg_attr(feature = "borsh_schema", derive(BorshSchema))]
286pub struct VersionResponse {
287    /// The dstack version (e.g. "0.5.7")
288    pub version: String,
289    /// Git revision
290    pub rev: String,
291}