meld-server 0.1.0

Single-port REST + gRPC server framework with FastAPI-like DX for Rust.
use std::{
    any::{Any, TypeId},
    collections::HashMap,
    sync::{Arc, Mutex},
};

use axum::{
    extract::{FromRef, FromRequestParts},
    http::{request::Parts, StatusCode},
    Extension, Json, Router,
};

use crate::api::{ApiError, ApiErrorResponse};

#[derive(Debug, Clone)]
pub struct Depends<T>(pub T);

#[derive(Debug, Clone)]
pub struct DependencyOverride<T>(pub T);

#[derive(Clone, Default)]
pub struct DependencyOverrides {
    values: Arc<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>,
}

impl DependencyOverrides {
    pub fn with<T>(mut self, value: T) -> Self
    where
        T: Clone + Send + Sync + 'static,
    {
        let map = Arc::make_mut(&mut self.values);
        map.insert(TypeId::of::<T>(), Arc::new(value));
        self
    }

    pub fn get<T>(&self) -> Option<T>
    where
        T: Clone + Send + Sync + 'static,
    {
        self.values
            .get(&TypeId::of::<T>())
            .and_then(|value| value.as_ref().downcast_ref::<T>())
            .cloned()
    }
}

#[derive(Clone, Default)]
struct DependencyCache {
    values: Arc<Mutex<HashMap<TypeId, Box<dyn Any + Send + Sync>>>>,
}

#[derive(Debug)]
enum DependencyCacheError {
    Poisoned,
}

impl DependencyCache {
    fn get<T>(&self) -> Result<Option<T>, DependencyCacheError>
    where
        T: Clone + Send + Sync + 'static,
    {
        let guard = self
            .values
            .lock()
            .map_err(|_| DependencyCacheError::Poisoned)?;
        Ok(guard
            .get(&TypeId::of::<T>())
            .and_then(|value| value.downcast_ref::<T>())
            .cloned())
    }

    fn insert<T>(&mut self, value: T) -> Result<(), DependencyCacheError>
    where
        T: Clone + Send + Sync + 'static,
    {
        let mut guard = self
            .values
            .lock()
            .map_err(|_| DependencyCacheError::Poisoned)?;
        guard.insert(TypeId::of::<T>(), Box::new(value));
        Ok(())
    }
}

#[axum::async_trait]
impl<T, S> FromRequestParts<S> for Depends<T>
where
    T: FromRef<S> + Clone + Send + Sync + 'static,
    S: Send + Sync,
{
    type Rejection = ApiError;

    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
        if let Some(cache) = parts.extensions.get::<DependencyCache>() {
            if let Some(value) = cache.get::<T>().map_err(|err| {
                tracing::error!(?err, "failed to read request dependency cache");
                internal_di_error("request dependency cache is unavailable")
            })? {
                return Ok(Self(value));
            }
        }

        let value = if let Some(overrides) = parts.extensions.get::<DependencyOverrides>() {
            if let Some(value) = overrides.get::<T>() {
                value
            } else if let Some(override_value) = parts.extensions.get::<DependencyOverride<T>>() {
                override_value.0.clone()
            } else {
                T::from_ref(state)
            }
        } else if let Some(override_value) = parts.extensions.get::<DependencyOverride<T>>() {
            override_value.0.clone()
        } else {
            T::from_ref(state)
        };

        if let Some(cache) = parts.extensions.get_mut::<DependencyCache>() {
            cache.insert(value.clone()).map_err(|err| {
                tracing::error!(?err, "failed to write request dependency cache");
                internal_di_error("request dependency cache is unavailable")
            })?;
        } else {
            let mut cache = DependencyCache::default();
            cache.insert(value.clone()).map_err(|err| {
                tracing::error!(?err, "failed to initialize request dependency cache");
                internal_di_error("request dependency cache is unavailable")
            })?;
            parts.extensions.insert(cache);
        }

        Ok(Self(value))
    }
}

pub fn with_dependency<S, T>(router: Router<S>, value: T) -> Router<S>
where
    S: Clone + Send + Sync + 'static,
    T: Clone + Send + Sync + 'static,
{
    with_dependency_overrides(router, DependencyOverrides::default().with(value))
}

pub fn with_dependency_override<S, T>(router: Router<S>, value: T) -> Router<S>
where
    S: Clone + Send + Sync + 'static,
    T: Clone + Send + Sync + 'static,
{
    router.layer(Extension(DependencyOverride(value)))
}

pub fn with_dependency_overrides<S>(router: Router<S>, overrides: DependencyOverrides) -> Router<S>
where
    S: Clone + Send + Sync + 'static,
{
    router.layer(Extension(overrides))
}

pub fn internal_di_error(message: impl Into<String>) -> ApiError {
    (
        StatusCode::INTERNAL_SERVER_ERROR,
        Json(ApiErrorResponse {
            code: "internal_error".to_string(),
            message: message.into(),
            detail: None,
            details: None,
        }),
    )
}

#[cfg(test)]
mod tests {
    use super::*;
    use axum::{
        body::to_bytes, extract::State, http::Request, response::IntoResponse, routing::get,
    };
    use serde::Serialize;
    use std::sync::{
        atomic::{AtomicUsize, Ordering},
        Arc,
    };
    use tower::util::ServiceExt;

