Skip to main content

alien_commands_client/
client.rs

1use std::time::Duration;
2
3use base64::{engine::general_purpose, Engine as _};
4use chrono::{DateTime, Utc};
5use serde::{de::DeserializeOwned, Deserialize, Serialize};
6use tracing::debug;
7
8use crate::error::CommandError;
9
10/// Configuration for the commands client.
11pub struct CommandsClientConfig {
12    /// Command timeout (default: 60s)
13    pub timeout: Duration,
14    /// Polling interval (default: 500ms)
15    pub poll_interval: Duration,
16    /// Max polling interval (default: 5s)
17    pub max_poll_interval: Duration,
18    /// Backoff multiplier (default: 1.5)
19    pub poll_backoff: f64,
20    /// Allow local file:// storage backends (dev only)
21    pub allow_local_storage: bool,
22}
23
24impl Default for CommandsClientConfig {
25    fn default() -> Self {
26        Self {
27            timeout: Duration::from_secs(60),
28            poll_interval: Duration::from_millis(500),
29            max_poll_interval: Duration::from_secs(5),
30            poll_backoff: 1.5,
31            allow_local_storage: false,
32        }
33    }
34}
35
36/// Options for a single invoke call.
37pub struct InvokeOptions {
38    /// Override the default timeout for this invocation.
39    pub timeout: Option<Duration>,
40    /// Set a deadline for the command (server-side expiry).
41    pub deadline: Option<DateTime<Utc>>,
42    /// Idempotency key to prevent duplicate commands.
43    pub idempotency_key: Option<String>,
44}
45
46/// High-level client for invoking commands on Alien deployments.
47pub struct CommandsClient {
48    manager_url: String,
49    deployment_id: String,
50    http_client: reqwest::Client,
51    config: CommandsClientConfig,
52}
53
54// -- API response types (internal) --
55
56#[derive(Deserialize)]
57#[serde(rename_all = "camelCase")]
58struct CreateCommandResponse {
59    command_id: String,
60}
61
62#[derive(Deserialize)]
63#[serde(rename_all = "camelCase")]
64struct CommandStatusResponse {
65    state: String,
66    #[serde(default)]
67    response: Option<CommandResponseBody>,
68}
69
70#[derive(Deserialize)]
71#[serde(rename_all = "camelCase")]
72struct CommandResponseBody {
73    #[serde(default)]
74    response: Option<BodySpecResponse>,
75    #[serde(default)]
76    code: Option<String>,
77    #[serde(default)]
78    message: Option<String>,
79}
80
81#[derive(Deserialize)]
82#[serde(rename_all = "camelCase")]
83struct BodySpecResponse {
84    mode: String,
85    #[serde(default)]
86    inline_base64: Option<String>,
87    #[serde(default)]
88    storage_get_request: Option<StorageGetRequest>,
89}
90
91#[derive(Deserialize)]
92#[serde(rename_all = "camelCase")]
93struct StorageGetRequest {
94    backend: StorageBackend,
95}
96
97#[derive(Deserialize)]
98#[serde(rename_all = "camelCase")]
99struct StorageBackend {
100    #[serde(rename = "type")]
101    backend_type: String,
102    #[serde(default)]
103    url: Option<String>,
104    #[serde(default)]
105    method: Option<String>,
106    #[serde(default)]
107    headers: Option<std::collections::HashMap<String, String>>,
108    #[serde(default, rename = "filePath")]
109    file_path: Option<String>,
110}
111
112impl CommandsClient {
113    /// Create a new commands client with default config.
114    pub fn new(manager_url: &str, deployment_id: &str, token: &str) -> Self {
115        Self::with_config(
116            manager_url,
117            deployment_id,
118            token,
119            CommandsClientConfig::default(),
120        )
121    }
122
123    /// Create a new commands client with custom config.
124    pub fn with_config(
125        manager_url: &str,
126        deployment_id: &str,
127        token: &str,
128        config: CommandsClientConfig,
129    ) -> Self {
130        let mut headers = reqwest::header::HeaderMap::new();
131        headers.insert(
132            reqwest::header::AUTHORIZATION,
133            reqwest::header::HeaderValue::from_str(&format!("Bearer {}", token))
134                .expect("invalid token"),
135        );
136
137        let http_client = reqwest::Client::builder()
138            .default_headers(headers)
139            .build()
140            .expect("failed to build HTTP client");
141
142        Self {
143            manager_url: manager_url.trim_end_matches('/').to_string(),
144            deployment_id: deployment_id.to_string(),
145            http_client,
146            config,
147        }
148    }
149
150    /// Build a client over a caller-supplied HTTP client, reusing the headers
151    /// it already carries (the auth header, and the workspace header used in
152    /// platform mode). `with_config` builds a token-only client and can't add
153    /// those.
154    pub fn with_http_client(
155        manager_url: &str,
156        deployment_id: &str,
157        http_client: reqwest::Client,
158        config: CommandsClientConfig,
159    ) -> Self {
160        Self {
161            manager_url: manager_url.trim_end_matches('/').to_string(),
162            deployment_id: deployment_id.to_string(),
163            http_client,
164            config,
165        }
166    }
167
168    /// Invoke a command and wait for the result.
169    ///
170    /// Sends params inline, polls for completion, and decodes the response.
171    pub async fn invoke<P: Serialize, R: DeserializeOwned>(
172        &self,
173        command: &str,
174        params: P,
175    ) -> Result<R, CommandError> {
176        self.invoke_with_options(command, params, None).await
177    }
178
179    /// Invoke a command with options and wait for the result.
180    pub async fn invoke_with_options<P: Serialize, R: DeserializeOwned>(
181        &self,
182        command: &str,
183        params: P,
184        options: Option<InvokeOptions>,
185    ) -> Result<R, CommandError> {
186        let timeout = options
187            .as_ref()
188            .and_then(|o| o.timeout)
189            .unwrap_or(self.config.timeout);
190
191        // Step 1: Create the command (always inline — server handles storage)
192        let command_id = self.create(command, params, options.as_ref()).await?;
193
194        debug!(command_id = %command_id, command = %command, "Command created, polling for result");
195
196        // Step 2: Poll for completion with exponential backoff
197        let start = tokio::time::Instant::now();
198        let mut interval = self.config.poll_interval;
199
200        loop {
201            if start.elapsed() > timeout {
202                return Err(CommandError::Timeout {
203                    command_id,
204                    last_state: "polling".to_string(),
205                });
206            }
207
208            tokio::time::sleep(interval).await;
209
210            let status = self.get_status(&command_id).await?;
211
212            match status.state.as_str() {
213                "SUCCEEDED" => {
214                    return self.decode_response(&command_id, status.response).await;
215                }
216                "FAILED" => {
217                    let (code, message) = status
218                        .response
219                        .as_ref()
220                        .map(|r| {
221                            (
222                                r.code.clone().unwrap_or_default(),
223                                r.message.clone().unwrap_or_default(),
224                            )
225                        })
226                        .unwrap_or_default();
227                    return Err(CommandError::DeploymentError {
228                        command_id,
229                        code,
230                        message,
231                    });
232                }
233                "EXPIRED" => {
234                    return Err(CommandError::Expired { command_id });
235                }
236                _ => {
237                    // Still in progress — backoff
238                    interval = Duration::from_secs_f64(
239                        (interval.as_secs_f64() * self.config.poll_backoff)
240                            .min(self.config.max_poll_interval.as_secs_f64()),
241                    );
242                }
243            }
244        }
245    }
246
247    /// Create a command without waiting for the result. Returns the command ID.
248    pub async fn create<P: Serialize>(
249        &self,
250        command: &str,
251        params: P,
252        options: Option<&InvokeOptions>,
253    ) -> Result<String, CommandError> {
254        let params_json = serde_json::to_vec(&params)?;
255        let params_base64 = general_purpose::STANDARD.encode(&params_json);
256
257        let mut body = serde_json::json!({
258            "deploymentId": self.deployment_id,
259            "command": command,
260            "params": {
261                "mode": "inline",
262                "inlineBase64": params_base64,
263            },
264        });
265
266        if let Some(opts) = options {
267            if let Some(deadline) = opts.deadline {
268                body["deadline"] = serde_json::Value::String(deadline.to_rfc3339());
269            }
270            if let Some(ref key) = opts.idempotency_key {
271                body["idempotencyKey"] = serde_json::Value::String(key.clone());
272            }
273        }
274
275        let url = format!("{}/commands", self.manager_url);
276        let resp = self.http_client.post(&url).json(&body).send().await?;
277
278        if !resp.status().is_success() {
279            let status = resp.status().as_u16();
280            let body = resp.text().await.unwrap_or_default();
281            return Err(CommandError::CreationFailed { status, body });
282        }
283
284        let result: CreateCommandResponse = resp.json().await?;
285        Ok(result.command_id)
286    }
287
288    /// Poll for a command's status.
289    async fn get_status(&self, command_id: &str) -> Result<CommandStatusResponse, CommandError> {
290        let url = format!("{}/commands/{}", self.manager_url, command_id);
291        let resp = self.http_client.get(&url).send().await?;
292
293        if !resp.status().is_success() {
294            let status = resp.status().as_u16();
295            let body = resp.text().await.unwrap_or_default();
296            return Err(CommandError::CreationFailed { status, body });
297        }
298
299        Ok(resp.json().await?)
300    }
301
302    // -- Internal helpers --
303
304    async fn decode_response<R: DeserializeOwned>(
305        &self,
306        command_id: &str,
307        response: Option<CommandResponseBody>,
308    ) -> Result<R, CommandError> {
309        let resp = response.ok_or_else(|| CommandError::ResponseDecodingFailed {
310            command_id: command_id.to_string(),
311            reason: "No response body in SUCCEEDED status".to_string(),
312        })?;
313
314        let body = resp
315            .response
316            .ok_or_else(|| CommandError::ResponseDecodingFailed {
317                command_id: command_id.to_string(),
318                reason: "No response field in success response".to_string(),
319            })?;
320
321        let bytes = match body.mode.as_str() {
322            "inline" => {
323                let base64_data =
324                    body.inline_base64
325                        .ok_or_else(|| CommandError::ResponseDecodingFailed {
326                            command_id: command_id.to_string(),
327                            reason: "Inline response missing inlineBase64 field".to_string(),
328                        })?;
329
330                general_purpose::STANDARD
331                    .decode(&base64_data)
332                    .map_err(|e| CommandError::ResponseDecodingFailed {
333                        command_id: command_id.to_string(),
334                        reason: format!("Base64 decode failed: {}", e),
335                    })?
336            }
337            "storage" => {
338                let get_request = body.storage_get_request.ok_or_else(|| {
339                    CommandError::ResponseDecodingFailed {
340                        command_id: command_id.to_string(),
341                        reason: "Storage response missing storageGetRequest".to_string(),
342                    }
343                })?;
344
345                self.download_from_storage(&get_request).await?
346            }
347            other => {
348                return Err(CommandError::ResponseDecodingFailed {
349                    command_id: command_id.to_string(),
350                    reason: format!("Unknown response mode: {}", other),
351                })
352            }
353        };
354
355        serde_json::from_slice(&bytes).map_err(|e| CommandError::ResponseDecodingFailed {
356            command_id: command_id.to_string(),
357            reason: format!("JSON decode failed: {}", e),
358        })
359    }
360
361    async fn download_from_storage(
362        &self,
363        get_request: &StorageGetRequest,
364    ) -> Result<Vec<u8>, CommandError> {
365        match get_request.backend.backend_type.as_str() {
366            "http" => {
367                let url = get_request.backend.url.as_deref().ok_or_else(|| {
368                    CommandError::StorageOperationFailed {
369                        reason: "HTTP storage backend missing url".to_string(),
370                    }
371                })?;
372
373                let method = get_request.backend.method.as_deref().unwrap_or("GET");
374
375                // Use a plain client (no auth headers — presigned URL carries auth)
376                let plain_http = reqwest::Client::new();
377                let mut req = match method {
378                    "PUT" => plain_http.put(url),
379                    "POST" => plain_http.post(url),
380                    _ => plain_http.get(url),
381                };
382
383                if let Some(headers) = &get_request.backend.headers {
384                    for (k, v) in headers {
385                        req = req.header(k.as_str(), v.as_str());
386                    }
387                }
388
389                let resp = req
390                    .send()
391                    .await
392                    .map_err(|e| CommandError::StorageOperationFailed {
393                        reason: format!("Storage download failed: {}", e),
394                    })?;
395
396                if !resp.status().is_success() {
397                    return Err(CommandError::StorageOperationFailed {
398                        reason: format!("Storage download returned HTTP {}", resp.status()),
399                    });
400                }
401
402                resp.bytes().await.map(|b| b.to_vec()).map_err(|e| {
403                    CommandError::StorageOperationFailed {
404                        reason: format!("Failed to read storage response bytes: {}", e),
405                    }
406                })
407            }
408            "local" if self.config.allow_local_storage => {
409                let file_path = get_request.backend.file_path.as_deref().ok_or_else(|| {
410                    CommandError::StorageOperationFailed {
411                        reason: "Local storage backend missing filePath".to_string(),
412                    }
413                })?;
414
415                let path = std::path::Path::new(file_path);
416                if path.is_absolute() || file_path.contains("..") {
417                    return Err(CommandError::StorageOperationFailed {
418                        reason: "Local storage path traversal detected".to_string(),
419                    });
420                }
421
422                tokio::fs::read(file_path)
423                    .await
424                    .map_err(|e| CommandError::StorageOperationFailed {
425                        reason: format!("Failed to read local file {}: {}", file_path, e),
426                    })
427            }
428            "local" => Err(CommandError::StorageOperationFailed {
429                reason: "Local storage backend not allowed (set allow_local_storage: true)"
430                    .to_string(),
431            }),
432            other => Err(CommandError::StorageOperationFailed {
433                reason: format!("Unknown storage backend type: {}", other),
434            }),
435        }
436    }
437}