use aws_smithy_protocol_test::{assert_ok, validate_body, MediaType};
use aws_smithy_runtime_api::client::connector_metadata::ConnectorMetadata;
use aws_smithy_runtime_api::client::http::{
HttpClient, HttpConnector, HttpConnectorFuture, HttpConnectorSettings, SharedHttpConnector,
};
use aws_smithy_runtime_api::client::orchestrator::{HttpRequest, HttpResponse};
use aws_smithy_runtime_api::client::result::ConnectorError;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
use aws_smithy_runtime_api::shared::IntoShared;
use http_1x::header::CONTENT_TYPE;
use std::ops::Deref;
use std::sync::{Arc, Mutex, MutexGuard};
type ReplayEvents = Vec<ReplayEvent>;
pub(crate) const DEFAULT_RELAXED_HEADERS: &[&str] = &["x-amz-user-agent", "authorization"];
#[derive(Debug)]
pub struct ReplayEvent {
request: HttpRequest,
response: HttpResponse,
}
impl ReplayEvent {
pub fn new(request: impl TryInto<HttpRequest>, response: impl TryInto<HttpResponse>) -> Self {
Self {
request: request.try_into().ok().expect("invalid request"),
response: response.try_into().ok().expect("invalid response"),
}
}
pub fn request(&self) -> &HttpRequest {
&self.request
}
pub fn response(&self) -> &HttpResponse {
&self.response
}
}
impl From<(HttpRequest, HttpResponse)> for ReplayEvent {
fn from((request, response): (HttpRequest, HttpResponse)) -> Self {
Self::new(request, response)
}
}
#[derive(Debug)]
struct ValidateRequest {
expected: HttpRequest,
actual: HttpRequest,
}
impl ValidateRequest {
fn assert_matches(&self, index: usize, ignore_headers: &[&str]) {
let (actual, expected) = (&self.actual, &self.expected);
assert_eq!(
expected.uri(),
actual.uri(),
"request[{index}] - URI doesn't match expected value"
);
for (name, value) in expected.headers() {
if !ignore_headers.contains(&name) {
let actual_header = actual
.headers()
.get(name)
.unwrap_or_else(|| panic!("Request #{index} - Header {name:?} is missing"));
assert_eq!(
value, actual_header,
"request[{index}] - Header {name:?} doesn't match expected value",
);
}
}
let actual_str = std::str::from_utf8(actual.body().bytes().unwrap_or(&[]));
let expected_str = std::str::from_utf8(expected.body().bytes().unwrap_or(&[]));
let media_type = if actual
.headers()
.get(CONTENT_TYPE)
.map(|v| v.contains("json"))
.unwrap_or(false)
{
MediaType::Json
} else {
MediaType::Other("unknown".to_string())
};
match (actual_str, expected_str) {
(Ok(actual), Ok(expected)) => assert_ok(validate_body(actual, expected, media_type)),
_ => assert_eq!(
expected.body().bytes(),
actual.body().bytes(),
"request[{index}] - Body contents didn't match expected value"
),
};
}
}
#[derive(Clone, Debug)]
pub struct StaticReplayClient {
data: Arc<Mutex<ReplayEvents>>,
requests: Arc<Mutex<Vec<ValidateRequest>>>,
}
impl StaticReplayClient {
pub fn new(mut data: ReplayEvents) -> Self {
data.reverse();
StaticReplayClient {
data: Arc::new(Mutex::new(data)),
requests: Default::default(),
}
}
pub fn actual_requests(&self) -> impl Iterator<Item = &HttpRequest> + '_ {
struct Iter<'a> {
_guard: MutexGuard<'a, Vec<ValidateRequest>>,
values: *const ValidateRequest,
len: usize,
next_index: usize,
}
impl<'a> Iterator for Iter<'a> {
type Item = &'a HttpRequest;
fn next(&mut self) -> Option<Self::Item> {
if self.next_index >= self.len {
None
} else {
let next = unsafe {
let offset = self.values.add(self.next_index);
&*offset
};
self.next_index += 1;
Some(&next.actual)
}
}
}
let guard = self.requests.lock().unwrap();
Iter {
values: guard.as_ptr(),
len: guard.len(),
_guard: guard,
next_index: 0,
}
}
fn requests(&self) -> impl Deref<Target = Vec<ValidateRequest>> + '_ {
self.requests.lock().unwrap()
}
#[track_caller]
pub fn assert_requests_match(&self, ignore_headers: &[&str]) {
for (i, req) in self.requests().iter().enumerate() {
req.assert_matches(i, ignore_headers)
}
let remaining_requests = self.data.lock().unwrap();
assert!(
remaining_requests.is_empty(),
"Expected {} additional requests (only {} sent)",
remaining_requests.len(),
self.requests().len()
);
}
#[track_caller]
pub fn relaxed_requests_match(&self) {
self.assert_requests_match(DEFAULT_RELAXED_HEADERS)
}
}
impl HttpConnector for StaticReplayClient {
fn call(&self, request: HttpRequest) -> HttpConnectorFuture {
let res = if let Some(event) = self.data.lock().unwrap().pop() {
self.requests.lock().unwrap().push(ValidateRequest {
expected: event.request,
actual: request,
});
Ok(event.response)
} else {
Err(ConnectorError::other(
"StaticReplayClient: no more test data available to respond with".into(),
None,
))
};
HttpConnectorFuture::new(async move { res })
}
}
impl HttpClient for StaticReplayClient {
fn http_connector(
&self,
_: &HttpConnectorSettings,
_: &RuntimeComponents,
) -> SharedHttpConnector {
self.clone().into_shared()
}
fn connector_metadata(&self) -> Option<ConnectorMetadata> {
Some(ConnectorMetadata::new("static-replay-client", None))
}
}
#[cfg(test)]
mod test {
use crate::test_util::{ReplayEvent, StaticReplayClient};
use aws_smithy_types::body::SdkBody;
#[test]
fn create_from_either_http_type() {
let _client = StaticReplayClient::new(vec![ReplayEvent::new(
http_1x::Request::builder()
.uri("test")
.body(SdkBody::from("hello"))
.unwrap(),
http_1x::Response::builder()
.status(200)
.body(SdkBody::from("hello"))
.unwrap(),
)]);
}
}