    #[derive(Clone)]
    struct TestState {
        label: String,
        build_counter: Arc<AtomicUsize>,
    }

    #[derive(Clone)]
    struct LabelDep {
        label: String,
    }

    #[derive(Clone)]
    struct VersionDep {
        version: String,
    }

    impl FromRef<TestState> for VersionDep {
        fn from_ref(_state: &TestState) -> Self {
            Self {
                version: "from-state".to_string(),
            }
        }
    }

    impl FromRef<TestState> for LabelDep {
        fn from_ref(state: &TestState) -> Self {
            state.build_counter.fetch_add(1, Ordering::SeqCst);
            Self {
                label: state.label.clone(),
            }
        }
    }

    #[derive(Serialize)]
    struct DepResponse {
        a: String,
        b: String,
    }

    async fn dep_handler(
        Depends(a): Depends<LabelDep>,
        Depends(b): Depends<LabelDep>,
    ) -> impl IntoResponse {
        Json(DepResponse {
            a: a.label,
            b: b.label,
        })
    }

    #[tokio::test]
    async fn depends_uses_request_scoped_cache() {
        let counter = Arc::new(AtomicUsize::new(0));
        let state = TestState {
            label: "state-value".to_string(),
            build_counter: counter.clone(),
        };
        let app = Router::new()
            .route("/dep", get(dep_handler))
            .with_state(state);

        let response = app
            .oneshot(
                Request::builder()
                    .uri("/dep")
                    .body(axum::body::Body::empty())
                    .unwrap(),
            )
            .await
            .expect("request should complete");

        assert_eq!(response.status(), axum::http::StatusCode::OK);
        let bytes = to_bytes(response.into_body(), usize::MAX)
            .await
            .expect("body bytes");
        let body = String::from_utf8(bytes.to_vec()).expect("utf8 body");
        assert!(body.contains("state-value"));
        assert_eq!(counter.load(Ordering::SeqCst), 1);
    }

    #[tokio::test]
    async fn depends_override_wins_over_state_resolution() {
        let counter = Arc::new(AtomicUsize::new(0));
        let state = TestState {
            label: "state-value".to_string(),
            build_counter: counter.clone(),
        };
        let app = with_dependency_override(
            Router::new()
                .route("/dep", get(dep_handler))
                .with_state(state),
            LabelDep {
                label: "override-value".to_string(),
            },
        );

        let response = app
            .oneshot(
                Request::builder()
                    .uri("/dep")
                    .body(axum::body::Body::empty())
                    .unwrap(),
            )
            .await
            .expect("request should complete");

        assert_eq!(response.status(), axum::http::StatusCode::OK);
        let bytes = to_bytes(response.into_body(), usize::MAX)
            .await
            .expect("body bytes");
        let body = String::from_utf8(bytes.to_vec()).expect("utf8 body");
        assert!(body.contains("override-value"));
        assert_eq!(counter.load(Ordering::SeqCst), 0);
    }

    async fn state_observe_handler(
        Depends(dep): Depends<LabelDep>,
        State(state): State<TestState>,
    ) -> impl IntoResponse {
        Json(DepResponse {
            a: dep.label,
            b: state.label,
        })
    }

    #[tokio::test]
    async fn depends_is_backward_compatible_with_state_extractor() {
        let state = TestState {
            label: "state-value".to_string(),
            build_counter: Arc::new(AtomicUsize::new(0)),
        };
        let app = Router::new()
            .route("/dep", get(state_observe_handler))
            .with_state(state);

        let response = app
            .oneshot(
                Request::builder()
                    .uri("/dep")
                    .body(axum::body::Body::empty())
                    .unwrap(),
            )
            .await
            .expect("request should complete");

        assert_eq!(response.status(), axum::http::StatusCode::OK);
    }

    #[derive(Serialize, serde::Deserialize)]
    struct MultiDepResponse {
        label: String,
        version: String,
    }

    async fn multi_dep_handler(
        Depends(label): Depends<LabelDep>,
        Depends(version): Depends<VersionDep>,
    ) -> impl IntoResponse {
        Json(MultiDepResponse {
            label: label.label,
            version: version.version,
        })
    }

    #[tokio::test]
    async fn depends_supports_multiple_overrides() {
        let counter = Arc::new(AtomicUsize::new(0));
        let state = TestState {
            label: "state-value".to_string(),
            build_counter: counter.clone(),
        };
        let overrides = DependencyOverrides::default()
            .with(LabelDep {
                label: "override-label".to_string(),
            })
            .with(VersionDep {
                version: "v2".to_string(),
            });

        let app = with_dependency_overrides(
            Router::new()
                .route("/dep", get(multi_dep_handler))
                .with_state(state),
            overrides,
        );

        let response = app
            .oneshot(
                Request::builder()
                    .uri("/dep")
                    .body(axum::body::Body::empty())
                    .unwrap(),
            )
            .await
            .expect("request should complete");

        assert_eq!(response.status(), axum::http::StatusCode::OK);
        let bytes = to_bytes(response.into_body(), usize::MAX)
            .await
            .expect("body bytes");
        let body: MultiDepResponse = serde_json::from_slice(&bytes).expect("json body");
        assert_eq!(body.label, "override-label");
        assert_eq!(body.version, "v2");
        assert_eq!(counter.load(Ordering::SeqCst), 0);
    }
}