use crate::{
BatchError,
core::step::{RepeatStatus, StepExecution, Tasklet},
tasklet::s3::{S3ClientConfig, build_s3_client},
};
use log::{debug, info};
use std::path::{Path, PathBuf};
use tokio::runtime::Handle;
#[derive(Debug)]
pub struct S3GetTasklet {
bucket: String,
key: String,
local_file: PathBuf,
config: S3ClientConfig,
}
impl S3GetTasklet {
async fn execute_async(&self) -> Result<RepeatStatus, BatchError> {
info!(
"Downloading s3://{}/{} -> {}",
self.bucket,
self.key,
self.local_file.display()
);
let client = build_s3_client(&self.config).await?;
if let Some(parent) = self.local_file.parent() {
std::fs::create_dir_all(parent).map_err(BatchError::Io)?;
}
let resp = client
.get_object()
.bucket(&self.bucket)
.key(&self.key)
.send()
.await
.map_err(|e| {
BatchError::ItemReader(format!("S3 get_object failed for {}: {}", self.key, e))
})?;
let mut body = resp.body.into_async_read();
let mut file = tokio::fs::File::create(&self.local_file)
.await
.map_err(BatchError::Io)?;
let bytes_written = tokio::io::copy(&mut body, &mut file)
.await
.map_err(BatchError::Io)?;
info!(
"Download complete: {} bytes written to {}",
bytes_written,
self.local_file.display()
);
Ok(RepeatStatus::Finished)
}
}
impl Tasklet for S3GetTasklet {
fn execute(&self, _step_execution: &StepExecution) -> Result<RepeatStatus, BatchError> {
tokio::task::block_in_place(|| Handle::current().block_on(self.execute_async()))
}
}
#[derive(Debug, Default)]
pub struct S3GetTaskletBuilder {
bucket: Option<String>,
key: Option<String>,
local_file: Option<PathBuf>,
config: S3ClientConfig,
}
impl S3GetTaskletBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn bucket<S: Into<String>>(mut self, bucket: S) -> Self {
self.bucket = Some(bucket.into());
self
}
pub fn key<S: Into<String>>(mut self, key: S) -> Self {
self.key = Some(key.into());
self
}
pub fn local_file<P: AsRef<Path>>(mut self, path: P) -> Self {
self.local_file = Some(path.as_ref().to_path_buf());
self
}
pub fn region<S: Into<String>>(mut self, region: S) -> Self {
self.config.region = Some(region.into());
self
}
pub fn endpoint_url<S: Into<String>>(mut self, url: S) -> Self {
self.config.endpoint_url = Some(url.into());
self
}
pub fn access_key_id<S: Into<String>>(mut self, key_id: S) -> Self {
self.config.access_key_id = Some(key_id.into());
self
}
pub fn secret_access_key<S: Into<String>>(mut self, secret: S) -> Self {
self.config.secret_access_key = Some(secret.into());
self
}
pub fn build(self) -> Result<S3GetTasklet, BatchError> {
let bucket = self.bucket.ok_or_else(|| {
BatchError::Configuration("S3GetTasklet: 'bucket' is required".to_string())
})?;
let key = self.key.ok_or_else(|| {
BatchError::Configuration("S3GetTasklet: 'key' is required".to_string())
})?;
let local_file = self.local_file.ok_or_else(|| {
BatchError::Configuration("S3GetTasklet: 'local_file' is required".to_string())
})?;
Ok(S3GetTasklet {
bucket,
key,
local_file,
config: self.config,
})
}
}
#[derive(Debug)]
pub struct S3GetFolderTasklet {
bucket: String,
prefix: String,
local_folder: PathBuf,
config: S3ClientConfig,
}
impl S3GetFolderTasklet {
async fn execute_async(&self) -> Result<RepeatStatus, BatchError> {
info!(
"Downloading s3://{}/{} -> {}",
self.bucket,
self.prefix,
self.local_folder.display()
);
let client = build_s3_client(&self.config).await?;
std::fs::create_dir_all(&self.local_folder).map_err(BatchError::Io)?;
let mut continuation_token: Option<String> = None;
let mut total_files = 0usize;
loop {
let mut req = client
.list_objects_v2()
.bucket(&self.bucket)
.prefix(&self.prefix);
if let Some(token) = continuation_token {
req = req.continuation_token(token);
}
let list_resp = req
.send()
.await
.map_err(|e| BatchError::ItemReader(format!("list_objects_v2 failed: {}", e)))?;
for object in list_resp.contents() {
let key = object.key().unwrap_or_default();
let relative = key.strip_prefix(self.prefix.as_str()).unwrap_or(key);
let relative = relative.strip_prefix('/').unwrap_or(relative);
if relative.is_empty() {
continue; }
let local_path = self.local_folder.join(relative);
if let Some(parent) = local_path.parent() {
std::fs::create_dir_all(parent).map_err(BatchError::Io)?;
}
debug!(
"Downloading s3://{}/{} -> {}",
self.bucket,
key,
local_path.display()
);
let resp = client
.get_object()
.bucket(&self.bucket)
.key(key)
.send()
.await
.map_err(|e| {
BatchError::ItemReader(format!("get_object failed for {}: {}", key, e))
})?;
let mut body = resp.body.into_async_read();
let mut file = tokio::fs::File::create(&local_path)
.await
.map_err(BatchError::Io)?;
tokio::io::copy(&mut body, &mut file)
.await
.map_err(BatchError::Io)?;
total_files += 1;
}
if list_resp.is_truncated().unwrap_or(false) {
continuation_token = list_resp.next_continuation_token().map(str::to_string);
} else {
break;
}
}
info!(
"Folder download complete: {} files downloaded to {}",
total_files,
self.local_folder.display()
);
Ok(RepeatStatus::Finished)
}
}
impl Tasklet for S3GetFolderTasklet {
fn execute(&self, _step_execution: &StepExecution) -> Result<RepeatStatus, BatchError> {
tokio::task::block_in_place(|| Handle::current().block_on(self.execute_async()))
}
}
#[derive(Debug, Default)]
pub struct S3GetFolderTaskletBuilder {
bucket: Option<String>,
prefix: Option<String>,
local_folder: Option<PathBuf>,
config: S3ClientConfig,
}
impl S3GetFolderTaskletBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn bucket<S: Into<String>>(mut self, bucket: S) -> Self {
self.bucket = Some(bucket.into());
self
}
pub fn prefix<S: Into<String>>(mut self, prefix: S) -> Self {
self.prefix = Some(prefix.into());
self
}
pub fn local_folder<P: AsRef<Path>>(mut self, path: P) -> Self {
self.local_folder = Some(path.as_ref().to_path_buf());
self
}
pub fn region<S: Into<String>>(mut self, region: S) -> Self {
self.config.region = Some(region.into());
self
}
pub fn endpoint_url<S: Into<String>>(mut self, url: S) -> Self {
self.config.endpoint_url = Some(url.into());
self
}
pub fn access_key_id<S: Into<String>>(mut self, key_id: S) -> Self {
self.config.access_key_id = Some(key_id.into());
self
}
pub fn secret_access_key<S: Into<String>>(mut self, secret: S) -> Self {
self.config.secret_access_key = Some(secret.into());
self
}
pub fn build(self) -> Result<S3GetFolderTasklet, BatchError> {
let bucket = self.bucket.ok_or_else(|| {
BatchError::Configuration("S3GetFolderTasklet: 'bucket' is required".to_string())
})?;
let prefix = self.prefix.ok_or_else(|| {
BatchError::Configuration("S3GetFolderTasklet: 'prefix' is required".to_string())
})?;
let local_folder = self.local_folder.ok_or_else(|| {
BatchError::Configuration("S3GetFolderTasklet: 'local_folder' is required".to_string())
})?;
Ok(S3GetFolderTasklet {
bucket,
prefix,
local_folder,
config: self.config,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn should_fail_build_when_bucket_missing() {
let result = S3GetTaskletBuilder::new()
.key("file.csv")
.local_file("/tmp/file.csv")
.build();
assert!(result.is_err(), "build should fail without bucket");
assert!(result.unwrap_err().to_string().contains("bucket"));
}
#[test]
fn should_fail_build_when_key_missing() {
let result = S3GetTaskletBuilder::new()
.bucket("my-bucket")
.local_file("/tmp/file.csv")
.build();
assert!(result.is_err(), "build should fail without key");
assert!(result.unwrap_err().to_string().contains("key"));
}
#[test]
fn should_fail_build_when_local_file_missing() {
let result = S3GetTaskletBuilder::new()
.bucket("my-bucket")
.key("file.csv")
.build();
assert!(result.is_err(), "build should fail without local_file");
assert!(result.unwrap_err().to_string().contains("local_file"));
}
#[test]
fn should_build_with_required_fields() {
let result = S3GetTaskletBuilder::new()
.bucket("my-bucket")
.key("file.csv")
.local_file("/tmp/file.csv")
.build();
assert!(
result.is_ok(),
"build should succeed with required fields: {:?}",
result.err()
);
}
#[test]
fn should_store_optional_config_fields() {
let tasklet = S3GetTaskletBuilder::new()
.bucket("b")
.key("k")
.local_file("/tmp/f")
.region("eu-west-1")
.endpoint_url("http://localhost:9000")
.access_key_id("AKID")
.secret_access_key("SECRET")
.build()
.unwrap(); assert_eq!(tasklet.config.region.as_deref(), Some("eu-west-1"));
assert_eq!(
tasklet.config.endpoint_url.as_deref(),
Some("http://localhost:9000")
);
assert_eq!(tasklet.config.access_key_id.as_deref(), Some("AKID"));
assert_eq!(tasklet.config.secret_access_key.as_deref(), Some("SECRET"));
}
#[test]
fn should_fail_folder_build_when_bucket_missing() {
let result = S3GetFolderTaskletBuilder::new()
.prefix("backups/")
.local_folder("/tmp/imports")
.build();
assert!(result.is_err(), "build should fail without bucket");
assert!(result.unwrap_err().to_string().contains("bucket"));
}
#[test]
fn should_fail_folder_build_when_prefix_missing() {
let result = S3GetFolderTaskletBuilder::new()
.bucket("my-bucket")
.local_folder("/tmp/imports")
.build();
assert!(result.is_err(), "build should fail without prefix");
assert!(result.unwrap_err().to_string().contains("prefix"));
}
#[test]
fn should_fail_folder_build_when_local_folder_missing() {
let result = S3GetFolderTaskletBuilder::new()
.bucket("my-bucket")
.prefix("backups/")
.build();
assert!(result.is_err(), "build should fail without local_folder");
assert!(result.unwrap_err().to_string().contains("local_folder"));
}
#[test]
fn should_build_folder_with_required_fields() {
let result = S3GetFolderTaskletBuilder::new()
.bucket("my-bucket")
.prefix("backups/")
.local_folder("/tmp/imports")
.build();
assert!(result.is_ok(), "build should succeed: {:?}", result.err());
}
#[test]
fn should_strip_leading_slash_from_relative_key() {
let prefix_with_slash = "backups/2026/";
let prefix_without_slash = "backups/2026";
let key = "backups/2026/file.csv";
let local_folder = std::path::Path::new("/tmp/imports");
let relative = key.strip_prefix(prefix_with_slash).unwrap_or(key);
let relative = relative.strip_prefix('/').unwrap_or(relative);
let path = local_folder.join(relative);
assert_eq!(
path,
std::path::Path::new("/tmp/imports/file.csv"),
"trailing-slash prefix should produce a correct local path"
);
let relative = key.strip_prefix(prefix_without_slash).unwrap_or(key);
let relative = relative.strip_prefix('/').unwrap_or(relative);
let path = local_folder.join(relative);
assert_eq!(
path,
std::path::Path::new("/tmp/imports/file.csv"),
"non-trailing-slash prefix must not produce an absolute path"
);
}
}