Skip to main content

axum_config/
lib.rs

1mod error;
2
3use axum::{extract::FromRequestParts, http::request::Parts};
4use error::ErrorResponse;
5
6pub use thisconfig::*;
7
8#[cfg(feature = "byte-unit")]
9pub use thisconfig::ByteConfig;
10
11#[cfg(feature = "time-unit")]
12pub use thisconfig::TimeConfig;
13
14pub struct ExtractConfig<T>(pub T);
15
16impl<S, T> FromRequestParts<S> for ExtractConfig<T>
17where
18    T: ConfigItem,
19    S: Send + Sync,
20{
21    type Rejection = ErrorResponse;
22
23    async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
24        let Some(config) = parts.extensions.get::<Config>() else {
25            tracing::error!("Configuration extension not found in request parts");
26            return Err(ErrorResponse::internal_server_error());
27        };
28
29        let Some(item) = config.get::<T>() else {
30            tracing::error!("Configuration item '{}' not found", T::key());
31            return Err(ErrorResponse::internal_server_error());
32        };
33
34        Ok(ExtractConfig(item))
35    }
36}
37
38pub struct ExtractOptionalConfig<T>(pub Option<T>);
39
40impl<S, T> FromRequestParts<S> for ExtractOptionalConfig<T>
41where
42    T: ConfigItem,
43    S: Send + Sync,
44{
45    type Rejection = ErrorResponse;
46
47    async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
48        let Some(config) = parts.extensions.get::<Config>() else {
49            tracing::error!("Configuration extension not found in request parts");
50            return Err(ErrorResponse::internal_server_error());
51        };
52
53        let item = config.get::<T>();
54
55        Ok(ExtractOptionalConfig(item))
56    }
57}
58
59#[cfg(feature = "validation")]
60pub struct ExtractValidatedConfig<T>(pub T);
61
62#[cfg(feature = "validation")]
63use validator::Validate;
64
65#[cfg(feature = "validation")]
66impl<S, T> FromRequestParts<S> for ExtractValidatedConfig<T>
67where
68    T: ConfigItem + Validate,
69    S: Send + Sync,
70{
71    type Rejection = ErrorResponse;
72
73    async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
74        let Some(config) = parts.extensions.get::<Config>() else {
75            tracing::error!("Configuration extension not found in request parts");
76            return Err(ErrorResponse::internal_server_error());
77        };
78
79        let item = config.get_validated::<T>().map_err(|e| {
80            tracing::error!("Configuration validation failed for '{}': {e}", T::key());
81            ErrorResponse::internal_server_error()
82        })?;
83
84        Ok(ExtractValidatedConfig(item))
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use super::*;
91    use serde::Deserialize;
92
93    #[derive(Debug, Clone, Deserialize)]
94    struct MockConfig {
95        value: String,
96    }
97
98    impl ConfigItem for MockConfig {
99        fn key() -> &'static str {
100            "mock"
101        }
102    }
103
104    #[test]
105    fn test_config_wrapper() {
106        let mock = MockConfig {
107            value: "test".to_string(),
108        };
109
110        let config = ExtractConfig(Some(mock.clone()));
111        assert_eq!(config.0.as_ref().unwrap().value, "test");
112    }
113
114    #[cfg(feature = "validation")]
115    #[test]
116    fn test_validated_config_wrapper() {
117        use validator::Validate;
118
119        #[derive(Debug, Clone, Deserialize, Validate)]
120        struct ValidatedMockConfig {
121            #[validate(length(min = 1))]
122            value: String,
123        }
124
125        impl ConfigItem for ValidatedMockConfig {
126            fn key() -> &'static str {
127                "validated_mock"
128            }
129        }
130
131        let mock = ValidatedMockConfig {
132            value: "test".to_string(),
133        };
134
135        let config = ExtractValidatedConfig(mock.clone());
136        assert_eq!(config.0.value, "test");
137    }
138}