dstack_sdk/
dstack_client.rs

1// SPDX-FileCopyrightText: © 2025 Created-for-a-purpose <rachitchahar@gmail.com>
2// SPDX-FileCopyrightText: © 2025 Daniel Sharifi <daniel.sharifi@nearone.org>
3// SPDX-FileCopyrightText: © 2025 tuddman <tuddman@users.noreply.github.com>
4//
5// SPDX-License-Identifier: Apache-2.0
6
7use anyhow::Result;
8use hex::encode as hex_encode;
9use http_client_unix_domain_socket::{ClientUnix, Method};
10use reqwest::Client;
11use serde::{de::DeserializeOwned, Serialize};
12use serde_json::{json, Value};
13use std::env;
14
15pub use dstack_sdk_types::dstack::*;
16
17fn get_endpoint(endpoint: Option<&str>) -> String {
18    if let Some(e) = endpoint {
19        return e.to_string();
20    }
21    if let Ok(sim_endpoint) = env::var("DSTACK_SIMULATOR_ENDPOINT") {
22        return sim_endpoint;
23    }
24    "/var/run/dstack.sock".to_string()
25}
26
27#[derive(Debug)]
28pub enum ClientKind {
29    Http,
30    Unix,
31}
32
33pub trait BaseClient {}
34
35/// The main client for interacting with the dstack service
36pub struct DstackClient {
37    /// The base URL for HTTP requests
38    base_url: String,
39    /// The endpoint for Unix domain socket communication
40    endpoint: String,
41    /// The type of client (HTTP or Unix domain socket)
42    client: ClientKind,
43}
44
45impl BaseClient for DstackClient {}
46
47impl DstackClient {
48    pub fn new(endpoint: Option<&str>) -> Self {
49        let endpoint = get_endpoint(endpoint);
50        let (base_url, client) = match endpoint {
51            ref e if e.starts_with("http://") || e.starts_with("https://") => {
52                (e.to_string(), ClientKind::Http)
53            }
54            _ => ("http://localhost".to_string(), ClientKind::Unix),
55        };
56
57        DstackClient {
58            base_url,
59            endpoint,
60            client,
61        }
62    }
63
64    async fn send_rpc_request<S: Serialize, D: DeserializeOwned>(
65        &self,
66        path: &str,
67        payload: &S,
68    ) -> anyhow::Result<D> {
69        match &self.client {
70            ClientKind::Http => {
71                let client = Client::new();
72                let url = format!(
73                    "{}/{}",
74                    self.base_url.trim_end_matches('/'),
75                    path.trim_start_matches('/')
76                );
77                let res = client
78                    .post(&url)
79                    .json(payload)
80                    .header("Content-Type", "application/json")
81                    .send()
82                    .await?
83                    .error_for_status()?;
84                Ok(res.json().await?)
85            }
86            ClientKind::Unix => {
87                let mut unix_client = ClientUnix::try_new(&self.endpoint).await?;
88                let res = unix_client
89                    .send_request_json::<_, _, Value>(
90                        path,
91                        Method::POST,
92                        &[("Content-Type", "application/json")],
93                        Some(&payload),
94                    )
95                    .await?;
96                Ok(res.1)
97            }
98        }
99    }
100
101    pub async fn get_key(
102        &self,
103        path: Option<String>,
104        purpose: Option<String>,
105    ) -> Result<GetKeyResponse> {
106        let data = json!({
107            "path": path.unwrap_or_default(),
108            "purpose": purpose.unwrap_or_default(),
109        });
110        let response = self.send_rpc_request("/GetKey", &data).await?;
111        let response = serde_json::from_value::<GetKeyResponse>(response)?;
112
113        Ok(response)
114    }
115
116    pub async fn get_quote(&self, report_data: Vec<u8>) -> Result<GetQuoteResponse> {
117        if report_data.is_empty() || report_data.len() > 64 {
118            anyhow::bail!("Invalid report data length")
119        }
120        let hex_data = hex_encode(report_data);
121        let data = json!({ "report_data": hex_data });
122        let response = self.send_rpc_request("/GetQuote", &data).await?;
123        let response = serde_json::from_value::<GetQuoteResponse>(response)?;
124
125        Ok(response)
126    }
127
128    pub async fn info(&self) -> Result<InfoResponse> {
129        let response = self.send_rpc_request("/Info", &json!({})).await?;
130        Ok(InfoResponse::validated_from_value(response)?)
131    }
132
133    pub async fn emit_event(&self, event: String, payload: Vec<u8>) -> Result<()> {
134        if event.is_empty() {
135            anyhow::bail!("Event name cannot be empty")
136        }
137        let hex_payload = hex_encode(payload);
138        let data = json!({ "event": event, "payload": hex_payload });
139        self.send_rpc_request::<_, ()>("/EmitEvent", &data).await?;
140        Ok(())
141    }
142
143    pub async fn get_tls_key(&self, tls_key_config: TlsKeyConfig) -> Result<GetTlsKeyResponse> {
144        let response = self.send_rpc_request("/GetTlsKey", &tls_key_config).await?;
145        let response = serde_json::from_value::<GetTlsKeyResponse>(response)?;
146
147        Ok(response)
148    }
149}