Skip to main content

zlayer_init_actions/
actions.rs

1//! Built-in init actions
2
3use crate::error::{InitError, Result};
4use std::collections::HashMap;
5use std::time::Duration;
6use tokio::process::Command;
7use tokio::time::{sleep, timeout};
8
9/// Wait for a TCP port to be open
10pub struct WaitTcp {
11    pub host: String,
12    pub port: u16,
13    pub timeout: Duration,
14    pub interval: Duration,
15}
16
17impl WaitTcp {
18    /// # Errors
19    /// Returns `InitError::TcpFailed` if the connection times out.
20    pub async fn execute(&self) -> Result<()> {
21        let start = std::time::Instant::now();
22
23        loop {
24            if tokio::net::TcpStream::connect(&format!("{}:{}", self.host, self.port))
25                .await
26                .is_ok()
27            {
28                return Ok(());
29            }
30
31            if start.elapsed() >= self.timeout {
32                return Err(InitError::TcpFailed {
33                    host: self.host.clone(),
34                    port: self.port,
35                    reason: format!("timeout after {:?}", self.timeout),
36                });
37            }
38
39            sleep(self.interval).await;
40        }
41    }
42}
43
44/// Wait for an HTTP endpoint to respond
45pub struct WaitHttp {
46    pub url: String,
47    pub expect_status: Option<u16>,
48    pub timeout: Duration,
49    pub interval: Duration,
50}
51
52impl WaitHttp {
53    /// # Errors
54    /// Returns `InitError::HttpFailed` if the request times out or the expected status is not received.
55    pub async fn execute(&self) -> Result<()> {
56        let start = std::time::Instant::now();
57        let client = reqwest::Client::builder()
58            .timeout(Duration::from_secs(5))
59            .build()
60            .map_err(|e| InitError::HttpFailed {
61                url: self.url.clone(),
62                reason: format!("failed to create client: {e}"),
63            })?;
64
65        loop {
66            let response = client.get(&self.url).send().await;
67
68            if let Ok(resp) = response {
69                let status = resp.status().as_u16();
70
71                if let Some(expected) = self.expect_status {
72                    if status == expected {
73                        return Ok(());
74                    }
75                } else if (200..300).contains(&status) {
76                    return Ok(());
77                }
78            }
79
80            if start.elapsed() >= self.timeout {
81                return Err(InitError::HttpFailed {
82                    url: self.url.clone(),
83                    reason: format!("timeout after {:?}", self.timeout),
84                });
85            }
86
87            sleep(self.interval).await;
88        }
89    }
90}
91
92/// Run a shell command
93pub struct RunCommand {
94    pub command: String,
95    pub timeout: Duration,
96}
97
98impl RunCommand {
99    /// # Errors
100    /// Returns an error if the command fails, exits non-zero, or times out.
101    pub async fn execute(&self) -> Result<()> {
102        match timeout(
103            self.timeout,
104            Command::new("sh").arg("-c").arg(&self.command).output(),
105        )
106        .await
107        {
108            Ok(Ok(output)) => {
109                if output.status.success() {
110                    Ok(())
111                } else {
112                    Err(InitError::CommandFailed {
113                        command: self.command.clone(),
114                        code: output.status.code().unwrap_or(-1),
115                        stdout: String::from_utf8_lossy(&output.stdout).to_string(),
116                        stderr: String::from_utf8_lossy(&output.stderr).to_string(),
117                    })
118                }
119            }
120            Ok(Err(_)) => Err(InitError::CommandFailed {
121                command: self.command.clone(),
122                code: -1,
123                stdout: String::new(),
124                stderr: "timeout".to_string(),
125            }),
126            Err(_) => Err(InitError::Timeout {
127                timeout: self.timeout,
128            }),
129        }
130    }
131}
132
133/// Push files to S3 from a local path
134#[cfg(feature = "s3")]
135pub struct S3Push {
136    /// Local source path (file or directory)
137    pub source: String,
138    /// S3 bucket name
139    pub bucket: String,
140    /// S3 key prefix
141    pub key: String,
142    /// Custom S3 endpoint (for S3-compatible services)
143    pub endpoint: Option<String>,
144    /// Region
145    pub region: Option<String>,
146    /// Upload timeout
147    pub timeout: Duration,
148}
149
150#[cfg(feature = "s3")]
151impl S3Push {
152    /// Execute the S3 push action, uploading files to the configured bucket.
153    ///
154    /// # Errors
155    ///
156    /// Returns an error if the AWS SDK configuration fails, the S3 client
157    /// cannot be created, or any file upload fails.
158    pub async fn execute(&self) -> Result<()> {
159        use aws_sdk_s3::Client;
160
161        // Build AWS config
162        let mut config_loader = aws_config::defaults(aws_config::BehaviorVersion::latest());
163        if let Some(ref region) = self.region {
164            config_loader = config_loader.region(aws_config::Region::new(region.clone()));
165        }
166        let sdk_config = config_loader.load().await;
167
168        // Build S3 client
169        let mut s3_config = aws_sdk_s3::config::Builder::from(&sdk_config);
170        if let Some(ref endpoint) = self.endpoint {
171            s3_config = s3_config.endpoint_url(endpoint).force_path_style(true);
172        }
173        let client = Client::from_conf(s3_config.build());
174
175        let source_path = std::path::Path::new(&self.source);
176
177        if source_path.is_file() {
178            // Upload single file
179            self.upload_file(&client, source_path, &self.key).await?;
180        } else if source_path.is_dir() {
181            // Upload directory recursively
182            self.upload_directory(&client, source_path, &self.key)
183                .await?;
184        } else {
185            return Err(InitError::S3Failed {
186                bucket: self.bucket.clone(),
187                key: self.key.clone(),
188                reason: format!("source path '{}' does not exist", self.source),
189            });
190        }
191
192        Ok(())
193    }
194
195    #[cfg(feature = "s3")]
196    async fn upload_file(
197        &self,
198        client: &aws_sdk_s3::Client,
199        path: &std::path::Path,
200        key: &str,
201    ) -> Result<()> {
202        use aws_sdk_s3::primitives::ByteStream;
203
204        tracing::info!(
205            bucket = %self.bucket,
206            key = %key,
207            source = %path.display(),
208            "pushing file to S3"
209        );
210
211        let data = tokio::fs::read(path)
212            .await
213            .map_err(|e| InitError::S3Failed {
214                bucket: self.bucket.clone(),
215                key: key.to_string(),
216                reason: format!("failed to read file: {e}"),
217            })?;
218
219        tokio::time::timeout(
220            self.timeout,
221            client
222                .put_object()
223                .bucket(&self.bucket)
224                .key(key)
225                .body(ByteStream::from(data))
226                .content_type("application/octet-stream")
227                .send(),
228        )
229        .await
230        .map_err(|_| InitError::Timeout {
231            timeout: self.timeout,
232        })?
233        .map_err(|e| InitError::S3Failed {
234            bucket: self.bucket.clone(),
235            key: key.to_string(),
236            reason: format!("put_object failed: {e}"),
237        })?;
238
239        tracing::info!(bucket = %self.bucket, key = %key, "S3 push complete");
240        Ok(())
241    }
242
243    #[cfg(feature = "s3")]
244    async fn upload_directory(
245        &self,
246        client: &aws_sdk_s3::Client,
247        dir: &std::path::Path,
248        prefix: &str,
249    ) -> Result<()> {
250        let mut entries = tokio::fs::read_dir(dir)
251            .await
252            .map_err(|e| InitError::S3Failed {
253                bucket: self.bucket.clone(),
254                key: prefix.to_string(),
255                reason: format!("failed to read directory: {e}"),
256            })?;
257
258        while let Some(entry) = entries
259            .next_entry()
260            .await
261            .map_err(|e| InitError::S3Failed {
262                bucket: self.bucket.clone(),
263                key: prefix.to_string(),
264                reason: format!("failed to read directory entry: {e}"),
265            })?
266        {
267            let path = entry.path();
268            let file_name = entry.file_name();
269            let key = format!(
270                "{}/{}",
271                prefix.trim_end_matches('/'),
272                file_name.to_string_lossy()
273            );
274
275            if path.is_file() {
276                self.upload_file(client, &path, &key).await?;
277            } else if path.is_dir() {
278                // Use Box::pin for recursive async
279                Box::pin(self.upload_directory(client, &path, &key)).await?;
280            }
281        }
282
283        Ok(())
284    }
285}
286
287/// Pull files from S3 to a local path
288#[cfg(feature = "s3")]
289pub struct S3Pull {
290    /// S3 bucket name
291    pub bucket: String,
292    /// S3 key or prefix to download
293    pub key: String,
294    /// Local destination path
295    pub destination: String,
296    /// Custom S3 endpoint (for S3-compatible services)
297    pub endpoint: Option<String>,
298    /// Region
299    pub region: Option<String>,
300    /// Download timeout
301    pub timeout: Duration,
302}
303
304#[cfg(feature = "s3")]
305impl S3Pull {
306    /// Execute the S3 pull action, downloading files from the configured bucket.
307    ///
308    /// # Errors
309    ///
310    /// Returns an error if the AWS SDK configuration fails, the S3 client
311    /// cannot be created, or any file download fails.
312    pub async fn execute(&self) -> Result<()> {
313        use aws_sdk_s3::Client;
314        use tokio::io::AsyncWriteExt;
315
316        // Build AWS config
317        let mut config_loader = aws_config::defaults(aws_config::BehaviorVersion::latest());
318        if let Some(ref region) = self.region {
319            config_loader = config_loader.region(aws_config::Region::new(region.clone()));
320        }
321        let sdk_config = config_loader.load().await;
322
323        // Build S3 client
324        let mut s3_config = aws_sdk_s3::config::Builder::from(&sdk_config);
325        if let Some(ref endpoint) = self.endpoint {
326            s3_config = s3_config.endpoint_url(endpoint).force_path_style(true);
327        }
328        let client = Client::from_conf(s3_config.build());
329
330        tracing::info!(
331            bucket = %self.bucket,
332            key = %self.key,
333            destination = %self.destination,
334            "pulling from S3"
335        );
336
337        // Get object from S3
338        let result = tokio::time::timeout(
339            self.timeout,
340            client
341                .get_object()
342                .bucket(&self.bucket)
343                .key(&self.key)
344                .send(),
345        )
346        .await
347        .map_err(|_| InitError::Timeout {
348            timeout: self.timeout,
349        })?
350        .map_err(|e| InitError::S3Failed {
351            bucket: self.bucket.clone(),
352            key: self.key.clone(),
353            reason: format!("get_object failed: {e}"),
354        })?;
355
356        // Read body
357        let data = result
358            .body
359            .collect()
360            .await
361            .map_err(|e| InitError::S3Failed {
362                bucket: self.bucket.clone(),
363                key: self.key.clone(),
364                reason: format!("failed to read body: {e}"),
365            })?
366            .into_bytes();
367
368        // Write to destination
369        let dest_path = std::path::Path::new(&self.destination);
370        if let Some(parent) = dest_path.parent() {
371            tokio::fs::create_dir_all(parent)
372                .await
373                .map_err(|e| InitError::S3Failed {
374                    bucket: self.bucket.clone(),
375                    key: self.key.clone(),
376                    reason: format!("failed to create destination directory: {e}"),
377                })?;
378        }
379
380        let mut file = tokio::fs::File::create(&self.destination)
381            .await
382            .map_err(|e| InitError::S3Failed {
383                bucket: self.bucket.clone(),
384                key: self.key.clone(),
385                reason: format!("failed to create file: {e}"),
386            })?;
387
388        file.write_all(&data)
389            .await
390            .map_err(|e| InitError::S3Failed {
391                bucket: self.bucket.clone(),
392                key: self.key.clone(),
393                reason: format!("failed to write file: {e}"),
394            })?;
395
396        tracing::info!(
397            bucket = %self.bucket,
398            key = %self.key,
399            bytes = data.len(),
400            "S3 pull complete"
401        );
402
403        Ok(())
404    }
405}
406
407/// Create an init action from the spec
408///
409/// # Errors
410/// Returns `InitError::InvalidParams` if required parameters are missing or invalid,
411/// or `InitError::UnknownAction` if the action type is not recognized.
412#[allow(clippy::too_many_lines, clippy::implicit_hasher)]
413pub fn from_spec(
414    action: &str,
415    params: &HashMap<String, serde_json::Value>,
416    _default_timeout: Duration,
417) -> Result<InitAction> {
418    match action {
419        "init.wait_tcp" => {
420            let host = params
421                .get("host")
422                .and_then(|v| v.as_str())
423                .ok_or_else(|| InitError::InvalidParams {
424                    action: action.to_string(),
425                    reason: "missing 'host' parameter".to_string(),
426                })?
427                .to_string();
428
429            #[allow(clippy::cast_possible_truncation)]
430            let port = params
431                .get("port")
432                .and_then(serde_json::Value::as_u64)
433                .ok_or_else(|| InitError::InvalidParams {
434                    action: action.to_string(),
435                    reason: "missing or invalid 'port' parameter".to_string(),
436                })? as u16;
437
438            let timeout_secs = params
439                .get("timeout")
440                .and_then(serde_json::Value::as_u64)
441                .unwrap_or(30);
442
443            Ok(InitAction::WaitTcp(WaitTcp {
444                host,
445                port,
446                timeout: Duration::from_secs(timeout_secs),
447                interval: Duration::from_secs(2),
448            }))
449        }
450
451        "init.wait_http" => {
452            let url = params
453                .get("url")
454                .and_then(|v| v.as_str())
455                .ok_or_else(|| InitError::InvalidParams {
456                    action: action.to_string(),
457                    reason: "missing 'url' parameter".to_string(),
458                })?
459                .to_string();
460
461            #[allow(clippy::cast_possible_truncation)]
462            let expect_status = params
463                .get("expect_status")
464                .and_then(serde_json::Value::as_u64)
465                .map(|v| v as u16);
466
467            let timeout_secs = params
468                .get("timeout")
469                .and_then(serde_json::Value::as_u64)
470                .unwrap_or(30);
471
472            Ok(InitAction::WaitHttp(WaitHttp {
473                url,
474                expect_status,
475                timeout: Duration::from_secs(timeout_secs),
476                interval: Duration::from_secs(2),
477            }))
478        }
479
480        "init.run" => {
481            let command = params
482                .get("command")
483                .and_then(|v| v.as_str())
484                .ok_or_else(|| InitError::InvalidParams {
485                    action: action.to_string(),
486                    reason: "missing 'command' parameter".to_string(),
487                })?
488                .to_string();
489
490            let timeout_secs = params
491                .get("timeout")
492                .and_then(serde_json::Value::as_u64)
493                .unwrap_or(300);
494
495            Ok(InitAction::Run(RunCommand {
496                command,
497                timeout: Duration::from_secs(timeout_secs),
498            }))
499        }
500
501        #[cfg(feature = "s3")]
502        "init.s3_push" => {
503            let source = params
504                .get("source")
505                .and_then(|v| v.as_str())
506                .ok_or_else(|| InitError::InvalidParams {
507                    action: action.to_string(),
508                    reason: "missing 'source' parameter".to_string(),
509                })?
510                .to_string();
511
512            let bucket = params
513                .get("bucket")
514                .and_then(|v| v.as_str())
515                .ok_or_else(|| InitError::InvalidParams {
516                    action: action.to_string(),
517                    reason: "missing 'bucket' parameter".to_string(),
518                })?
519                .to_string();
520
521            let key = params
522                .get("key")
523                .and_then(|v| v.as_str())
524                .ok_or_else(|| InitError::InvalidParams {
525                    action: action.to_string(),
526                    reason: "missing 'key' parameter".to_string(),
527                })?
528                .to_string();
529
530            let endpoint = params
531                .get("endpoint")
532                .and_then(|v| v.as_str())
533                .map(String::from);
534            let region = params
535                .get("region")
536                .and_then(|v| v.as_str())
537                .map(String::from);
538            let timeout_secs = params
539                .get("timeout")
540                .and_then(serde_json::Value::as_u64)
541                .unwrap_or(300);
542
543            Ok(InitAction::S3Push(S3Push {
544                source,
545                bucket,
546                key,
547                endpoint,
548                region,
549                timeout: Duration::from_secs(timeout_secs),
550            }))
551        }
552
553        #[cfg(feature = "s3")]
554        "init.s3_pull" => {
555            let bucket = params
556                .get("bucket")
557                .and_then(|v| v.as_str())
558                .ok_or_else(|| InitError::InvalidParams {
559                    action: action.to_string(),
560                    reason: "missing 'bucket' parameter".to_string(),
561                })?
562                .to_string();
563
564            let key = params
565                .get("key")
566                .and_then(|v| v.as_str())
567                .ok_or_else(|| InitError::InvalidParams {
568                    action: action.to_string(),
569                    reason: "missing 'key' parameter".to_string(),
570                })?
571                .to_string();
572
573            let destination = params
574                .get("destination")
575                .and_then(|v| v.as_str())
576                .ok_or_else(|| InitError::InvalidParams {
577                    action: action.to_string(),
578                    reason: "missing 'destination' parameter".to_string(),
579                })?
580                .to_string();
581
582            let endpoint = params
583                .get("endpoint")
584                .and_then(|v| v.as_str())
585                .map(String::from);
586            let region = params
587                .get("region")
588                .and_then(|v| v.as_str())
589                .map(String::from);
590            let timeout_secs = params
591                .get("timeout")
592                .and_then(serde_json::Value::as_u64)
593                .unwrap_or(300);
594
595            Ok(InitAction::S3Pull(S3Pull {
596                bucket,
597                key,
598                destination,
599                endpoint,
600                region,
601                timeout: Duration::from_secs(timeout_secs),
602            }))
603        }
604
605        _ => Err(InitError::UnknownAction(action.to_string())),
606    }
607}
608
609/// Enum of all init actions
610pub enum InitAction {
611    WaitTcp(WaitTcp),
612    WaitHttp(WaitHttp),
613    Run(RunCommand),
614    #[cfg(feature = "s3")]
615    S3Push(S3Push),
616    #[cfg(feature = "s3")]
617    S3Pull(S3Pull),
618}
619
620impl InitAction {
621    /// # Errors
622    /// Returns an error if the underlying action fails.
623    pub async fn execute(&self) -> Result<()> {
624        match self {
625            InitAction::WaitTcp(a) => a.execute().await,
626            InitAction::WaitHttp(a) => a.execute().await,
627            InitAction::Run(a) => a.execute().await,
628            #[cfg(feature = "s3")]
629            InitAction::S3Push(a) => a.execute().await,
630            #[cfg(feature = "s3")]
631            InitAction::S3Pull(a) => a.execute().await,
632        }
633    }
634}
635
636#[cfg(test)]
637mod tests {
638    use super::*;
639
640    #[tokio::test]
641    async fn test_run_command_success() {
642        let action = RunCommand {
643            command: "echo hello".to_string(),
644            timeout: Duration::from_secs(5),
645        };
646        action.execute().await.unwrap();
647    }
648
649    #[tokio::test]
650    async fn test_run_command_failure() {
651        let action = RunCommand {
652            command: "exit 1".to_string(),
653            timeout: Duration::from_secs(5),
654        };
655        let result = action.execute().await;
656        assert!(result.is_err());
657    }
658
659    #[test]
660    fn test_from_spec_wait_tcp() {
661        let mut params = HashMap::new();
662        params.insert("host".to_string(), serde_json::json!("localhost"));
663        params.insert("port".to_string(), serde_json::json!(8080));
664
665        let action = from_spec("init.wait_tcp", &params, Duration::from_secs(30)).unwrap();
666        match action {
667            InitAction::WaitTcp(_) => {}
668            _ => panic!("Expected WaitTcp action"),
669        }
670    }
671
672    #[test]
673    fn test_from_spec_unknown() {
674        let params = HashMap::new();
675        let result = from_spec("unknown.action", &params, Duration::from_secs(30));
676        assert!(result.is_err());
677    }
678}