1use crate::error::{InitError, Result};
4use std::collections::HashMap;
5use std::time::Duration;
6use tokio::process::Command;
7use tokio::time::{sleep, timeout};
8
9pub struct WaitTcp {
11 pub host: String,
12 pub port: u16,
13 pub timeout: Duration,
14 pub interval: Duration,
15}
16
17impl WaitTcp {
18 pub async fn execute(&self) -> Result<()> {
21 let start = std::time::Instant::now();
22
23 loop {
24 if tokio::net::TcpStream::connect(&format!("{}:{}", self.host, self.port))
25 .await
26 .is_ok()
27 {
28 return Ok(());
29 }
30
31 if start.elapsed() >= self.timeout {
32 return Err(InitError::TcpFailed {
33 host: self.host.clone(),
34 port: self.port,
35 reason: format!("timeout after {:?}", self.timeout),
36 });
37 }
38
39 sleep(self.interval).await;
40 }
41 }
42}
43
44pub struct WaitHttp {
46 pub url: String,
47 pub expect_status: Option<u16>,
48 pub timeout: Duration,
49 pub interval: Duration,
50}
51
52impl WaitHttp {
53 pub async fn execute(&self) -> Result<()> {
56 let start = std::time::Instant::now();
57 let client = reqwest::Client::builder()
58 .timeout(Duration::from_secs(5))
59 .build()
60 .map_err(|e| InitError::HttpFailed {
61 url: self.url.clone(),
62 reason: format!("failed to create client: {e}"),
63 })?;
64
65 loop {
66 let response = client.get(&self.url).send().await;
67
68 if let Ok(resp) = response {
69 let status = resp.status().as_u16();
70
71 if let Some(expected) = self.expect_status {
72 if status == expected {
73 return Ok(());
74 }
75 } else if (200..300).contains(&status) {
76 return Ok(());
77 }
78 }
79
80 if start.elapsed() >= self.timeout {
81 return Err(InitError::HttpFailed {
82 url: self.url.clone(),
83 reason: format!("timeout after {:?}", self.timeout),
84 });
85 }
86
87 sleep(self.interval).await;
88 }
89 }
90}
91
92pub struct RunCommand {
94 pub command: String,
95 pub timeout: Duration,
96}
97
98impl RunCommand {
99 pub async fn execute(&self) -> Result<()> {
102 match timeout(
103 self.timeout,
104 Command::new("sh").arg("-c").arg(&self.command).output(),
105 )
106 .await
107 {
108 Ok(Ok(output)) => {
109 if output.status.success() {
110 Ok(())
111 } else {
112 Err(InitError::CommandFailed {
113 command: self.command.clone(),
114 code: output.status.code().unwrap_or(-1),
115 stdout: String::from_utf8_lossy(&output.stdout).to_string(),
116 stderr: String::from_utf8_lossy(&output.stderr).to_string(),
117 })
118 }
119 }
120 Ok(Err(_)) => Err(InitError::CommandFailed {
121 command: self.command.clone(),
122 code: -1,
123 stdout: String::new(),
124 stderr: "timeout".to_string(),
125 }),
126 Err(_) => Err(InitError::Timeout {
127 timeout: self.timeout,
128 }),
129 }
130 }
131}
132
133#[cfg(feature = "s3")]
135pub struct S3Push {
136 pub source: String,
138 pub bucket: String,
140 pub key: String,
142 pub endpoint: Option<String>,
144 pub region: Option<String>,
146 pub timeout: Duration,
148}
149
150#[cfg(feature = "s3")]
151impl S3Push {
152 pub async fn execute(&self) -> Result<()> {
159 use aws_sdk_s3::Client;
160
161 let mut config_loader = aws_config::defaults(aws_config::BehaviorVersion::latest());
163 if let Some(ref region) = self.region {
164 config_loader = config_loader.region(aws_config::Region::new(region.clone()));
165 }
166 let sdk_config = config_loader.load().await;
167
168 let mut s3_config = aws_sdk_s3::config::Builder::from(&sdk_config);
170 if let Some(ref endpoint) = self.endpoint {
171 s3_config = s3_config.endpoint_url(endpoint).force_path_style(true);
172 }
173 let client = Client::from_conf(s3_config.build());
174
175 let source_path = std::path::Path::new(&self.source);
176
177 if source_path.is_file() {
178 self.upload_file(&client, source_path, &self.key).await?;
180 } else if source_path.is_dir() {
181 self.upload_directory(&client, source_path, &self.key)
183 .await?;
184 } else {
185 return Err(InitError::S3Failed {
186 bucket: self.bucket.clone(),
187 key: self.key.clone(),
188 reason: format!("source path '{}' does not exist", self.source),
189 });
190 }
191
192 Ok(())
193 }
194
195 #[cfg(feature = "s3")]
196 async fn upload_file(
197 &self,
198 client: &aws_sdk_s3::Client,
199 path: &std::path::Path,
200 key: &str,
201 ) -> Result<()> {
202 use aws_sdk_s3::primitives::ByteStream;
203
204 tracing::info!(
205 bucket = %self.bucket,
206 key = %key,
207 source = %path.display(),
208 "pushing file to S3"
209 );
210
211 let data = tokio::fs::read(path)
212 .await
213 .map_err(|e| InitError::S3Failed {
214 bucket: self.bucket.clone(),
215 key: key.to_string(),
216 reason: format!("failed to read file: {e}"),
217 })?;
218
219 tokio::time::timeout(
220 self.timeout,
221 client
222 .put_object()
223 .bucket(&self.bucket)
224 .key(key)
225 .body(ByteStream::from(data))
226 .content_type("application/octet-stream")
227 .send(),
228 )
229 .await
230 .map_err(|_| InitError::Timeout {
231 timeout: self.timeout,
232 })?
233 .map_err(|e| InitError::S3Failed {
234 bucket: self.bucket.clone(),
235 key: key.to_string(),
236 reason: format!("put_object failed: {e}"),
237 })?;
238
239 tracing::info!(bucket = %self.bucket, key = %key, "S3 push complete");
240 Ok(())
241 }
242
243 #[cfg(feature = "s3")]
244 async fn upload_directory(
245 &self,
246 client: &aws_sdk_s3::Client,
247 dir: &std::path::Path,
248 prefix: &str,
249 ) -> Result<()> {
250 let mut entries = tokio::fs::read_dir(dir)
251 .await
252 .map_err(|e| InitError::S3Failed {
253 bucket: self.bucket.clone(),
254 key: prefix.to_string(),
255 reason: format!("failed to read directory: {e}"),
256 })?;
257
258 while let Some(entry) = entries
259 .next_entry()
260 .await
261 .map_err(|e| InitError::S3Failed {
262 bucket: self.bucket.clone(),
263 key: prefix.to_string(),
264 reason: format!("failed to read directory entry: {e}"),
265 })?
266 {
267 let path = entry.path();
268 let file_name = entry.file_name();
269 let key = format!(
270 "{}/{}",
271 prefix.trim_end_matches('/'),
272 file_name.to_string_lossy()
273 );
274
275 if path.is_file() {
276 self.upload_file(client, &path, &key).await?;
277 } else if path.is_dir() {
278 Box::pin(self.upload_directory(client, &path, &key)).await?;
280 }
281 }
282
283 Ok(())
284 }
285}
286
287#[cfg(feature = "s3")]
289pub struct S3Pull {
290 pub bucket: String,
292 pub key: String,
294 pub destination: String,
296 pub endpoint: Option<String>,
298 pub region: Option<String>,
300 pub timeout: Duration,
302}
303
304#[cfg(feature = "s3")]
305impl S3Pull {
306 pub async fn execute(&self) -> Result<()> {
313 use aws_sdk_s3::Client;
314 use tokio::io::AsyncWriteExt;
315
316 let mut config_loader = aws_config::defaults(aws_config::BehaviorVersion::latest());
318 if let Some(ref region) = self.region {
319 config_loader = config_loader.region(aws_config::Region::new(region.clone()));
320 }
321 let sdk_config = config_loader.load().await;
322
323 let mut s3_config = aws_sdk_s3::config::Builder::from(&sdk_config);
325 if let Some(ref endpoint) = self.endpoint {
326 s3_config = s3_config.endpoint_url(endpoint).force_path_style(true);
327 }
328 let client = Client::from_conf(s3_config.build());
329
330 tracing::info!(
331 bucket = %self.bucket,
332 key = %self.key,
333 destination = %self.destination,
334 "pulling from S3"
335 );
336
337 let result = tokio::time::timeout(
339 self.timeout,
340 client
341 .get_object()
342 .bucket(&self.bucket)
343 .key(&self.key)
344 .send(),
345 )
346 .await
347 .map_err(|_| InitError::Timeout {
348 timeout: self.timeout,
349 })?
350 .map_err(|e| InitError::S3Failed {
351 bucket: self.bucket.clone(),
352 key: self.key.clone(),
353 reason: format!("get_object failed: {e}"),
354 })?;
355
356 let data = result
358 .body
359 .collect()
360 .await
361 .map_err(|e| InitError::S3Failed {
362 bucket: self.bucket.clone(),
363 key: self.key.clone(),
364 reason: format!("failed to read body: {e}"),
365 })?
366 .into_bytes();
367
368 let dest_path = std::path::Path::new(&self.destination);
370 if let Some(parent) = dest_path.parent() {
371 tokio::fs::create_dir_all(parent)
372 .await
373 .map_err(|e| InitError::S3Failed {
374 bucket: self.bucket.clone(),
375 key: self.key.clone(),
376 reason: format!("failed to create destination directory: {e}"),
377 })?;
378 }
379
380 let mut file = tokio::fs::File::create(&self.destination)
381 .await
382 .map_err(|e| InitError::S3Failed {
383 bucket: self.bucket.clone(),
384 key: self.key.clone(),
385 reason: format!("failed to create file: {e}"),
386 })?;
387
388 file.write_all(&data)
389 .await
390 .map_err(|e| InitError::S3Failed {
391 bucket: self.bucket.clone(),
392 key: self.key.clone(),
393 reason: format!("failed to write file: {e}"),
394 })?;
395
396 tracing::info!(
397 bucket = %self.bucket,
398 key = %self.key,
399 bytes = data.len(),
400 "S3 pull complete"
401 );
402
403 Ok(())
404 }
405}
406
407#[allow(clippy::too_many_lines, clippy::implicit_hasher)]
413pub fn from_spec(
414 action: &str,
415 params: &HashMap<String, serde_json::Value>,
416 _default_timeout: Duration,
417) -> Result<InitAction> {
418 match action {
419 "init.wait_tcp" => {
420 let host = params
421 .get("host")
422 .and_then(|v| v.as_str())
423 .ok_or_else(|| InitError::InvalidParams {
424 action: action.to_string(),
425 reason: "missing 'host' parameter".to_string(),
426 })?
427 .to_string();
428
429 #[allow(clippy::cast_possible_truncation)]
430 let port = params
431 .get("port")
432 .and_then(serde_json::Value::as_u64)
433 .ok_or_else(|| InitError::InvalidParams {
434 action: action.to_string(),
435 reason: "missing or invalid 'port' parameter".to_string(),
436 })? as u16;
437
438 let timeout_secs = params
439 .get("timeout")
440 .and_then(serde_json::Value::as_u64)
441 .unwrap_or(30);
442
443 Ok(InitAction::WaitTcp(WaitTcp {
444 host,
445 port,
446 timeout: Duration::from_secs(timeout_secs),
447 interval: Duration::from_secs(2),
448 }))
449 }
450
451 "init.wait_http" => {
452 let url = params
453 .get("url")
454 .and_then(|v| v.as_str())
455 .ok_or_else(|| InitError::InvalidParams {
456 action: action.to_string(),
457 reason: "missing 'url' parameter".to_string(),
458 })?
459 .to_string();
460
461 #[allow(clippy::cast_possible_truncation)]
462 let expect_status = params
463 .get("expect_status")
464 .and_then(serde_json::Value::as_u64)
465 .map(|v| v as u16);
466
467 let timeout_secs = params
468 .get("timeout")
469 .and_then(serde_json::Value::as_u64)
470 .unwrap_or(30);
471
472 Ok(InitAction::WaitHttp(WaitHttp {
473 url,
474 expect_status,
475 timeout: Duration::from_secs(timeout_secs),
476 interval: Duration::from_secs(2),
477 }))
478 }
479
480 "init.run" => {
481 let command = params
482 .get("command")
483 .and_then(|v| v.as_str())
484 .ok_or_else(|| InitError::InvalidParams {
485 action: action.to_string(),
486 reason: "missing 'command' parameter".to_string(),
487 })?
488 .to_string();
489
490 let timeout_secs = params
491 .get("timeout")
492 .and_then(serde_json::Value::as_u64)
493 .unwrap_or(300);
494
495 Ok(InitAction::Run(RunCommand {
496 command,
497 timeout: Duration::from_secs(timeout_secs),
498 }))
499 }
500
501 #[cfg(feature = "s3")]
502 "init.s3_push" => {
503 let source = params
504 .get("source")
505 .and_then(|v| v.as_str())
506 .ok_or_else(|| InitError::InvalidParams {
507 action: action.to_string(),
508 reason: "missing 'source' parameter".to_string(),
509 })?
510 .to_string();
511
512 let bucket = params
513 .get("bucket")
514 .and_then(|v| v.as_str())
515 .ok_or_else(|| InitError::InvalidParams {
516 action: action.to_string(),
517 reason: "missing 'bucket' parameter".to_string(),
518 })?
519 .to_string();
520
521 let key = params
522 .get("key")
523 .and_then(|v| v.as_str())
524 .ok_or_else(|| InitError::InvalidParams {
525 action: action.to_string(),
526 reason: "missing 'key' parameter".to_string(),
527 })?
528 .to_string();
529
530 let endpoint = params
531 .get("endpoint")
532 .and_then(|v| v.as_str())
533 .map(String::from);
534 let region = params
535 .get("region")
536 .and_then(|v| v.as_str())
537 .map(String::from);
538 let timeout_secs = params
539 .get("timeout")
540 .and_then(serde_json::Value::as_u64)
541 .unwrap_or(300);
542
543 Ok(InitAction::S3Push(S3Push {
544 source,
545 bucket,
546 key,
547 endpoint,
548 region,
549 timeout: Duration::from_secs(timeout_secs),
550 }))
551 }
552
553 #[cfg(feature = "s3")]
554 "init.s3_pull" => {
555 let bucket = params
556 .get("bucket")
557 .and_then(|v| v.as_str())
558 .ok_or_else(|| InitError::InvalidParams {
559 action: action.to_string(),
560 reason: "missing 'bucket' parameter".to_string(),
561 })?
562 .to_string();
563
564 let key = params
565 .get("key")
566 .and_then(|v| v.as_str())
567 .ok_or_else(|| InitError::InvalidParams {
568 action: action.to_string(),
569 reason: "missing 'key' parameter".to_string(),
570 })?
571 .to_string();
572
573 let destination = params
574 .get("destination")
575 .and_then(|v| v.as_str())
576 .ok_or_else(|| InitError::InvalidParams {
577 action: action.to_string(),
578 reason: "missing 'destination' parameter".to_string(),
579 })?
580 .to_string();
581
582 let endpoint = params
583 .get("endpoint")
584 .and_then(|v| v.as_str())
585 .map(String::from);
586 let region = params
587 .get("region")
588 .and_then(|v| v.as_str())
589 .map(String::from);
590 let timeout_secs = params
591 .get("timeout")
592 .and_then(serde_json::Value::as_u64)
593 .unwrap_or(300);
594
595 Ok(InitAction::S3Pull(S3Pull {
596 bucket,
597 key,
598 destination,
599 endpoint,
600 region,
601 timeout: Duration::from_secs(timeout_secs),
602 }))
603 }
604
605 _ => Err(InitError::UnknownAction(action.to_string())),
606 }
607}
608
609pub enum InitAction {
611 WaitTcp(WaitTcp),
612 WaitHttp(WaitHttp),
613 Run(RunCommand),
614 #[cfg(feature = "s3")]
615 S3Push(S3Push),
616 #[cfg(feature = "s3")]
617 S3Pull(S3Pull),
618}
619
620impl InitAction {
621 pub async fn execute(&self) -> Result<()> {
624 match self {
625 InitAction::WaitTcp(a) => a.execute().await,
626 InitAction::WaitHttp(a) => a.execute().await,
627 InitAction::Run(a) => a.execute().await,
628 #[cfg(feature = "s3")]
629 InitAction::S3Push(a) => a.execute().await,
630 #[cfg(feature = "s3")]
631 InitAction::S3Pull(a) => a.execute().await,
632 }
633 }
634}
635
636#[cfg(test)]
637mod tests {
638 use super::*;
639
640 #[tokio::test]
641 async fn test_run_command_success() {
642 let action = RunCommand {
643 command: "echo hello".to_string(),
644 timeout: Duration::from_secs(5),
645 };
646 action.execute().await.unwrap();
647 }
648
649 #[tokio::test]
650 async fn test_run_command_failure() {
651 let action = RunCommand {
652 command: "exit 1".to_string(),
653 timeout: Duration::from_secs(5),
654 };
655 let result = action.execute().await;
656 assert!(result.is_err());
657 }
658
659 #[test]
660 fn test_from_spec_wait_tcp() {
661 let mut params = HashMap::new();
662 params.insert("host".to_string(), serde_json::json!("localhost"));
663 params.insert("port".to_string(), serde_json::json!(8080));
664
665 let action = from_spec("init.wait_tcp", ¶ms, Duration::from_secs(30)).unwrap();
666 match action {
667 InitAction::WaitTcp(_) => {}
668 _ => panic!("Expected WaitTcp action"),
669 }
670 }
671
672 #[test]
673 fn test_from_spec_unknown() {
674 let params = HashMap::new();
675 let result = from_spec("unknown.action", ¶ms, Duration::from_secs(30));
676 assert!(result.is_err());
677 }
678}