use crate::catalog::{AuthProvider, CatalogError, Result};
use async_trait::async_trait;
use aws_sigv4::http_request::{SignableBody, SignableRequest, SigningSettings};
use aws_sigv4::sign::v4;
use http::Request as HttpRequest;
use std::time::SystemTime;
#[derive(Debug)]
pub struct SigV4AuthProvider {
region: String,
service: String,
credentials: aws_credential_types::Credentials,
}
impl SigV4AuthProvider {
pub fn new(
region: String,
service: String,
credentials: aws_credential_types::Credentials,
) -> Self {
Self {
region,
service,
credentials,
}
}
}
#[async_trait]
impl AuthProvider for SigV4AuthProvider {
async fn sign_request(&self, req: reqwest::Request) -> Result<reqwest::Request> {
let url = req.url().clone();
let method = req.method().clone();
let headers = req.headers().clone();
let body_bytes = req
.body()
.and_then(|b| b.as_bytes())
.map(|b| b.to_vec())
.unwrap_or_default();
let mut http_req = HttpRequest::builder()
.method(method.as_str())
.uri(url.as_str())
.body(&body_bytes[..])
.map_err(|e| {
CatalogError::Unexpected(format!("Failed to build HTTP request: {}", e))
})?;
for (name, value) in headers.iter() {
http_req.headers_mut().insert(name.clone(), value.clone());
}
let identity = self.credentials.clone().into();
let signing_settings = SigningSettings::default();
let signing_params = v4::SigningParams::builder()
.identity(&identity)
.region(&self.region)
.name(&self.service)
.time(SystemTime::now())
.settings(signing_settings)
.build()
.expect("signing params are valid")
.into();
let signable_request = SignableRequest::new(
http_req.method().as_str(),
url.as_str(),
std::iter::empty::<(&str, &str)>(),
SignableBody::Bytes(&body_bytes),
)
.expect("signable request");
let (signing_instructions, _signature) =
aws_sigv4::http_request::sign(signable_request, &signing_params)
.map_err(|e| CatalogError::AuthError(format!("Failed to sign request: {}", e)))?
.into_parts();
signing_instructions.apply_to_request_http1x(&mut http_req);
let http_client = reqwest::Client::new();
let mut signed_req = http_client
.request(method, url)
.body(body_bytes.clone())
.build()
.map_err(|e| CatalogError::HttpError(format!("Failed to build request: {}", e)))?;
*signed_req.headers_mut() = http_req.headers().clone();
Ok(signed_req)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_credentials() -> aws_credential_types::Credentials {
aws_credential_types::Credentials::new(
"AKIAIOSFODNN7EXAMPLE",
"wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
None,
None,
"test",
)
}
#[tokio::test]
async fn test_sigv4_adds_authorization_header() {
let provider = SigV4AuthProvider::new(
"us-west-2".to_string(),
"s3tables".to_string(),
create_test_credentials(),
);
let req = reqwest::Client::new()
.get("https://s3tables.us-west-2.amazonaws.com/iceberg")
.build()
.unwrap();
let signed_req = provider.sign_request(req).await.unwrap();
let auth_header = signed_req
.headers()
.get(reqwest::header::AUTHORIZATION)
.expect("Authorization header should be present");
let auth_str = auth_header.to_str().unwrap();
assert!(auth_str.starts_with("AWS4-HMAC-SHA256"));
assert!(auth_str.contains("Credential="));
assert!(auth_str.contains("SignedHeaders="));
assert!(auth_str.contains("Signature="));
}
#[tokio::test]
async fn test_sigv4_adds_aws_headers() {
let provider = SigV4AuthProvider::new(
"us-east-1".to_string(),
"s3tables".to_string(),
create_test_credentials(),
);
let req = reqwest::Client::new()
.post("https://s3tables.us-east-1.amazonaws.com/iceberg")
.body("test body")
.build()
.unwrap();
let signed_req = provider.sign_request(req).await.unwrap();
assert!(signed_req.headers().contains_key("x-amz-date"));
assert!(signed_req
.headers()
.contains_key(reqwest::header::AUTHORIZATION));
let body = signed_req.body().unwrap().as_bytes().unwrap();
assert_eq!(body, b"test body");
}
#[tokio::test]
async fn test_sigv4_preserves_original_headers() {
let provider = SigV4AuthProvider::new(
"us-west-2".to_string(),
"s3tables".to_string(),
create_test_credentials(),
);
let req = reqwest::Client::new()
.get("https://s3tables.us-west-2.amazonaws.com/iceberg")
.header("Content-Type", "application/json")
.header("X-Custom-Header", "custom-value")
.build()
.unwrap();
let signed_req = provider.sign_request(req).await.unwrap();
assert_eq!(
signed_req.headers().get("Content-Type").unwrap(),
"application/json"
);
assert_eq!(
signed_req.headers().get("X-Custom-Header").unwrap(),
"custom-value"
);
assert!(signed_req.headers().contains_key("x-amz-date"));
assert!(signed_req
.headers()
.contains_key(reqwest::header::AUTHORIZATION));
}
#[test]
fn test_sigv4_provider_debug() {
let provider = SigV4AuthProvider::new(
"us-west-2".to_string(),
"s3tables".to_string(),
create_test_credentials(),
);
let debug_str = format!("{:?}", provider);
assert!(debug_str.contains("SigV4AuthProvider"));
}
}