use crate::{Context, Layer, Service};
use rama_utils::macros::define_inner_service_accessors;
use std::fmt;
pub struct AddExtensionLayer<T> {
value: T,
}
impl<T: fmt::Debug> std::fmt::Debug for AddExtensionLayer<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AddExtensionLayer")
.field("value", &self.value)
.finish()
}
}
impl<T> Clone for AddExtensionLayer<T>
where
T: Clone,
{
fn clone(&self) -> Self {
Self {
value: self.value.clone(),
}
}
}
impl<T> AddExtensionLayer<T> {
pub const fn new(value: T) -> Self {
AddExtensionLayer { value }
}
}
impl<S, T> Layer<S> for AddExtensionLayer<T>
where
T: Clone,
{
type Service = AddExtension<S, T>;
fn layer(&self, inner: S) -> Self::Service {
AddExtension {
inner,
value: self.value.clone(),
}
}
fn into_layer(self, inner: S) -> Self::Service {
AddExtension {
inner,
value: self.value,
}
}
}
pub struct AddExtension<S, T> {
inner: S,
value: T,
}
impl<S: fmt::Debug, T: fmt::Debug> std::fmt::Debug for AddExtension<S, T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AddExtension")
.field("inner", &self.inner)
.field("value", &self.value)
.finish()
}
}
impl<S, T> Clone for AddExtension<S, T>
where
S: Clone,
T: Clone,
{
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
value: self.value.clone(),
}
}
}
impl<S, T> AddExtension<S, T> {
pub const fn new(inner: S, value: T) -> Self {
Self { inner, value }
}
define_inner_service_accessors!();
}
impl<State, Request, S, T> Service<State, Request> for AddExtension<S, T>
where
State: Clone + Send + Sync + 'static,
Request: Send + 'static,
S: Service<State, Request>,
T: Clone + Send + Sync + 'static,
{
type Response = S::Response;
type Error = S::Error;
fn serve(
&self,
mut ctx: Context<State>,
req: Request,
) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
ctx.insert(self.value.clone());
self.inner.serve(ctx, req)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Context, service::service_fn};
use std::{convert::Infallible, sync::Arc};
struct State(i32);
#[tokio::test]
async fn basic() {
let state = Arc::new(State(1));
let svc = AddExtensionLayer::new(state).into_layer(service_fn(
async |ctx: Context<()>, _req: ()| {
let state = ctx.get::<Arc<State>>().unwrap();
Ok::<_, Infallible>(state.0)
},
));
let res = svc.serve(Context::default(), ()).await.unwrap();
assert_eq!(1, res);
}
}