dstack_sdk/
dstack_client.rs1use anyhow::Result;
8use hex::encode as hex_encode;
9use http_client_unix_domain_socket::{ClientUnix, Method};
10use reqwest::Client;
11use serde::{de::DeserializeOwned, Serialize};
12use serde_json::{json, Value};
13use std::env;
14
15pub use dstack_sdk_types::dstack::*;
16
17fn get_endpoint(endpoint: Option<&str>) -> String {
18 if let Some(e) = endpoint {
19 return e.to_string();
20 }
21 if let Ok(sim_endpoint) = env::var("DSTACK_SIMULATOR_ENDPOINT") {
22 return sim_endpoint;
23 }
24 "/var/run/dstack.sock".to_string()
25}
26
27#[derive(Debug)]
28pub enum ClientKind {
29 Http,
30 Unix,
31}
32
33pub trait BaseClient {}
34
35pub struct DstackClient {
37 base_url: String,
39 endpoint: String,
41 client: ClientKind,
43}
44
45impl BaseClient for DstackClient {}
46
47impl DstackClient {
48 pub fn new(endpoint: Option<&str>) -> Self {
49 let endpoint = get_endpoint(endpoint);
50 let (base_url, client) = match endpoint {
51 ref e if e.starts_with("http://") || e.starts_with("https://") => {
52 (e.to_string(), ClientKind::Http)
53 }
54 _ => ("http://localhost".to_string(), ClientKind::Unix),
55 };
56
57 DstackClient {
58 base_url,
59 endpoint,
60 client,
61 }
62 }
63
64 async fn send_rpc_request<S: Serialize, D: DeserializeOwned>(
65 &self,
66 path: &str,
67 payload: &S,
68 ) -> anyhow::Result<D> {
69 match &self.client {
70 ClientKind::Http => {
71 let client = Client::new();
72 let url = format!(
73 "{}/{}",
74 self.base_url.trim_end_matches('/'),
75 path.trim_start_matches('/')
76 );
77 let res = client
78 .post(&url)
79 .json(payload)
80 .header("Content-Type", "application/json")
81 .send()
82 .await?
83 .error_for_status()?;
84 Ok(res.json().await?)
85 }
86 ClientKind::Unix => {
87 let mut unix_client = ClientUnix::try_new(&self.endpoint).await?;
88 let res = unix_client
89 .send_request_json::<_, _, Value>(
90 path,
91 Method::POST,
92 &[("Content-Type", "application/json")],
93 Some(&payload),
94 )
95 .await?;
96 Ok(res.1)
97 }
98 }
99 }
100
101 pub async fn get_key(
102 &self,
103 path: Option<String>,
104 purpose: Option<String>,
105 ) -> Result<GetKeyResponse> {
106 let data = json!({
107 "path": path.unwrap_or_default(),
108 "purpose": purpose.unwrap_or_default(),
109 });
110 let response = self.send_rpc_request("/GetKey", &data).await?;
111 let response = serde_json::from_value::<GetKeyResponse>(response)?;
112
113 Ok(response)
114 }
115
116 pub async fn get_quote(&self, report_data: Vec<u8>) -> Result<GetQuoteResponse> {
117 if report_data.is_empty() || report_data.len() > 64 {
118 anyhow::bail!("Invalid report data length")
119 }
120 let hex_data = hex_encode(report_data);
121 let data = json!({ "report_data": hex_data });
122 let response = self.send_rpc_request("/GetQuote", &data).await?;
123 let response = serde_json::from_value::<GetQuoteResponse>(response)?;
124
125 Ok(response)
126 }
127
128 pub async fn info(&self) -> Result<InfoResponse> {
129 let response = self.send_rpc_request("/Info", &json!({})).await?;
130 Ok(InfoResponse::validated_from_value(response)?)
131 }
132
133 pub async fn emit_event(&self, event: String, payload: Vec<u8>) -> Result<()> {
134 if event.is_empty() {
135 anyhow::bail!("Event name cannot be empty")
136 }
137 let hex_payload = hex_encode(payload);
138 let data = json!({ "event": event, "payload": hex_payload });
139 self.send_rpc_request::<_, ()>("/EmitEvent", &data).await?;
140 Ok(())
141 }
142
143 pub async fn get_tls_key(&self, tls_key_config: TlsKeyConfig) -> Result<GetTlsKeyResponse> {
144 let response = self.send_rpc_request("/GetTlsKey", &tls_key_config).await?;
145 let response = serde_json::from_value::<GetTlsKeyResponse>(response)?;
146
147 Ok(response)
148 }
149}