1mod error;
2
3use axum::{extract::FromRequestParts, http::request::Parts};
4use error::ErrorResponse;
5
6pub use thisconfig::*;
7
8#[cfg(feature = "byte-unit")]
9pub use thisconfig::ByteConfig;
10
11#[cfg(feature = "time-unit")]
12pub use thisconfig::TimeConfig;
13
14pub struct ExtractConfig<T>(pub T);
15
16impl<S, T> FromRequestParts<S> for ExtractConfig<T>
17where
18 T: ConfigItem,
19 S: Send + Sync,
20{
21 type Rejection = ErrorResponse;
22
23 async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
24 let Some(config) = parts.extensions.get::<Config>() else {
25 tracing::error!("Configuration extension not found in request parts");
26 return Err(ErrorResponse::internal_server_error());
27 };
28
29 let Some(item) = config.get::<T>() else {
30 tracing::error!("Configuration item '{}' not found", T::key());
31 return Err(ErrorResponse::internal_server_error());
32 };
33
34 Ok(ExtractConfig(item))
35 }
36}
37
38pub struct ExtractOptionalConfig<T>(pub Option<T>);
39
40impl<S, T> FromRequestParts<S> for ExtractOptionalConfig<T>
41where
42 T: ConfigItem,
43 S: Send + Sync,
44{
45 type Rejection = ErrorResponse;
46
47 async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
48 let Some(config) = parts.extensions.get::<Config>() else {
49 tracing::error!("Configuration extension not found in request parts");
50 return Err(ErrorResponse::internal_server_error());
51 };
52
53 let item = config.get::<T>();
54
55 Ok(ExtractOptionalConfig(item))
56 }
57}
58
59#[cfg(feature = "validation")]
60pub struct ExtractValidatedConfig<T>(pub T);
61
62#[cfg(feature = "validation")]
63use validator::Validate;
64
65#[cfg(feature = "validation")]
66impl<S, T> FromRequestParts<S> for ExtractValidatedConfig<T>
67where
68 T: ConfigItem + Validate,
69 S: Send + Sync,
70{
71 type Rejection = ErrorResponse;
72
73 async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
74 let Some(config) = parts.extensions.get::<Config>() else {
75 tracing::error!("Configuration extension not found in request parts");
76 return Err(ErrorResponse::internal_server_error());
77 };
78
79 let item = config.get_validated::<T>().map_err(|e| {
80 tracing::error!("Configuration validation failed for '{}': {e}", T::key());
81 ErrorResponse::internal_server_error()
82 })?;
83
84 Ok(ExtractValidatedConfig(item))
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use super::*;
91 use serde::Deserialize;
92
93 #[derive(Debug, Clone, Deserialize)]
94 struct MockConfig {
95 value: String,
96 }
97
98 impl ConfigItem for MockConfig {
99 fn key() -> &'static str {
100 "mock"
101 }
102 }
103
104 #[test]
105 fn test_config_wrapper() {
106 let mock = MockConfig {
107 value: "test".to_string(),
108 };
109
110 let config = ExtractConfig(Some(mock.clone()));
111 assert_eq!(config.0.as_ref().unwrap().value, "test");
112 }
113
114 #[cfg(feature = "validation")]
115 #[test]
116 fn test_validated_config_wrapper() {
117 use validator::Validate;
118
119 #[derive(Debug, Clone, Deserialize, Validate)]
120 struct ValidatedMockConfig {
121 #[validate(length(min = 1))]
122 value: String,
123 }
124
125 impl ConfigItem for ValidatedMockConfig {
126 fn key() -> &'static str {
127 "validated_mock"
128 }
129 }
130
131 let mock = ValidatedMockConfig {
132 value: "test".to_string(),
133 };
134
135 let config = ExtractValidatedConfig(mock.clone());
136 assert_eq!(config.0.value, "test");
137 }
138}