rok-container 0.3.3

IoC service container and dependency injection for the rok ecosystem
Documentation
use std::collections::HashMap;
use std::sync::Arc;

use axum::extract::{FromRequestParts, Path};
use axum::http::request::Parts;
use sqlx::PgPool;

use crate::container::Container;
use crate::error::ContainerError;

/// Trait implemented by models that support route model binding.
///
/// The default implementation uses `"id"` as the route parameter name.
/// Override [`route_param_name`](BoundModel::route_param_name) for slug-based
/// or other custom lookups.
///
/// Requires a [`PgPool`] singleton registered in the [`Container`].
///
/// # Example
///
/// ```rust,ignore
/// use rok_container::{BoundModel, ContainerError};
/// use sqlx::PgPool;
///
/// impl BoundModel for Post {
///     async fn find_by_route_key(key: &str, db: &PgPool) -> Result<Self, ContainerError> {
///         let id: i64 = key.parse().map_err(|_| ContainerError::ModelNotFound)?;
///         sqlx::query_as::<_, Post>("SELECT * FROM posts WHERE id = $1")
///             .bind(id)
///             .fetch_optional(db)
///             .await
///             .map_err(|_| ContainerError::ModelNotFound)?
///             .ok_or(ContainerError::ModelNotFound)
///     }
/// }
/// ```
pub trait BoundModel: Sized + Send + Sync + 'static {
    /// Name of the route path parameter to extract.  Defaults to `"id"`.
    fn route_param_name() -> &'static str {
        "id"
    }

    /// Fetch the model using the extracted route key.
    ///
    /// # Errors
    ///
    /// Return [`ContainerError::ModelNotFound`] when no record matches.
    fn find_by_route_key(
        key: &str,
        db: &PgPool,
    ) -> impl std::future::Future<Output = Result<Self, ContainerError>> + Send;
}

/// Axum extractor that resolves a route model by its path parameter.
///
/// Reads the path param named [`BoundModel::route_param_name`], calls
/// [`BoundModel::find_by_route_key`], and returns `404` when nothing matches.
///
/// Requires both an [`Arc<Container>`](Container) extension **and** a
/// [`PgPool`] singleton registered in the container.
///
/// # Example
///
/// ```rust,ignore
/// use rok_container::Bound;
///
/// // GET /posts/:id  →  fetches Post by id automatically
/// async fn show(Bound(post): Bound<Post>) -> Json<PostResource> {
///     Json(PostResource::from(post))
/// }
/// ```
pub struct Bound<T>(pub T);

impl<T, S> FromRequestParts<S> for Bound<T>
where
    T: BoundModel,
    S: Send + Sync,
{
    type Rejection = ContainerError;

    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
        let Path(params) = Path::<HashMap<String, String>>::from_request_parts(parts, state)
            .await
            .map_err(|_| ContainerError::MissingRouteParam(T::route_param_name()))?;

        let key = params
            .get(T::route_param_name())
            .ok_or(ContainerError::MissingRouteParam(T::route_param_name()))?;

        let container = parts
            .extensions
            .get::<Arc<Container>>()
            .cloned()
            .ok_or(ContainerError::NotRegistered("Container"))?;

        let pool = container.make::<PgPool>()?;

        let model = T::find_by_route_key(key, &pool).await?;
        Ok(Bound(model))
    }
}