use std::{convert::Infallible, future::Future, pin::Pin, sync::Arc};
use bytes::Bytes;
use http_body_util::BodyExt;
use hyper::body::Incoming;
use rustack_sts_model::error::StsError;
use crate::{
body::StsResponseBody,
dispatch::{StsHandler, dispatch_operation},
request::{extract_access_key_from_auth, parse_form_params},
response::{CONTENT_TYPE, error_to_response},
router::resolve_operation,
};
#[derive(Clone)]
pub struct StsHttpConfig {
pub skip_signature_validation: bool,
pub region: String,
pub credential_provider: Option<Arc<dyn rustack_auth::CredentialProvider>>,
}
impl std::fmt::Debug for StsHttpConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StsHttpConfig")
.field("skip_signature_validation", &self.skip_signature_validation)
.field("region", &self.region)
.field(
"credential_provider",
&self.credential_provider.as_ref().map(|_| "..."),
)
.finish()
}
}
impl Default for StsHttpConfig {
fn default() -> Self {
Self {
skip_signature_validation: true,
region: "us-east-1".to_owned(),
credential_provider: None,
}
}
}
#[derive(Debug)]
pub struct StsHttpService<H: StsHandler> {
handler: Arc<H>,
config: Arc<StsHttpConfig>,
}
impl<H: StsHandler> StsHttpService<H> {
pub fn new(handler: Arc<H>, config: StsHttpConfig) -> Self {
Self {
handler,
config: Arc::new(config),
}
}
}
impl<H: StsHandler> Clone for StsHttpService<H> {
fn clone(&self) -> Self {
Self {
handler: Arc::clone(&self.handler),
config: Arc::clone(&self.config),
}
}
}
impl<H: StsHandler> hyper::service::Service<http::Request<Incoming>> for StsHttpService<H> {
type Response = http::Response<StsResponseBody>;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn call(&self, req: http::Request<Incoming>) -> Self::Future {
let handler = Arc::clone(&self.handler);
let config = Arc::clone(&self.config);
let request_id = uuid::Uuid::new_v4().to_string();
Box::pin(async move {
let response = process_request(req, handler.as_ref(), &config, &request_id).await;
let response = add_common_headers(response, &request_id);
Ok(response)
})
}
}
async fn process_request<H: StsHandler>(
req: http::Request<Incoming>,
handler: &H,
config: &StsHttpConfig,
request_id: &str,
) -> http::Response<StsResponseBody> {
let (parts, incoming) = req.into_parts();
if parts.method != http::Method::POST {
let err =
StsError::invalid_action(format!("STS requires POST method, got {}", parts.method));
return error_to_response(&err, request_id);
}
let body = match collect_body(incoming).await {
Ok(body) => body,
Err(err) => return error_to_response(&err, request_id),
};
let params = parse_form_params(&body);
let op = match resolve_operation(¶ms) {
Ok(op) => op,
Err(err) => return error_to_response(&err, request_id),
};
let caller_access_key = parts
.headers
.get("authorization")
.and_then(|v| v.to_str().ok())
.and_then(extract_access_key_from_auth);
if !config.skip_signature_validation {
if let Some(ref cred_provider) = config.credential_provider {
let body_hash = rustack_auth::hash_payload(&body);
if let Err(auth_err) =
rustack_auth::verify_sigv4(&parts, &body_hash, cred_provider.as_ref())
{
let err = StsError::invalid_client_token_id(auth_err.to_string());
return error_to_response(&err, request_id);
}
}
}
match dispatch_operation(handler, op, body, caller_access_key, request_id).await {
Ok(response) => response,
Err(err) => error_to_response(&err, request_id),
}
}
async fn collect_body(incoming: Incoming) -> Result<Bytes, StsError> {
incoming
.collect()
.await
.map(http_body_util::Collected::to_bytes)
.map_err(|e| StsError::internal_error(format!("Failed to read request body: {e}")))
}
fn add_common_headers(
mut response: http::Response<StsResponseBody>,
request_id: &str,
) -> http::Response<StsResponseBody> {
let headers = response.headers_mut();
if let Ok(hv) = http::HeaderValue::from_str(request_id) {
headers.entry("x-amzn-requestid").or_insert(hv);
}
headers
.entry("content-type")
.or_insert(http::HeaderValue::from_static(CONTENT_TYPE));
headers.insert("server", http::HeaderValue::from_static("Rustack"));
headers.insert(
"access-control-allow-origin",
http::HeaderValue::from_static("*"),
);
response
}