use aws_smithy_runtime_api::client::interceptors::context::{Error, Input, Output};
use aws_smithy_runtime_api::client::orchestrator::HttpResponse;
use aws_smithy_runtime_api::client::result::SdkError;
use aws_smithy_runtime_api::http::StatusCode;
use aws_smithy_types::body::SdkBody;
use std::fmt;
use std::future::Future;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
#[derive(Debug)]
#[allow(clippy::large_enum_variant)]
pub enum MockResponse<O, E> {
Output(O),
Error(E),
Http(HttpResponse),
}
type MatchFn = Arc<dyn Fn(&Input) -> bool + Send + Sync>;
type ServeFn = Arc<dyn Fn(usize, &Input) -> Option<MockResponse<Output, Error>> + Send + Sync>;
#[derive(Clone)]
pub struct Rule {
pub(crate) matcher: MatchFn,
response_handler: ServeFn,
call_count: Arc<AtomicUsize>,
pub(crate) max_responses: usize,
is_simple: bool,
}
impl fmt::Debug for Rule {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Rule")
}
}
impl Rule {
#[allow(clippy::type_complexity)]
pub(crate) fn new<O, E>(
matcher: MatchFn,
response_handler: Arc<dyn Fn(usize, &Input) -> Option<MockResponse<O, E>> + Send + Sync>,
max_responses: usize,
is_simple: bool,
) -> Self
where
O: fmt::Debug + Send + Sync + 'static,
E: fmt::Debug + Send + Sync + std::error::Error + 'static,
{
Rule {
matcher,
response_handler: Arc::new(move |idx: usize, input: &Input| {
if idx < max_responses {
response_handler(idx, input).map(|resp| match resp {
MockResponse::Output(o) => MockResponse::Output(Output::erase(o)),
MockResponse::Error(e) => MockResponse::Error(Error::erase(e)),
MockResponse::Http(http_resp) => MockResponse::Http(http_resp),
})
} else {
None
}
}),
call_count: Arc::new(AtomicUsize::new(0)),
max_responses,
is_simple,
}
}
pub(crate) fn is_simple(&self) -> bool {
self.is_simple
}
pub(crate) fn next_response(&self, input: &Input) -> Option<MockResponse<Output, Error>> {
let idx = self.call_count.fetch_add(1, Ordering::SeqCst);
(self.response_handler)(idx, input)
}
pub fn num_calls(&self) -> usize {
self.call_count.load(Ordering::SeqCst)
}
pub fn is_exhausted(&self) -> bool {
self.num_calls() >= self.max_responses
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RuleMode {
Sequential,
MatchAny,
}
pub struct RuleBuilder<I, O, E> {
pub(crate) input_filter: MatchFn,
pub(crate) _ty: std::marker::PhantomData<(I, O, E)>,
}
impl<I, O, E> RuleBuilder<I, O, E>
where
I: fmt::Debug + Send + Sync + 'static,
O: fmt::Debug + Send + Sync + 'static,
E: fmt::Debug + Send + Sync + std::error::Error + 'static,
{
#[doc(hidden)]
pub fn new() -> Self {
RuleBuilder {
input_filter: Arc::new(|i: &Input| i.downcast_ref::<I>().is_some()),
_ty: std::marker::PhantomData,
}
}
#[doc(hidden)]
pub fn new_from_mock<F, R>(_input_hint: impl Fn() -> I, _output_hint: impl Fn() -> F) -> Self
where
F: Future<Output = Result<O, SdkError<E, R>>>,
{
Self {
input_filter: Arc::new(|i: &Input| i.downcast_ref::<I>().is_some()),
_ty: Default::default(),
}
}
pub fn match_requests<F>(mut self, filter: F) -> Self
where
F: Fn(&I) -> bool + Send + Sync + 'static,
{
self.input_filter = Arc::new(move |i: &Input| match i.downcast_ref::<I>() {
Some(typed_input) => filter(typed_input),
_ => false,
});
self
}
pub fn sequence(self) -> ResponseSequenceBuilder<I, O, E> {
ResponseSequenceBuilder::new(self.input_filter)
}
pub fn then_output<F>(self, output_fn: F) -> Rule
where
F: Fn() -> O + Send + Sync + 'static,
{
self.sequence().output(output_fn).build_simple()
}
pub fn then_error<F>(self, error_fn: F) -> Rule
where
F: Fn() -> E + Send + Sync + 'static,
{
self.sequence().error(error_fn).build_simple()
}
pub fn then_http_response<F>(self, response_fn: F) -> Rule
where
F: Fn() -> HttpResponse + Send + Sync + 'static,
{
self.sequence().http_response(response_fn).build_simple()
}
pub fn then_compute_output<F>(self, compute_fn: F) -> Rule
where
F: Fn(&I) -> O + Send + Sync + 'static,
{
self.sequence().compute_output(compute_fn).build_simple()
}
pub fn then_compute_response<F>(self, compute_fn: F) -> Rule
where
F: Fn(&I) -> MockResponse<O, E> + Send + Sync + 'static,
{
self.sequence().compute_response(compute_fn).build_simple()
}
}
type SequenceGeneratorFn<O, E> = Arc<dyn Fn(&Input) -> MockResponse<O, E> + Send + Sync>;
pub struct ResponseSequenceBuilder<I, O, E> {
generators: Vec<(SequenceGeneratorFn<O, E>, usize)>,
input_filter: MatchFn,
is_simple: bool,
_marker: std::marker::PhantomData<I>,
}
pub struct FinalizedResponseSequenceBuilder<I, O, E> {
inner: ResponseSequenceBuilder<I, O, E>,
}
impl<I, O, E> ResponseSequenceBuilder<I, O, E>
where
I: fmt::Debug + Send + Sync + 'static,
O: fmt::Debug + Send + Sync + 'static,
E: fmt::Debug + Send + Sync + std::error::Error + 'static,
{
pub(crate) fn new(input_filter: MatchFn) -> Self {
Self {
generators: Vec::new(),
input_filter,
is_simple: false,
_marker: std::marker::PhantomData,
}
}
pub fn output<F>(mut self, output_fn: F) -> Self
where
F: Fn() -> O + Send + Sync + 'static,
{
let generator = Arc::new(move |_input: &Input| MockResponse::Output(output_fn()));
self.generators.push((generator, 1));
self
}
pub fn error<F>(mut self, error_fn: F) -> Self
where
F: Fn() -> E + Send + Sync + 'static,
{
let generator = Arc::new(move |_input: &Input| MockResponse::Error(error_fn()));
self.generators.push((generator, 1));
self
}
pub fn http_status(mut self, status: u16, body: Option<String>) -> Self {
let status_code = StatusCode::try_from(status).unwrap();
let generator: SequenceGeneratorFn<O, E> = match body {
Some(body) => Arc::new(move |_input: &Input| {
MockResponse::Http(HttpResponse::new(status_code, SdkBody::from(body.clone())))
}),
None => Arc::new(move |_input: &Input| {
MockResponse::Http(HttpResponse::new(status_code, SdkBody::empty()))
}),
};
self.generators.push((generator, 1));
self
}
pub fn http_response<F>(mut self, response_fn: F) -> Self
where
F: Fn() -> HttpResponse + Send + Sync + 'static,
{
let generator = Arc::new(move |_input: &Input| MockResponse::Http(response_fn()));
self.generators.push((generator, 1));
self
}
fn compute_output<F>(mut self, compute_fn: F) -> Self
where
F: Fn(&I) -> O + Send + Sync + 'static,
{
let generator = Arc::new(move |input: &Input| {
if let Some(typed_input) = input.downcast_ref::<I>() {
MockResponse::Output(compute_fn(typed_input))
} else {
panic!("Input type mismatch in compute_output")
}
});
self.generators.push((generator, 1));
self
}
fn compute_response<F>(mut self, compute_fn: F) -> Self
where
F: Fn(&I) -> MockResponse<O, E> + Send + Sync + 'static,
{
let generator = Arc::new(move |input: &Input| {
if let Some(typed_input) = input.downcast_ref::<I>() {
compute_fn(typed_input)
} else {
panic!("Input type mismatch in compute_response")
}
});
self.generators.push((generator, 1));
self
}
pub fn times(mut self, count: usize) -> Self {
if self.generators.is_empty() {
panic!("times(n) called before adding a response to the sequence");
}
match count {
0 => panic!("repeat count must be greater than zero"),
1 => {
return self;
}
_ => {}
}
if let Some(last_generator) = self.generators.last_mut() {
last_generator.1 = count;
}
self
}
pub fn repeatedly(self) -> FinalizedResponseSequenceBuilder<I, O, E> {
if self.generators.is_empty() {
panic!("repeatedly() called before adding a response to the sequence");
}
let inner = self.times(usize::MAX);
FinalizedResponseSequenceBuilder { inner }
}
pub(crate) fn build_simple(mut self) -> Rule {
self.is_simple = true;
self.repeatedly().build()
}
pub fn build(self) -> Rule {
let generators = self.generators;
let is_simple = self.is_simple;
let total_responses: usize = generators
.iter()
.map(|(_, count)| *count)
.fold(0, |acc, count| acc.saturating_add(count));
Rule::new(
self.input_filter,
Arc::new(move |idx, input| {
let mut current_idx = idx;
for (generator, repeat_count) in &generators {
if current_idx < *repeat_count {
return Some(generator(input));
}
current_idx -= repeat_count;
}
None
}),
total_responses,
is_simple,
)
}
}
impl<I, O, E> FinalizedResponseSequenceBuilder<I, O, E>
where
I: fmt::Debug + Send + Sync + 'static,
O: fmt::Debug + Send + Sync + 'static,
E: fmt::Debug + Send + Sync + std::error::Error + 'static,
{
pub fn build(self) -> Rule {
self.inner.build()
}
}