use std::collections::HashSet;
use std::sync::Arc;
use parking_lot::Mutex;
#[derive(Clone)]
pub struct InjectedErrors(Arc<Mutex<HashSet<String>>>);
impl Default for InjectedErrors {
fn default() -> Self {
Self::new()
}
}
impl InjectedErrors {
pub fn new() -> Self {
Self(Arc::new(Mutex::new(HashSet::new())))
}
pub fn inject(&self, method: &str) {
self.0.lock().insert(method.to_owned());
}
pub fn clear(&self, method: &str) {
self.0.lock().remove(method);
}
pub fn clear_all(&self) {
self.0.lock().clear();
}
fn check_path(&self, path: &str) -> Option<String> {
let method = path.rsplit('/').next().unwrap_or(path);
let set = self.0.lock();
if set.contains(method) {
Some(method.to_owned())
} else {
None
}
}
}
#[derive(Clone)]
pub struct ErrorInjectionLayer {
errors: InjectedErrors,
}
impl ErrorInjectionLayer {
pub fn new(errors: InjectedErrors) -> Self {
Self { errors }
}
}
impl<S> tower::Layer<S> for ErrorInjectionLayer {
type Service = ErrorInjectionService<S>;
fn layer(&self, service: S) -> Self::Service {
ErrorInjectionService {
inner: service,
errors: self.errors.clone(),
}
}
}
#[derive(Clone)]
pub struct ErrorInjectionService<S> {
inner: S,
errors: InjectedErrors,
}
impl<S, ReqBody, ResBody> tower::Service<http::Request<ReqBody>> for ErrorInjectionService<S>
where
S: tower::Service<http::Request<ReqBody>, Response = http::Response<ResBody>>
+ Clone
+ Send
+ 'static,
S::Future: Send,
S::Error: Send,
ReqBody: Send + 'static,
ResBody: Default + Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future =
std::pin::Pin<Box<dyn std::future::Future<Output = Result<S::Response, S::Error>> + Send>>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
if let Some(method) = self.errors.check_path(req.uri().path()) {
let status = tonic::Status::not_found(format!(
"injected error for testing: {method} deliberately failed"
));
return Box::pin(async move { Ok(status.into_http()) });
}
let clone = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, clone);
Box::pin(async move { inner.call(req).await })
}
}