use axum::extract::FromRequestParts;
use std::{ops::Deref, sync::Arc};
use crate::{Error, app::get_app};
pub struct Inject<T>(pub Arc<T>);
impl<T> Deref for Inject<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<S, T> FromRequestParts<S> for Inject<T>
where
T: 'static,
S: Send + Sync,
{
type Rejection = Error;
async fn from_request_parts(
_parts: &mut axum::http::request::Parts,
_state: &S,
) -> Result<Self, Self::Rejection> {
Ok(Inject(get_app()?.get::<T>()?))
}
}
#[cfg(test)]
mod tests {
use super::super::Router;
use super::*;
use crate::{
app::{App, context::run_with_app},
container::Container,
};
use axum::body::Body;
use http::Request;
use http_body_util::BodyExt;
use tower::Service;
#[tokio::test]
async fn test_inject() {
#[derive(Clone, Debug, PartialEq, Eq)]
struct MyService {
value: u32,
}
let handler = async |injected: Inject<MyService>| injected.value.to_string();
let mut app = App::new();
app.container.bind(MyService { value: 42 });
let router = Router::compose(|r| {
r.layer(crate::app::layer::AppLayer(app.clone()))
.get("/", handler);
});
let mut router_service = router.build();
let response = run_with_app(
app,
router_service.call(
Request::builder()
.method("GET")
.uri("/")
.body(Body::empty())
.unwrap(),
),
)
.await
.unwrap();
let bytes = response.into_body().collect().await.unwrap().to_bytes();
let body_str = String::from_utf8(bytes.to_vec()).unwrap();
assert_eq!(body_str, "42");
}
}