1use std::{collections::BTreeMap, net::SocketAddr, time::Duration};
2
3use serde::Deserialize;
4
5use crate::core::{
6 ConfigFeatureWarning, CoreError, CoreResult, DatabaseSection, LogConfig, LogSection,
7 RpcClientSection, ServiceConfig, dependency_feature_warnings, load_config,
8};
9
10#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
12#[serde(default, deny_unknown_fields)]
13pub struct RestServiceConfig {
14 pub name: String,
16 pub mode: String,
18 pub server: RestServerSection,
20 pub log: LogSection,
22 pub auth: Option<RestAuthSection>,
24 pub middlewares: RestMiddlewaresSection,
26 pub rpc_clients: BTreeMap<String, RpcClientSection>,
28 pub database: Option<DatabaseSection>,
30}
31
32impl Default for RestServiceConfig {
33 fn default() -> Self {
34 let service = ServiceConfig::default();
35 Self {
36 name: service.name,
37 mode: service.mode,
38 server: RestServerSection::default(),
39 log: service.log,
40 auth: None,
41 middlewares: RestMiddlewaresSection::default(),
42 rpc_clients: BTreeMap::new(),
43 database: None,
44 }
45 }
46}
47
48impl RestServiceConfig {
49 pub fn load(basename: &str, env_prefix: &str) -> Result<Self, config::ConfigError> {
51 load_config(basename, env_prefix)
52 }
53
54 pub fn addr(&self) -> CoreResult<SocketAddr> {
56 format!("{}:{}", self.server.host, self.server.port)
57 .parse()
58 .map_err(|error| {
59 config::ConfigError::Message(format!("invalid REST listen address: {error}")).into()
60 })
61 }
62
63 pub fn log_config(&self) -> LogConfig {
65 self.log.to_log_config(&self.name)
66 }
67
68 pub fn rest_config(&self) -> crate::rest::RestConfig {
70 let mut config = if self.middlewares.resilience || self.middlewares.metrics {
71 crate::rest::RestConfig::production_defaults(self.name.clone())
72 } else {
73 crate::rest::RestConfig {
74 name: self.name.clone(),
75 ..crate::rest::RestConfig::default()
76 }
77 };
78 config.timeout = Duration::from_millis(self.server.timeout_ms);
79 config.max_body_bytes = self.server.max_body_bytes;
80 config.middlewares.metrics.enabled = self.middlewares.metrics;
81 if !self.middlewares.resilience {
82 config.middlewares.resilience = crate::rest::RestResilienceConfig::default();
83 }
84 config.auth = self.auth.as_ref().and_then(RestAuthSection::auth_config);
85 config
86 }
87
88 pub fn validate_features(&self) -> Vec<ConfigFeatureWarning> {
90 let mut warnings = Vec::new();
91 if self.middlewares.metrics && !cfg!(feature = "observability") {
92 warnings.push(ConfigFeatureWarning::ignored(
93 "middlewares.metrics",
94 "observability",
95 ));
96 }
97 if self.middlewares.resilience && !cfg!(feature = "resil") {
98 warnings.push(ConfigFeatureWarning::ignored(
99 "middlewares.resilience",
100 "resil",
101 ));
102 }
103 warnings.extend(dependency_feature_warnings(
104 &self.rpc_clients,
105 self.database.as_ref(),
106 ));
107 warnings
108 }
109
110 pub fn jwt_expires(&self) -> Option<u64> {
112 self.auth.as_ref().map(RestAuthSection::jwt_expires)
113 }
114
115 pub fn rpc_client(&self, name: &str) -> CoreResult<&RpcClientSection> {
117 self.rpc_clients.get(name).ok_or_else(|| {
118 CoreError::Config(config::ConfigError::Message(format!(
119 "missing rpc client config: {name}"
120 )))
121 })
122 }
123
124 #[cfg(feature = "rpc")]
126 pub fn rpc_client_config(&self, name: &str) -> CoreResult<crate::rpc::RpcClientConfig> {
127 self.rpc_client(name)?.to_rpc_client_config()
128 }
129
130 #[cfg(feature = "db")]
132 pub fn database_config(&self) -> Option<crate::db::DatabaseConfig> {
133 self.database
134 .as_ref()
135 .map(DatabaseSection::to_database_config)
136 }
137}
138
139#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
141#[serde(default, deny_unknown_fields)]
142pub struct RestServerSection {
143 pub host: String,
144 pub port: u16,
145 pub timeout_ms: u64,
146 pub max_body_bytes: usize,
147}
148
149impl Default for RestServerSection {
150 fn default() -> Self {
151 Self {
152 host: "127.0.0.1".to_string(),
153 port: 8080,
154 timeout_ms: 5000,
155 max_body_bytes: 1024 * 1024,
156 }
157 }
158}
159
160#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
162#[serde(default, deny_unknown_fields)]
163pub struct RestAuthSection {
164 pub jwt_secret: String,
165 pub jwt_expires: u64,
166 pub public_paths: Vec<String>,
167}
168
169impl Default for RestAuthSection {
170 fn default() -> Self {
171 Self {
172 jwt_secret: String::new(),
173 jwt_expires: 7200,
174 public_paths: Vec::new(),
175 }
176 }
177}
178
179impl RestAuthSection {
180 fn secret(&self) -> String {
181 std::env::var("JWT_AUTH_SECRET").unwrap_or_else(|_| self.jwt_secret.clone())
182 }
183
184 fn jwt_expires(&self) -> u64 {
185 std::env::var("JWT_AUTH_EXPIRES")
186 .ok()
187 .and_then(|value| value.parse().ok())
188 .unwrap_or(self.jwt_expires)
189 }
190
191 fn auth_config(&self) -> Option<crate::rest::AuthConfig> {
192 let secret = self.secret();
193 (!secret.is_empty()).then(|| crate::rest::AuthConfig {
194 secret,
195 public_paths: self.public_paths.clone(),
196 })
197 }
198}
199
200#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
202#[serde(default, deny_unknown_fields)]
203pub struct RestMiddlewaresSection {
204 pub metrics: bool,
205 pub resilience: bool,
206}
207
208impl Default for RestMiddlewaresSection {
209 fn default() -> Self {
210 Self {
211 metrics: true,
212 resilience: true,
213 }
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::RestServiceConfig;
220
221 #[test]
222 fn maps_runtime_values() {
223 let config = RestServiceConfig::default();
224 let runtime = config.rest_config();
225 assert_eq!(runtime.name, "rs-zero");
226 assert_eq!(runtime.timeout, std::time::Duration::from_millis(5000));
227 assert!(runtime.middlewares.metrics.enabled);
228 }
229
230 #[test]
231 fn validate_features_reflects_compile_time_features() {
232 let warnings = RestServiceConfig::default().validate_features();
233 assert_eq!(
234 warnings
235 .iter()
236 .any(|warning| warning.required_feature == "observability"),
237 !cfg!(feature = "observability")
238 );
239 assert_eq!(
240 warnings
241 .iter()
242 .any(|warning| warning.required_feature == "resil"),
243 !cfg!(feature = "resil")
244 );
245 }
246}