use std::collections::HashMap;
use std::sync::Arc;
use axum::extract::{FromRequestParts, Path};
use axum::http::request::Parts;
use sqlx::PgPool;
use super::container::Container;
use super::error::ContainerError;
pub trait BoundModel: Sized + Send + Sync + 'static {
fn route_param_name() -> &'static str {
"id"
}
fn find_by_route_key(
key: &str,
db: &PgPool,
) -> impl std::future::Future<Output = Result<Self, ContainerError>> + Send;
}
pub struct Bound<T>(pub T);
impl<T, S> FromRequestParts<S> for Bound<T>
where
T: BoundModel,
S: Send + Sync,
{
type Rejection = ContainerError;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let Path(params) = Path::<HashMap<String, String>>::from_request_parts(parts, state)
.await
.map_err(|_| ContainerError::MissingRouteParam(T::route_param_name()))?;
let key = params
.get(T::route_param_name())
.ok_or(ContainerError::MissingRouteParam(T::route_param_name()))?;
let container = parts
.extensions
.get::<Arc<Container>>()
.cloned()
.ok_or(ContainerError::NotRegistered("Container"))?;
let pool = container.make::<PgPool>()?;
let model = T::find_by_route_key(key, &pool).await?;
Ok(Bound(model))
}
}