1use std::time::Duration;
2
3use base64::{engine::general_purpose, Engine as _};
4use chrono::{DateTime, Utc};
5use serde::{de::DeserializeOwned, Deserialize, Serialize};
6use tracing::debug;
7
8use crate::error::CommandError;
9
10pub struct CommandsClientConfig {
12 pub timeout: Duration,
14 pub poll_interval: Duration,
16 pub max_poll_interval: Duration,
18 pub poll_backoff: f64,
20 pub allow_local_storage: bool,
22}
23
24impl Default for CommandsClientConfig {
25 fn default() -> Self {
26 Self {
27 timeout: Duration::from_secs(60),
28 poll_interval: Duration::from_millis(500),
29 max_poll_interval: Duration::from_secs(5),
30 poll_backoff: 1.5,
31 allow_local_storage: false,
32 }
33 }
34}
35
36pub struct InvokeOptions {
38 pub timeout: Option<Duration>,
40 pub deadline: Option<DateTime<Utc>>,
42 pub idempotency_key: Option<String>,
44}
45
46pub struct CommandsClient {
48 manager_url: String,
49 deployment_id: String,
50 http_client: reqwest::Client,
51 config: CommandsClientConfig,
52}
53
54#[derive(Deserialize)]
57#[serde(rename_all = "camelCase")]
58struct CreateCommandResponse {
59 command_id: String,
60}
61
62#[derive(Deserialize)]
63#[serde(rename_all = "camelCase")]
64struct CommandStatusResponse {
65 state: String,
66 #[serde(default)]
67 response: Option<CommandResponseBody>,
68}
69
70#[derive(Deserialize)]
71#[serde(rename_all = "camelCase")]
72struct CommandResponseBody {
73 #[serde(default)]
74 response: Option<BodySpecResponse>,
75 #[serde(default)]
76 code: Option<String>,
77 #[serde(default)]
78 message: Option<String>,
79}
80
81#[derive(Deserialize)]
82#[serde(rename_all = "camelCase")]
83struct BodySpecResponse {
84 mode: String,
85 #[serde(default)]
86 inline_base64: Option<String>,
87 #[serde(default)]
88 storage_get_request: Option<StorageGetRequest>,
89}
90
91#[derive(Deserialize)]
92#[serde(rename_all = "camelCase")]
93struct StorageGetRequest {
94 backend: StorageBackend,
95}
96
97#[derive(Deserialize)]
98#[serde(rename_all = "camelCase")]
99struct StorageBackend {
100 #[serde(rename = "type")]
101 backend_type: String,
102 #[serde(default)]
103 url: Option<String>,
104 #[serde(default)]
105 method: Option<String>,
106 #[serde(default)]
107 headers: Option<std::collections::HashMap<String, String>>,
108 #[serde(default, rename = "filePath")]
109 file_path: Option<String>,
110}
111
112impl CommandsClient {
113 pub fn new(manager_url: &str, deployment_id: &str, token: &str) -> Self {
115 Self::with_config(
116 manager_url,
117 deployment_id,
118 token,
119 CommandsClientConfig::default(),
120 )
121 }
122
123 pub fn with_config(
125 manager_url: &str,
126 deployment_id: &str,
127 token: &str,
128 config: CommandsClientConfig,
129 ) -> Self {
130 let mut headers = reqwest::header::HeaderMap::new();
131 headers.insert(
132 reqwest::header::AUTHORIZATION,
133 reqwest::header::HeaderValue::from_str(&format!("Bearer {}", token))
134 .expect("invalid token"),
135 );
136
137 let http_client = reqwest::Client::builder()
138 .default_headers(headers)
139 .build()
140 .expect("failed to build HTTP client");
141
142 Self {
143 manager_url: manager_url.trim_end_matches('/').to_string(),
144 deployment_id: deployment_id.to_string(),
145 http_client,
146 config,
147 }
148 }
149
150 pub fn with_http_client(
155 manager_url: &str,
156 deployment_id: &str,
157 http_client: reqwest::Client,
158 config: CommandsClientConfig,
159 ) -> Self {
160 Self {
161 manager_url: manager_url.trim_end_matches('/').to_string(),
162 deployment_id: deployment_id.to_string(),
163 http_client,
164 config,
165 }
166 }
167
168 pub async fn invoke<P: Serialize, R: DeserializeOwned>(
172 &self,
173 command: &str,
174 params: P,
175 ) -> Result<R, CommandError> {
176 self.invoke_with_options(command, params, None).await
177 }
178
179 pub async fn invoke_with_options<P: Serialize, R: DeserializeOwned>(
181 &self,
182 command: &str,
183 params: P,
184 options: Option<InvokeOptions>,
185 ) -> Result<R, CommandError> {
186 let timeout = options
187 .as_ref()
188 .and_then(|o| o.timeout)
189 .unwrap_or(self.config.timeout);
190
191 let command_id = self.create(command, params, options.as_ref()).await?;
193
194 debug!(command_id = %command_id, command = %command, "Command created, polling for result");
195
196 let start = tokio::time::Instant::now();
198 let mut interval = self.config.poll_interval;
199
200 loop {
201 if start.elapsed() > timeout {
202 return Err(CommandError::Timeout {
203 command_id,
204 last_state: "polling".to_string(),
205 });
206 }
207
208 tokio::time::sleep(interval).await;
209
210 let status = self.get_status(&command_id).await?;
211
212 match status.state.as_str() {
213 "SUCCEEDED" => {
214 return self.decode_response(&command_id, status.response).await;
215 }
216 "FAILED" => {
217 let (code, message) = status
218 .response
219 .as_ref()
220 .map(|r| {
221 (
222 r.code.clone().unwrap_or_default(),
223 r.message.clone().unwrap_or_default(),
224 )
225 })
226 .unwrap_or_default();
227 return Err(CommandError::DeploymentError {
228 command_id,
229 code,
230 message,
231 });
232 }
233 "EXPIRED" => {
234 return Err(CommandError::Expired { command_id });
235 }
236 _ => {
237 interval = Duration::from_secs_f64(
239 (interval.as_secs_f64() * self.config.poll_backoff)
240 .min(self.config.max_poll_interval.as_secs_f64()),
241 );
242 }
243 }
244 }
245 }
246
247 pub async fn create<P: Serialize>(
249 &self,
250 command: &str,
251 params: P,
252 options: Option<&InvokeOptions>,
253 ) -> Result<String, CommandError> {
254 let params_json = serde_json::to_vec(¶ms)?;
255 let params_base64 = general_purpose::STANDARD.encode(¶ms_json);
256
257 let mut body = serde_json::json!({
258 "deploymentId": self.deployment_id,
259 "command": command,
260 "params": {
261 "mode": "inline",
262 "inlineBase64": params_base64,
263 },
264 });
265
266 if let Some(opts) = options {
267 if let Some(deadline) = opts.deadline {
268 body["deadline"] = serde_json::Value::String(deadline.to_rfc3339());
269 }
270 if let Some(ref key) = opts.idempotency_key {
271 body["idempotencyKey"] = serde_json::Value::String(key.clone());
272 }
273 }
274
275 let url = format!("{}/commands", self.manager_url);
276 let resp = self.http_client.post(&url).json(&body).send().await?;
277
278 if !resp.status().is_success() {
279 let status = resp.status().as_u16();
280 let body = resp.text().await.unwrap_or_default();
281 return Err(CommandError::CreationFailed { status, body });
282 }
283
284 let result: CreateCommandResponse = resp.json().await?;
285 Ok(result.command_id)
286 }
287
288 async fn get_status(&self, command_id: &str) -> Result<CommandStatusResponse, CommandError> {
290 let url = format!("{}/commands/{}", self.manager_url, command_id);
291 let resp = self.http_client.get(&url).send().await?;
292
293 if !resp.status().is_success() {
294 let status = resp.status().as_u16();
295 let body = resp.text().await.unwrap_or_default();
296 return Err(CommandError::CreationFailed { status, body });
297 }
298
299 Ok(resp.json().await?)
300 }
301
302 async fn decode_response<R: DeserializeOwned>(
305 &self,
306 command_id: &str,
307 response: Option<CommandResponseBody>,
308 ) -> Result<R, CommandError> {
309 let resp = response.ok_or_else(|| CommandError::ResponseDecodingFailed {
310 command_id: command_id.to_string(),
311 reason: "No response body in SUCCEEDED status".to_string(),
312 })?;
313
314 let body = resp
315 .response
316 .ok_or_else(|| CommandError::ResponseDecodingFailed {
317 command_id: command_id.to_string(),
318 reason: "No response field in success response".to_string(),
319 })?;
320
321 let bytes = match body.mode.as_str() {
322 "inline" => {
323 let base64_data =
324 body.inline_base64
325 .ok_or_else(|| CommandError::ResponseDecodingFailed {
326 command_id: command_id.to_string(),
327 reason: "Inline response missing inlineBase64 field".to_string(),
328 })?;
329
330 general_purpose::STANDARD
331 .decode(&base64_data)
332 .map_err(|e| CommandError::ResponseDecodingFailed {
333 command_id: command_id.to_string(),
334 reason: format!("Base64 decode failed: {}", e),
335 })?
336 }
337 "storage" => {
338 let get_request = body.storage_get_request.ok_or_else(|| {
339 CommandError::ResponseDecodingFailed {
340 command_id: command_id.to_string(),
341 reason: "Storage response missing storageGetRequest".to_string(),
342 }
343 })?;
344
345 self.download_from_storage(&get_request).await?
346 }
347 other => {
348 return Err(CommandError::ResponseDecodingFailed {
349 command_id: command_id.to_string(),
350 reason: format!("Unknown response mode: {}", other),
351 })
352 }
353 };
354
355 serde_json::from_slice(&bytes).map_err(|e| CommandError::ResponseDecodingFailed {
356 command_id: command_id.to_string(),
357 reason: format!("JSON decode failed: {}", e),
358 })
359 }
360
361 async fn download_from_storage(
362 &self,
363 get_request: &StorageGetRequest,
364 ) -> Result<Vec<u8>, CommandError> {
365 match get_request.backend.backend_type.as_str() {
366 "http" => {
367 let url = get_request.backend.url.as_deref().ok_or_else(|| {
368 CommandError::StorageOperationFailed {
369 reason: "HTTP storage backend missing url".to_string(),
370 }
371 })?;
372
373 let method = get_request.backend.method.as_deref().unwrap_or("GET");
374
375 let plain_http = reqwest::Client::new();
377 let mut req = match method {
378 "PUT" => plain_http.put(url),
379 "POST" => plain_http.post(url),
380 _ => plain_http.get(url),
381 };
382
383 if let Some(headers) = &get_request.backend.headers {
384 for (k, v) in headers {
385 req = req.header(k.as_str(), v.as_str());
386 }
387 }
388
389 let resp = req
390 .send()
391 .await
392 .map_err(|e| CommandError::StorageOperationFailed {
393 reason: format!("Storage download failed: {}", e),
394 })?;
395
396 if !resp.status().is_success() {
397 return Err(CommandError::StorageOperationFailed {
398 reason: format!("Storage download returned HTTP {}", resp.status()),
399 });
400 }
401
402 resp.bytes().await.map(|b| b.to_vec()).map_err(|e| {
403 CommandError::StorageOperationFailed {
404 reason: format!("Failed to read storage response bytes: {}", e),
405 }
406 })
407 }
408 "local" if self.config.allow_local_storage => {
409 let file_path = get_request.backend.file_path.as_deref().ok_or_else(|| {
410 CommandError::StorageOperationFailed {
411 reason: "Local storage backend missing filePath".to_string(),
412 }
413 })?;
414
415 let path = std::path::Path::new(file_path);
416 if path.is_absolute() || file_path.contains("..") {
417 return Err(CommandError::StorageOperationFailed {
418 reason: "Local storage path traversal detected".to_string(),
419 });
420 }
421
422 tokio::fs::read(file_path)
423 .await
424 .map_err(|e| CommandError::StorageOperationFailed {
425 reason: format!("Failed to read local file {}: {}", file_path, e),
426 })
427 }
428 "local" => Err(CommandError::StorageOperationFailed {
429 reason: "Local storage backend not allowed (set allow_local_storage: true)"
430 .to_string(),
431 }),
432 other => Err(CommandError::StorageOperationFailed {
433 reason: format!("Unknown storage backend type: {}", other),
434 }),
435 }
436 }
437}