1use std::time::Duration;
2
3use base64::{engine::general_purpose, Engine as _};
4use chrono::{DateTime, Utc};
5use serde::{de::DeserializeOwned, Deserialize, Serialize};
6use tracing::debug;
7
8use crate::error::CommandError;
9
10pub struct CommandsClientConfig {
12 pub timeout: Duration,
14 pub poll_interval: Duration,
16 pub max_poll_interval: Duration,
18 pub poll_backoff: f64,
20 pub allow_local_storage: bool,
22}
23
24impl Default for CommandsClientConfig {
25 fn default() -> Self {
26 Self {
27 timeout: Duration::from_secs(60),
28 poll_interval: Duration::from_millis(500),
29 max_poll_interval: Duration::from_secs(5),
30 poll_backoff: 1.5,
31 allow_local_storage: false,
32 }
33 }
34}
35
36pub struct InvokeOptions {
38 pub timeout: Option<Duration>,
40 pub deadline: Option<DateTime<Utc>>,
42 pub idempotency_key: Option<String>,
44}
45
46pub struct CommandsClient {
48 manager_url: String,
49 deployment_id: String,
50 token: String,
51 http_client: reqwest::Client,
52 config: CommandsClientConfig,
53}
54
55#[derive(Deserialize)]
58#[serde(rename_all = "camelCase")]
59struct CreateCommandResponse {
60 command_id: String,
61 state: String,
62 next: String,
63}
64
65#[derive(Deserialize)]
66#[serde(rename_all = "camelCase")]
67struct CommandStatusResponse {
68 command_id: String,
69 state: String,
70 #[serde(default)]
71 response: Option<CommandResponseBody>,
72}
73
74#[derive(Deserialize)]
75#[serde(rename_all = "camelCase")]
76struct CommandResponseBody {
77 status: String,
78 #[serde(default)]
79 response: Option<BodySpecResponse>,
80 #[serde(default)]
81 code: Option<String>,
82 #[serde(default)]
83 message: Option<String>,
84}
85
86#[derive(Deserialize)]
87#[serde(rename_all = "camelCase")]
88struct BodySpecResponse {
89 mode: String,
90 #[serde(default)]
91 inline_base64: Option<String>,
92 #[serde(default)]
93 storage_get_request: Option<StorageGetRequest>,
94}
95
96#[derive(Deserialize)]
97#[serde(rename_all = "camelCase")]
98struct StorageGetRequest {
99 backend: StorageBackend,
100}
101
102#[derive(Deserialize)]
103#[serde(rename_all = "camelCase")]
104struct StorageBackend {
105 #[serde(rename = "type")]
106 backend_type: String,
107 #[serde(default)]
108 url: Option<String>,
109 #[serde(default)]
110 method: Option<String>,
111 #[serde(default)]
112 headers: Option<std::collections::HashMap<String, String>>,
113 #[serde(default, rename = "filePath")]
114 file_path: Option<String>,
115}
116
117impl CommandsClient {
118 pub fn new(manager_url: &str, deployment_id: &str, token: &str) -> Self {
120 Self::with_config(
121 manager_url,
122 deployment_id,
123 token,
124 CommandsClientConfig::default(),
125 )
126 }
127
128 pub fn with_config(
130 manager_url: &str,
131 deployment_id: &str,
132 token: &str,
133 config: CommandsClientConfig,
134 ) -> Self {
135 let mut headers = reqwest::header::HeaderMap::new();
136 headers.insert(
137 reqwest::header::AUTHORIZATION,
138 reqwest::header::HeaderValue::from_str(&format!("Bearer {}", token))
139 .expect("invalid token"),
140 );
141
142 let http_client = reqwest::Client::builder()
143 .default_headers(headers)
144 .build()
145 .expect("failed to build HTTP client");
146
147 Self {
148 manager_url: manager_url.trim_end_matches('/').to_string(),
149 deployment_id: deployment_id.to_string(),
150 token: token.to_string(),
151 http_client,
152 config,
153 }
154 }
155
156 pub async fn invoke<P: Serialize, R: DeserializeOwned>(
160 &self,
161 command: &str,
162 params: P,
163 ) -> Result<R, CommandError> {
164 self.invoke_with_options(command, params, None).await
165 }
166
167 pub async fn invoke_with_options<P: Serialize, R: DeserializeOwned>(
169 &self,
170 command: &str,
171 params: P,
172 options: Option<InvokeOptions>,
173 ) -> Result<R, CommandError> {
174 let timeout = options
175 .as_ref()
176 .and_then(|o| o.timeout)
177 .unwrap_or(self.config.timeout);
178
179 let command_id = self.create(command, params, options.as_ref()).await?;
181
182 debug!(command_id = %command_id, command = %command, "Command created, polling for result");
183
184 let start = tokio::time::Instant::now();
186 let mut interval = self.config.poll_interval;
187
188 loop {
189 if start.elapsed() > timeout {
190 return Err(CommandError::Timeout {
191 command_id,
192 last_state: "polling".to_string(),
193 });
194 }
195
196 tokio::time::sleep(interval).await;
197
198 let status = self.get_status(&command_id).await?;
199
200 match status.state.as_str() {
201 "SUCCEEDED" => {
202 return self.decode_response(&command_id, status.response).await;
203 }
204 "FAILED" => {
205 let (code, message) = status
206 .response
207 .as_ref()
208 .map(|r| {
209 (
210 r.code.clone().unwrap_or_default(),
211 r.message.clone().unwrap_or_default(),
212 )
213 })
214 .unwrap_or_default();
215 return Err(CommandError::DeploymentError {
216 command_id,
217 code,
218 message,
219 });
220 }
221 "EXPIRED" => {
222 return Err(CommandError::Expired { command_id });
223 }
224 _ => {
225 interval = Duration::from_secs_f64(
227 (interval.as_secs_f64() * self.config.poll_backoff)
228 .min(self.config.max_poll_interval.as_secs_f64()),
229 );
230 }
231 }
232 }
233 }
234
235 pub async fn create<P: Serialize>(
237 &self,
238 command: &str,
239 params: P,
240 options: Option<&InvokeOptions>,
241 ) -> Result<String, CommandError> {
242 let params_json = serde_json::to_vec(¶ms)?;
243 let params_base64 = general_purpose::STANDARD.encode(¶ms_json);
244
245 let mut body = serde_json::json!({
246 "deploymentId": self.deployment_id,
247 "command": command,
248 "params": {
249 "mode": "inline",
250 "inlineBase64": params_base64,
251 },
252 });
253
254 if let Some(opts) = options {
255 if let Some(deadline) = opts.deadline {
256 body["deadline"] = serde_json::Value::String(deadline.to_rfc3339());
257 }
258 if let Some(ref key) = opts.idempotency_key {
259 body["idempotencyKey"] = serde_json::Value::String(key.clone());
260 }
261 }
262
263 let url = format!("{}/commands", self.manager_url);
264 let resp = self.http_client.post(&url).json(&body).send().await?;
265
266 if !resp.status().is_success() {
267 let status = resp.status().as_u16();
268 let body = resp.text().await.unwrap_or_default();
269 return Err(CommandError::CreationFailed { status, body });
270 }
271
272 let result: CreateCommandResponse = resp.json().await?;
273 Ok(result.command_id)
274 }
275
276 pub async fn get_status(
278 &self,
279 command_id: &str,
280 ) -> Result<CommandStatusResponse, CommandError> {
281 let url = format!("{}/commands/{}", self.manager_url, command_id);
282 let resp = self.http_client.get(&url).send().await?;
283
284 if !resp.status().is_success() {
285 let status = resp.status().as_u16();
286 let body = resp.text().await.unwrap_or_default();
287 return Err(CommandError::CreationFailed { status, body });
288 }
289
290 Ok(resp.json().await?)
291 }
292
293 async fn decode_response<R: DeserializeOwned>(
296 &self,
297 command_id: &str,
298 response: Option<CommandResponseBody>,
299 ) -> Result<R, CommandError> {
300 let resp = response.ok_or_else(|| CommandError::ResponseDecodingFailed {
301 command_id: command_id.to_string(),
302 reason: "No response body in SUCCEEDED status".to_string(),
303 })?;
304
305 let body = resp
306 .response
307 .ok_or_else(|| CommandError::ResponseDecodingFailed {
308 command_id: command_id.to_string(),
309 reason: "No response field in success response".to_string(),
310 })?;
311
312 let bytes = match body.mode.as_str() {
313 "inline" => {
314 let base64_data =
315 body.inline_base64
316 .ok_or_else(|| CommandError::ResponseDecodingFailed {
317 command_id: command_id.to_string(),
318 reason: "Inline response missing inlineBase64 field".to_string(),
319 })?;
320
321 general_purpose::STANDARD
322 .decode(&base64_data)
323 .map_err(|e| CommandError::ResponseDecodingFailed {
324 command_id: command_id.to_string(),
325 reason: format!("Base64 decode failed: {}", e),
326 })?
327 }
328 "storage" => {
329 let get_request = body.storage_get_request.ok_or_else(|| {
330 CommandError::ResponseDecodingFailed {
331 command_id: command_id.to_string(),
332 reason: "Storage response missing storageGetRequest".to_string(),
333 }
334 })?;
335
336 self.download_from_storage(&get_request).await?
337 }
338 other => {
339 return Err(CommandError::ResponseDecodingFailed {
340 command_id: command_id.to_string(),
341 reason: format!("Unknown response mode: {}", other),
342 })
343 }
344 };
345
346 serde_json::from_slice(&bytes).map_err(|e| CommandError::ResponseDecodingFailed {
347 command_id: command_id.to_string(),
348 reason: format!("JSON decode failed: {}", e),
349 })
350 }
351
352 async fn download_from_storage(
353 &self,
354 get_request: &StorageGetRequest,
355 ) -> Result<Vec<u8>, CommandError> {
356 match get_request.backend.backend_type.as_str() {
357 "http" => {
358 let url = get_request.backend.url.as_deref().ok_or_else(|| {
359 CommandError::StorageOperationFailed {
360 reason: "HTTP storage backend missing url".to_string(),
361 }
362 })?;
363
364 let method = get_request.backend.method.as_deref().unwrap_or("GET");
365
366 let plain_http = reqwest::Client::new();
368 let mut req = match method {
369 "PUT" => plain_http.put(url),
370 "POST" => plain_http.post(url),
371 _ => plain_http.get(url),
372 };
373
374 if let Some(headers) = &get_request.backend.headers {
375 for (k, v) in headers {
376 req = req.header(k.as_str(), v.as_str());
377 }
378 }
379
380 let resp = req
381 .send()
382 .await
383 .map_err(|e| CommandError::StorageOperationFailed {
384 reason: format!("Storage download failed: {}", e),
385 })?;
386
387 if !resp.status().is_success() {
388 return Err(CommandError::StorageOperationFailed {
389 reason: format!("Storage download returned HTTP {}", resp.status()),
390 });
391 }
392
393 resp.bytes().await.map(|b| b.to_vec()).map_err(|e| {
394 CommandError::StorageOperationFailed {
395 reason: format!("Failed to read storage response bytes: {}", e),
396 }
397 })
398 }
399 "local" if self.config.allow_local_storage => {
400 let file_path = get_request.backend.file_path.as_deref().ok_or_else(|| {
401 CommandError::StorageOperationFailed {
402 reason: "Local storage backend missing filePath".to_string(),
403 }
404 })?;
405
406 let path = std::path::Path::new(file_path);
407 if path.is_absolute() || file_path.contains("..") {
408 return Err(CommandError::StorageOperationFailed {
409 reason: "Local storage path traversal detected".to_string(),
410 });
411 }
412
413 tokio::fs::read(file_path)
414 .await
415 .map_err(|e| CommandError::StorageOperationFailed {
416 reason: format!("Failed to read local file {}: {}", file_path, e),
417 })
418 }
419 "local" => Err(CommandError::StorageOperationFailed {
420 reason: "Local storage backend not allowed (set allow_local_storage: true)"
421 .to_string(),
422 }),
423 other => Err(CommandError::StorageOperationFailed {
424 reason: format!("Unknown storage backend type: {}", other),
425 }),
426 }
427 }
428}