use aws_config::SdkConfig;
use hyper::client::HttpConnector;
use std::error::Error;
use std::fs::File;
use std::io::Write;
use std::path::Path;
use serde_json::Value;
use anyhow::Result;
use aws_sdk_s3::{primitives::ByteStream, Client};
use indicatif::{ProgressBar, ProgressStyle};
pub const BAR_FORMAT: &str = "[{elapsed_precise}] {bar:40.cyan/blue} ({bytes_per_sec}) {bytes:>7}/{total_bytes:7} {msg} [ETA {eta}]";
pub async fn s3_auth(
shasta_token: &str,
shasta_base_url: &str,
shasta_root_cert: &[u8],
) -> Result<Value, reqwest::Error> {
let client_builder = reqwest::Client::builder()
.add_root_certificate(reqwest::Certificate::from_pem(shasta_root_cert)?);
let client = if let Ok(socks5_env) = std::env::var("SOCKS5") {
log::debug!("SOCKS5 enabled");
let socks5proxy = reqwest::Proxy::all(socks5_env).unwrap();
client_builder.proxy(socks5proxy).build().unwrap()
} else {
client_builder.build().unwrap()
};
let api_url = shasta_base_url.to_owned() + "/sts/token";
let resp = client
.put(api_url)
.bearer_auth(shasta_token)
.send()
.await?
.error_for_status()?;
let sts_value = resp.json::<serde_json::Value>().await.unwrap();
log::debug!("-- STS Token retrieved --");
log::debug!("Debug - STS token:\n{:#?}", sts_value);
std::env::set_var(
"AWS_SESSION_TOKEN",
sts_value["Credentials"]["SessionToken"].as_str().unwrap(),
);
std::env::set_var(
"AWS_ACCESS_KEY_ID",
sts_value["Credentials"]["AccessKeyId"].as_str().unwrap(),
);
std::env::set_var(
"AWS_SECRET_ACCESS_KEY",
sts_value["Credentials"]["SecretAccessKey"]
.as_str()
.unwrap(),
);
Ok(sts_value)
}
async fn setup_client(sts_value: &Value) -> Client {
use aws_smithy_runtime::client::http::hyper_014::HyperClientBuilder;
let region_provider =
aws_config::meta::region::RegionProviderChain::default_provider().or_else("us-east-1");
let config: SdkConfig;
if let Ok(socks5_env) = std::env::var("SOCKS5") {
log::debug!("SOCKS5 enabled");
let mut http_connector: HttpConnector = hyper::client::HttpConnector::new();
http_connector.enforce_http(false);
let socks_http_connector = hyper_socks2::SocksConnector {
proxy_addr: socks5_env.to_string().parse::<hyper::Uri>().unwrap(), auth: None,
connector: http_connector.clone(),
};
let http_client = HyperClientBuilder::new().build(socks_http_connector);
config = aws_config::from_env()
.region(region_provider)
.http_client(http_client)
.endpoint_url(sts_value["Credentials"]["EndpointURL"].as_str().unwrap())
.app_name(aws_config::AppName::new("manta").unwrap())
.load()
.await;
} else {
config = aws_config::from_env()
.region(region_provider)
.endpoint_url(sts_value["Credentials"]["EndpointURL"].as_str().unwrap())
.app_name(aws_config::AppName::new("manta").unwrap())
.load()
.await;
}
let client = aws_sdk_s3::Client::from_conf(
aws_sdk_s3::Client::new(&config)
.config()
.to_builder()
.force_path_style(true)
.build(),
);
client
}
pub async fn s3_get_object_size(
sts_value: &Value,
key: &str,
bucket: &str,
) -> Result<i64, Box<dyn Error>> {
let client = setup_client(sts_value).await;
match client.get_object().bucket(bucket).key(key).send().await {
Ok(object) => Ok(object.content_length().unwrap()),
Err(e) => panic!("Error, unable to get object size from s3. Error msg: {}", e),
}
}
pub async fn s3_download_object(
sts_value: &Value,
object_path: &str,
bucket: &str,
destination_path: &str,
) -> Result<String, Box<dyn Error>> {
let client = setup_client(sts_value).await;
let filename = Path::new(object_path).file_name().unwrap();
let file_path = Path::new(destination_path).join(filename);
log::debug!("Create directory '{}'", destination_path);
match std::fs::create_dir_all(destination_path) {
Ok(_) => log::debug!("Created directory '{}' successfully", destination_path),
Err(error) => panic!("Error creating directory {}: {}", destination_path, error),
};
let mut file = match File::create(&file_path) {
Ok(file) => {
log::debug!(
"Created file '{}' successfully",
&file_path.to_string_lossy()
);
file
}
Err(error) => panic!(
"Error creating file {}: {}",
&file_path.to_string_lossy(),
error
),
};
let mut object = client
.get_object()
.bucket(bucket)
.key(object_path)
.send()
.await?;
let bar_size = object.content_length().unwrap();
let bar = ProgressBar::new(bar_size as u64);
bar.set_style(ProgressStyle::with_template(BAR_FORMAT).unwrap());
while let Some(bytes) = object.body.try_next().await? {
let bytes = file.write(&bytes)?;
bar.inc(bytes as u64);
}
bar.finish();
Ok(file_path.to_string_lossy().to_string())
}
pub async fn s3_upload_object(
sts_value: &Value,
object_path: &str,
bucket: &str,
file_path: &str,
) -> Result<String, Box<dyn Error>> {
let client = setup_client(sts_value).await;
let body = ByteStream::from_path(Path::new(&file_path)).await;
match client
.put_object()
.bucket(bucket)
.key(object_path)
.body(body.unwrap())
.send()
.await
{
Ok(put_object_output) => {
log::debug!("Uploaded file '{}' successfully", &file_path);
Ok(put_object_output.e_tag.unwrap())
}
Err(error) => panic!("Error uploading file {}: {}", &file_path, error),
}
}
pub async fn s3_remove_object(
sts_value: &Value,
object_path: &str,
bucket: &str,
) -> Result<String, Box<dyn Error>> {
let client = setup_client(sts_value).await;
match client
.delete_object()
.bucket(bucket)
.key(object_path)
.send()
.await
{
Ok(_file) => {
log::debug!("Cleaned file '{}' successfully", &object_path);
Ok(String::from("client"))
}
Err(error) => panic!("Error cleaning file {}: {}", &object_path, error),
}
}
pub async fn s3_multipart_upload_object(
sts_value: &Value,
object_path: &str,
bucket: &str,
file_path: &str,
) -> Result<String, Box<dyn Error>> {
use aws_sdk_s3::operation::create_multipart_upload::CreateMultipartUploadOutput;
use aws_sdk_s3::types::{CompletedMultipartUpload, CompletedPart};
use aws_smithy_types::byte_stream::Length;
let client = setup_client(sts_value).await;
const CHUNK_SIZE: u64 = 1024 * 1024 * 5;
const MAX_CHUNKS: u64 = 10000;
let multipart_upload_res: CreateMultipartUploadOutput = client
.create_multipart_upload()
.bucket(bucket)
.key(object_path)
.send()
.await
.unwrap();
let upload_id = multipart_upload_res.upload_id().unwrap();
let path = Path::new(&file_path);
let file_size = std::fs::metadata(path).expect("it exists I swear").len();
let mut chunk_count = (file_size / CHUNK_SIZE) + 1;
let mut size_of_last_chunk = file_size % CHUNK_SIZE;
if size_of_last_chunk == 0 {
size_of_last_chunk = CHUNK_SIZE;
chunk_count -= 1;
}
let bar = ProgressBar::new(file_size);
bar.set_style(ProgressStyle::with_template(BAR_FORMAT).unwrap());
if file_size == 0 {
panic!("Bad file size.");
}
if chunk_count > MAX_CHUNKS {
panic!("Too many chunks! Try increasing your chunk size.")
}
let mut upload_parts: Vec<CompletedPart> = Vec::new();
for chunk_index in 0..chunk_count {
let this_chunk = if chunk_count - 1 == chunk_index {
size_of_last_chunk
} else {
CHUNK_SIZE
};
let stream = ByteStream::read_from()
.path(path)
.offset(chunk_index * CHUNK_SIZE)
.length(Length::Exact(this_chunk))
.build()
.await
.unwrap();
let part_number = (chunk_index as i32) + 1;
let upload_part_res = client
.upload_part()
.key(object_path)
.bucket(bucket)
.upload_id(upload_id)
.body(stream)
.part_number(part_number)
.send()
.await?;
upload_parts.push(
CompletedPart::builder()
.e_tag(upload_part_res.e_tag.unwrap_or_default())
.part_number(part_number)
.build(),
);
bar.inc(this_chunk);
}
let completed_multipart_upload: CompletedMultipartUpload = CompletedMultipartUpload::builder()
.set_parts(Some(upload_parts))
.build();
let _complete_multipart_upload_res = client
.complete_multipart_upload()
.bucket(bucket)
.key(object_path)
.multipart_upload(completed_multipart_upload)
.upload_id(upload_id)
.send()
.await
.unwrap();
bar.finish();
Ok(_complete_multipart_upload_res.e_tag.clone().unwrap())
}