pub mod phase;
pub mod wrappers;
use crate::client::orchestrator::{HttpRequest, HttpResponse, OrchestratorError};
use crate::config_bag::ConfigBag;
use crate::type_erasure::{TypeErasedBox, TypeErasedError};
use aws_smithy_http::result::SdkError;
use phase::Phase;
use std::fmt::Debug;
use std::mem;
use tracing::{error, trace};
pub type Input = TypeErasedBox;
pub type Output = TypeErasedBox;
pub type Error = TypeErasedError;
pub type OutputOrError = Result<Output, OrchestratorError<Error>>;
type Request = HttpRequest;
type Response = HttpResponse;
#[derive(Debug)]
pub struct InterceptorContext<I = Input, O = Output, E = Error>
where
E: Debug,
{
pub(crate) input: Option<I>,
pub(crate) output_or_error: Option<Result<O, OrchestratorError<E>>>,
pub(crate) request: Option<Request>,
pub(crate) response: Option<Response>,
phase: Phase,
tainted: bool,
request_checkpoint: Option<HttpRequest>,
}
impl InterceptorContext<Input, Output, Error> {
pub fn new(input: Input) -> InterceptorContext<Input, Output, Error> {
InterceptorContext {
input: Some(input),
output_or_error: None,
request: None,
response: None,
phase: Phase::BeforeSerialization,
tainted: false,
request_checkpoint: None,
}
}
}
impl<I, O, E> InterceptorContext<I, O, E>
where
E: Debug,
{
#[doc(hidden)]
#[allow(clippy::type_complexity)]
pub fn into_parts(
self,
) -> (
Option<I>,
Option<Result<O, OrchestratorError<E>>>,
Option<Request>,
Option<Response>,
) {
(
self.input,
self.output_or_error,
self.request,
self.response,
)
}
pub fn finalize(self) -> Result<O, SdkError<E, HttpResponse>> {
let Self {
output_or_error,
response,
phase,
..
} = self;
output_or_error
.expect("output_or_error must always beset before finalize is called.")
.map_err(|error| OrchestratorError::into_sdk_error(error, &phase, response))
}
pub fn input(&self) -> &I {
self.input
.as_ref()
.expect("input is present in 'before serialization'")
}
pub fn input_mut(&mut self) -> &mut I {
self.input
.as_mut()
.expect("input is present in 'before serialization'")
}
pub fn take_input(&mut self) -> Option<I> {
self.input.take()
}
pub fn set_request(&mut self, request: Request) {
self.request = Some(request);
}
pub fn request(&self) -> &Request {
self.request
.as_ref()
.expect("request populated in 'before transmit'")
}
pub fn request_mut(&mut self) -> &mut Request {
self.request
.as_mut()
.expect("request populated in 'before transmit'")
}
pub fn take_request(&mut self) -> Request {
self.request
.take()
.expect("take request once during 'transmit'")
}
pub fn set_response(&mut self, response: Response) {
self.response = Some(response);
}
pub fn response(&self) -> &Response {
self.response.as_ref().expect(
"response set in 'before deserialization' and available in the phases following it",
)
}
pub fn response_mut(&mut self) -> &mut Response {
self.response.as_mut().expect(
"response is set in 'before deserialization' and available in the following phases",
)
}
pub fn set_output_or_error(&mut self, output: Result<O, OrchestratorError<E>>) {
self.output_or_error = Some(output);
}
pub fn output_or_error(&self) -> Result<&O, &OrchestratorError<E>> {
self.output_or_error
.as_ref()
.expect("output set in Phase::AfterDeserialization")
.as_ref()
}
pub fn output_or_error_mut(&mut self) -> &mut Result<O, OrchestratorError<E>> {
self.output_or_error
.as_mut()
.expect("output set in 'after deserialization'")
}
#[doc(hidden)]
pub fn enter_serialization_phase(&mut self) {
debug_assert!(
self.phase.is_before_serialization(),
"called enter_serialization_phase but phase is not before 'serialization'"
);
self.phase = Phase::Serialization;
}
#[doc(hidden)]
pub fn enter_before_transmit_phase(&mut self) {
debug_assert!(
self.phase.is_serialization(),
"called enter_before_transmit_phase but phase is not 'serialization'"
);
debug_assert!(
self.input.is_none(),
"input must be taken before calling enter_before_transmit_phase"
);
debug_assert!(
self.request.is_some(),
"request must be set before calling enter_before_transmit_phase"
);
self.request_checkpoint = try_clone(self.request());
self.tainted = true;
self.phase = Phase::BeforeTransmit;
}
#[doc(hidden)]
pub fn enter_transmit_phase(&mut self) {
debug_assert!(
self.phase.is_before_transmit(),
"called enter_transmit_phase but phase is not before transmit"
);
self.phase = Phase::Transmit;
}
#[doc(hidden)]
pub fn enter_before_deserialization_phase(&mut self) {
debug_assert!(
self.phase.is_transmit(),
"called enter_before_deserialization_phase but phase is not 'transmit'"
);
debug_assert!(
self.request.is_none(),
"request must be taken before entering the 'before deserialization' phase"
);
debug_assert!(
self.response.is_some(),
"response must be set to before entering the 'before deserialization' phase"
);
self.phase = Phase::BeforeDeserialization;
}
#[doc(hidden)]
pub fn enter_deserialization_phase(&mut self) {
debug_assert!(
self.phase.is_before_deserialization(),
"called enter_deserialization_phase but phase is not 'before deserialization'"
);
self.phase = Phase::Deserialization;
}
#[doc(hidden)]
pub fn enter_after_deserialization_phase(&mut self) {
debug_assert!(
self.phase.is_deserialization(),
"called enter_after_deserialization_phase but phase is not 'deserialization'"
);
debug_assert!(
self.output_or_error.is_some(),
"output must be set to before entering the 'after deserialization' phase"
);
self.phase = Phase::AfterDeserialization;
}
pub fn rewind(&mut self, _cfg: &mut ConfigBag) -> bool {
if !self.tainted {
return true;
}
if self.request_checkpoint.is_none() {
return false;
}
self.phase = Phase::BeforeTransmit;
self.request = try_clone(self.request_checkpoint.as_ref().expect("checked above"));
self.response = None;
self.output_or_error = None;
true
}
pub fn fail(&mut self, error: OrchestratorError<E>) {
if !self.is_failed() {
trace!(
"orchestrator is transitioning to the 'failure' phase from the '{:?}' phase",
self.phase
);
}
if let Some(Err(existing_err)) = mem::replace(&mut self.output_or_error, Some(Err(error))) {
error!("orchestrator context received an error but one was already present; Throwing away previous error: {:?}", existing_err);
}
}
pub fn is_failed(&self) -> bool {
self.output_or_error
.as_ref()
.map(Result::is_err)
.unwrap_or_default()
}
}
fn try_clone(request: &HttpRequest) -> Option<HttpRequest> {
let cloned_body = request.body().try_clone()?;
let mut cloned_request = ::http::Request::builder()
.uri(request.uri().clone())
.method(request.method());
*cloned_request
.headers_mut()
.expect("builder has not been modified, headers must be valid") = request.headers().clone();
Some(
cloned_request
.body(cloned_body)
.expect("a clone of a valid request should be a valid request"),
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::type_erasure::TypedBox;
use aws_smithy_http::body::SdkBody;
use http::header::{AUTHORIZATION, CONTENT_LENGTH};
use http::{HeaderValue, Uri};
#[test]
fn test_success_transitions() {
let input = TypedBox::new("input".to_string()).erase();
let output = TypedBox::new("output".to_string()).erase();
let mut context = InterceptorContext::new(input);
assert_eq!("input", context.input().downcast_ref::<String>().unwrap());
context.input_mut();
context.enter_serialization_phase();
let _ = context.take_input();
context.set_request(http::Request::builder().body(SdkBody::empty()).unwrap());
context.enter_before_transmit_phase();
context.request();
context.request_mut();
context.enter_transmit_phase();
let _ = context.take_request();
context.set_response(http::Response::builder().body(SdkBody::empty()).unwrap());
context.enter_before_deserialization_phase();
context.response();
context.response_mut();
context.enter_deserialization_phase();
context.response();
context.response_mut();
context.set_output_or_error(Ok(output));
context.enter_after_deserialization_phase();
context.response();
context.response_mut();
let _ = context.output_or_error();
let _ = context.output_or_error_mut();
let output = context.output_or_error.unwrap().expect("success");
assert_eq!("output", output.downcast_ref::<String>().unwrap());
}
#[test]
fn test_rewind_for_retry() {
use std::fmt;
#[derive(Debug)]
struct Error;
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("don't care")
}
}
impl std::error::Error for Error {}
let mut cfg = ConfigBag::base();
let input = TypedBox::new("input".to_string()).erase();
let output = TypedBox::new("output".to_string()).erase();
let error = TypedBox::new(Error).erase_error();
let mut context = InterceptorContext::new(input);
assert_eq!("input", context.input().downcast_ref::<String>().unwrap());
context.enter_serialization_phase();
let _ = context.take_input();
context.set_request(
http::Request::builder()
.header("test", "the-original-un-mutated-request")
.body(SdkBody::empty())
.unwrap(),
);
context.enter_before_transmit_phase();
context.request_mut().headers_mut().remove("test");
context.request_mut().headers_mut().insert(
"test",
HeaderValue::from_static("request-modified-after-signing"),
);
context.enter_transmit_phase();
let request = context.take_request();
assert_eq!(
"request-modified-after-signing",
request.headers().get("test").unwrap()
);
context.set_response(http::Response::builder().body(SdkBody::empty()).unwrap());
context.enter_before_deserialization_phase();
context.enter_deserialization_phase();
context.set_output_or_error(Err(OrchestratorError::operation(error)));
assert!(context.rewind(&mut cfg));
assert_eq!(
"the-original-un-mutated-request",
context.request().headers().get("test").unwrap()
);
context.enter_transmit_phase();
let _ = context.take_request();
context.set_response(http::Response::builder().body(SdkBody::empty()).unwrap());
context.enter_before_deserialization_phase();
context.enter_deserialization_phase();
context.set_output_or_error(Ok(output));
context.enter_after_deserialization_phase();
let output = context.output_or_error.unwrap().expect("success");
assert_eq!("output", output.downcast_ref::<String>().unwrap());
}
#[test]
fn try_clone_clones_all_data() {
let request = ::http::Request::builder()
.uri(Uri::from_static("https://www.amazon.com"))
.method("POST")
.header(CONTENT_LENGTH, 456)
.header(AUTHORIZATION, "Token: hello")
.body(SdkBody::from("hello world!"))
.expect("valid request");
let cloned = try_clone(&request).expect("request is cloneable");
assert_eq!(&Uri::from_static("https://www.amazon.com"), cloned.uri());
assert_eq!("POST", cloned.method());
assert_eq!(2, cloned.headers().len());
assert_eq!("Token: hello", cloned.headers().get(AUTHORIZATION).unwrap(),);
assert_eq!("456", cloned.headers().get(CONTENT_LENGTH).unwrap());
assert_eq!("hello world!".as_bytes(), cloned.body().bytes().unwrap());
}
}