Skip to main content

alien_commands_client/
client.rs

1use std::time::Duration;
2
3use base64::{engine::general_purpose, Engine as _};
4use chrono::{DateTime, Utc};
5use serde::{de::DeserializeOwned, Deserialize, Serialize};
6use tracing::debug;
7
8use crate::error::CommandError;
9
10/// Configuration for the commands client.
11pub struct CommandsClientConfig {
12    /// Command timeout (default: 60s)
13    pub timeout: Duration,
14    /// Polling interval (default: 500ms)
15    pub poll_interval: Duration,
16    /// Max polling interval (default: 5s)
17    pub max_poll_interval: Duration,
18    /// Backoff multiplier (default: 1.5)
19    pub poll_backoff: f64,
20    /// Allow local file:// storage backends (dev only)
21    pub allow_local_storage: bool,
22}
23
24impl Default for CommandsClientConfig {
25    fn default() -> Self {
26        Self {
27            timeout: Duration::from_secs(60),
28            poll_interval: Duration::from_millis(500),
29            max_poll_interval: Duration::from_secs(5),
30            poll_backoff: 1.5,
31            allow_local_storage: false,
32        }
33    }
34}
35
36/// Options for a single invoke call.
37pub struct InvokeOptions {
38    /// Override the default timeout for this invocation.
39    pub timeout: Option<Duration>,
40    /// Set a deadline for the command (server-side expiry).
41    pub deadline: Option<DateTime<Utc>>,
42    /// Idempotency key to prevent duplicate commands.
43    pub idempotency_key: Option<String>,
44}
45
46/// High-level client for invoking commands on Alien deployments.
47pub struct CommandsClient {
48    manager_url: String,
49    deployment_id: String,
50    token: String,
51    http_client: reqwest::Client,
52    config: CommandsClientConfig,
53}
54
55// -- API response types (internal) --
56
57#[derive(Deserialize)]
58#[serde(rename_all = "camelCase")]
59struct CreateCommandResponse {
60    command_id: String,
61    state: String,
62    next: String,
63}
64
65#[derive(Deserialize)]
66#[serde(rename_all = "camelCase")]
67struct CommandStatusResponse {
68    command_id: String,
69    state: String,
70    #[serde(default)]
71    response: Option<CommandResponseBody>,
72}
73
74#[derive(Deserialize)]
75#[serde(rename_all = "camelCase")]
76struct CommandResponseBody {
77    status: String,
78    #[serde(default)]
79    response: Option<BodySpecResponse>,
80    #[serde(default)]
81    code: Option<String>,
82    #[serde(default)]
83    message: Option<String>,
84}
85
86#[derive(Deserialize)]
87#[serde(rename_all = "camelCase")]
88struct BodySpecResponse {
89    mode: String,
90    #[serde(default)]
91    inline_base64: Option<String>,
92    #[serde(default)]
93    storage_get_request: Option<StorageGetRequest>,
94}
95
96#[derive(Deserialize)]
97#[serde(rename_all = "camelCase")]
98struct StorageGetRequest {
99    backend: StorageBackend,
100}
101
102#[derive(Deserialize)]
103#[serde(rename_all = "camelCase")]
104struct StorageBackend {
105    #[serde(rename = "type")]
106    backend_type: String,
107    #[serde(default)]
108    url: Option<String>,
109    #[serde(default)]
110    method: Option<String>,
111    #[serde(default)]
112    headers: Option<std::collections::HashMap<String, String>>,
113    #[serde(default, rename = "filePath")]
114    file_path: Option<String>,
115}
116
117impl CommandsClient {
118    /// Create a new commands client with default config.
119    pub fn new(manager_url: &str, deployment_id: &str, token: &str) -> Self {
120        Self::with_config(
121            manager_url,
122            deployment_id,
123            token,
124            CommandsClientConfig::default(),
125        )
126    }
127
128    /// Create a new commands client with custom config.
129    pub fn with_config(
130        manager_url: &str,
131        deployment_id: &str,
132        token: &str,
133        config: CommandsClientConfig,
134    ) -> Self {
135        let mut headers = reqwest::header::HeaderMap::new();
136        headers.insert(
137            reqwest::header::AUTHORIZATION,
138            reqwest::header::HeaderValue::from_str(&format!("Bearer {}", token))
139                .expect("invalid token"),
140        );
141
142        let http_client = reqwest::Client::builder()
143            .default_headers(headers)
144            .build()
145            .expect("failed to build HTTP client");
146
147        Self {
148            manager_url: manager_url.trim_end_matches('/').to_string(),
149            deployment_id: deployment_id.to_string(),
150            token: token.to_string(),
151            http_client,
152            config,
153        }
154    }
155
156    /// Invoke a command and wait for the result.
157    ///
158    /// Sends params inline, polls for completion, and decodes the response.
159    pub async fn invoke<P: Serialize, R: DeserializeOwned>(
160        &self,
161        command: &str,
162        params: P,
163    ) -> Result<R, CommandError> {
164        self.invoke_with_options(command, params, None).await
165    }
166
167    /// Invoke a command with options and wait for the result.
168    pub async fn invoke_with_options<P: Serialize, R: DeserializeOwned>(
169        &self,
170        command: &str,
171        params: P,
172        options: Option<InvokeOptions>,
173    ) -> Result<R, CommandError> {
174        let timeout = options
175            .as_ref()
176            .and_then(|o| o.timeout)
177            .unwrap_or(self.config.timeout);
178
179        // Step 1: Create the command (always inline — server handles storage)
180        let command_id = self.create(command, params, options.as_ref()).await?;
181
182        debug!(command_id = %command_id, command = %command, "Command created, polling for result");
183
184        // Step 2: Poll for completion with exponential backoff
185        let start = tokio::time::Instant::now();
186        let mut interval = self.config.poll_interval;
187
188        loop {
189            if start.elapsed() > timeout {
190                return Err(CommandError::Timeout {
191                    command_id,
192                    last_state: "polling".to_string(),
193                });
194            }
195
196            tokio::time::sleep(interval).await;
197
198            let status = self.get_status(&command_id).await?;
199
200            match status.state.as_str() {
201                "SUCCEEDED" => {
202                    return self.decode_response(&command_id, status.response).await;
203                }
204                "FAILED" => {
205                    let (code, message) = status
206                        .response
207                        .as_ref()
208                        .map(|r| {
209                            (
210                                r.code.clone().unwrap_or_default(),
211                                r.message.clone().unwrap_or_default(),
212                            )
213                        })
214                        .unwrap_or_default();
215                    return Err(CommandError::DeploymentError {
216                        command_id,
217                        code,
218                        message,
219                    });
220                }
221                "EXPIRED" => {
222                    return Err(CommandError::Expired { command_id });
223                }
224                _ => {
225                    // Still in progress — backoff
226                    interval = Duration::from_secs_f64(
227                        (interval.as_secs_f64() * self.config.poll_backoff)
228                            .min(self.config.max_poll_interval.as_secs_f64()),
229                    );
230                }
231            }
232        }
233    }
234
235    /// Create a command without waiting for the result. Returns the command ID.
236    pub async fn create<P: Serialize>(
237        &self,
238        command: &str,
239        params: P,
240        options: Option<&InvokeOptions>,
241    ) -> Result<String, CommandError> {
242        let params_json = serde_json::to_vec(&params)?;
243        let params_base64 = general_purpose::STANDARD.encode(&params_json);
244
245        let mut body = serde_json::json!({
246            "deploymentId": self.deployment_id,
247            "command": command,
248            "params": {
249                "mode": "inline",
250                "inlineBase64": params_base64,
251            },
252        });
253
254        if let Some(opts) = options {
255            if let Some(deadline) = opts.deadline {
256                body["deadline"] = serde_json::Value::String(deadline.to_rfc3339());
257            }
258            if let Some(ref key) = opts.idempotency_key {
259                body["idempotencyKey"] = serde_json::Value::String(key.clone());
260            }
261        }
262
263        let url = format!("{}/commands", self.manager_url);
264        let resp = self.http_client.post(&url).json(&body).send().await?;
265
266        if !resp.status().is_success() {
267            let status = resp.status().as_u16();
268            let body = resp.text().await.unwrap_or_default();
269            return Err(CommandError::CreationFailed { status, body });
270        }
271
272        let result: CreateCommandResponse = resp.json().await?;
273        Ok(result.command_id)
274    }
275
276    /// Poll for a command's status.
277    pub async fn get_status(
278        &self,
279        command_id: &str,
280    ) -> Result<CommandStatusResponse, CommandError> {
281        let url = format!("{}/commands/{}", self.manager_url, command_id);
282        let resp = self.http_client.get(&url).send().await?;
283
284        if !resp.status().is_success() {
285            let status = resp.status().as_u16();
286            let body = resp.text().await.unwrap_or_default();
287            return Err(CommandError::CreationFailed { status, body });
288        }
289
290        Ok(resp.json().await?)
291    }
292
293    // -- Internal helpers --
294
295    async fn decode_response<R: DeserializeOwned>(
296        &self,
297        command_id: &str,
298        response: Option<CommandResponseBody>,
299    ) -> Result<R, CommandError> {
300        let resp = response.ok_or_else(|| CommandError::ResponseDecodingFailed {
301            command_id: command_id.to_string(),
302            reason: "No response body in SUCCEEDED status".to_string(),
303        })?;
304
305        let body = resp
306            .response
307            .ok_or_else(|| CommandError::ResponseDecodingFailed {
308                command_id: command_id.to_string(),
309                reason: "No response field in success response".to_string(),
310            })?;
311
312        let bytes = match body.mode.as_str() {
313            "inline" => {
314                let base64_data =
315                    body.inline_base64
316                        .ok_or_else(|| CommandError::ResponseDecodingFailed {
317                            command_id: command_id.to_string(),
318                            reason: "Inline response missing inlineBase64 field".to_string(),
319                        })?;
320
321                general_purpose::STANDARD
322                    .decode(&base64_data)
323                    .map_err(|e| CommandError::ResponseDecodingFailed {
324                        command_id: command_id.to_string(),
325                        reason: format!("Base64 decode failed: {}", e),
326                    })?
327            }
328            "storage" => {
329                let get_request = body.storage_get_request.ok_or_else(|| {
330                    CommandError::ResponseDecodingFailed {
331                        command_id: command_id.to_string(),
332                        reason: "Storage response missing storageGetRequest".to_string(),
333                    }
334                })?;
335
336                self.download_from_storage(&get_request).await?
337            }
338            other => {
339                return Err(CommandError::ResponseDecodingFailed {
340                    command_id: command_id.to_string(),
341                    reason: format!("Unknown response mode: {}", other),
342                })
343            }
344        };
345
346        serde_json::from_slice(&bytes).map_err(|e| CommandError::ResponseDecodingFailed {
347            command_id: command_id.to_string(),
348            reason: format!("JSON decode failed: {}", e),
349        })
350    }
351
352    async fn download_from_storage(
353        &self,
354        get_request: &StorageGetRequest,
355    ) -> Result<Vec<u8>, CommandError> {
356        match get_request.backend.backend_type.as_str() {
357            "http" => {
358                let url = get_request.backend.url.as_deref().ok_or_else(|| {
359                    CommandError::StorageOperationFailed {
360                        reason: "HTTP storage backend missing url".to_string(),
361                    }
362                })?;
363
364                let method = get_request.backend.method.as_deref().unwrap_or("GET");
365
366                // Use a plain client (no auth headers — presigned URL carries auth)
367                let plain_http = reqwest::Client::new();
368                let mut req = match method {
369                    "PUT" => plain_http.put(url),
370                    "POST" => plain_http.post(url),
371                    _ => plain_http.get(url),
372                };
373
374                if let Some(headers) = &get_request.backend.headers {
375                    for (k, v) in headers {
376                        req = req.header(k.as_str(), v.as_str());
377                    }
378                }
379
380                let resp = req
381                    .send()
382                    .await
383                    .map_err(|e| CommandError::StorageOperationFailed {
384                        reason: format!("Storage download failed: {}", e),
385                    })?;
386
387                if !resp.status().is_success() {
388                    return Err(CommandError::StorageOperationFailed {
389                        reason: format!("Storage download returned HTTP {}", resp.status()),
390                    });
391                }
392
393                resp.bytes().await.map(|b| b.to_vec()).map_err(|e| {
394                    CommandError::StorageOperationFailed {
395                        reason: format!("Failed to read storage response bytes: {}", e),
396                    }
397                })
398            }
399            "local" if self.config.allow_local_storage => {
400                let file_path = get_request.backend.file_path.as_deref().ok_or_else(|| {
401                    CommandError::StorageOperationFailed {
402                        reason: "Local storage backend missing filePath".to_string(),
403                    }
404                })?;
405
406                let path = std::path::Path::new(file_path);
407                if path.is_absolute() || file_path.contains("..") {
408                    return Err(CommandError::StorageOperationFailed {
409                        reason: "Local storage path traversal detected".to_string(),
410                    });
411                }
412
413                tokio::fs::read(file_path)
414                    .await
415                    .map_err(|e| CommandError::StorageOperationFailed {
416                        reason: format!("Failed to read local file {}: {}", file_path, e),
417                    })
418            }
419            "local" => Err(CommandError::StorageOperationFailed {
420                reason: "Local storage backend not allowed (set allow_local_storage: true)"
421                    .to_string(),
422            }),
423            other => Err(CommandError::StorageOperationFailed {
424                reason: format!("Unknown storage backend type: {}", other),
425            }),
426        }
427    }
428}