Skip to main content

aster_server/
configuration.rs

1use crate::error::{to_env_var, ConfigError};
2use config::{Config, Environment};
3use serde::Deserialize;
4use std::net::SocketAddr;
5
6#[derive(Debug, Default, Deserialize)]
7pub struct Settings {
8    #[serde(default = "default_host")]
9    pub host: String,
10    #[serde(default = "default_port")]
11    pub port: u16,
12}
13
14impl Settings {
15    pub fn socket_addr(&self) -> SocketAddr {
16        format!("{}:{}", self.host, self.port)
17            .parse()
18            .expect("Failed to parse socket address")
19    }
20
21    pub fn new() -> Result<Self, ConfigError> {
22        Self::load_and_validate()
23    }
24
25    fn load_and_validate() -> Result<Self, ConfigError> {
26        // Start with default configuration
27        let config = Config::builder()
28            // Server defaults
29            .set_default("host", default_host())?
30            .set_default("port", default_port())?
31            // Layer on the environment variables
32            .add_source(
33                Environment::with_prefix("ASTER")
34                    .prefix_separator("_")
35                    .separator("__")
36                    .try_parsing(true),
37            )
38            .build()?;
39
40        // Try to deserialize the configuration
41        let result: Result<Self, config::ConfigError> = config.try_deserialize();
42
43        // Handle missing field errors specially
44        match result {
45            Ok(settings) => Ok(settings),
46            Err(err) => {
47                tracing::debug!("Configuration error: {:?}", &err);
48
49                // Handle both NotFound and missing field message variants
50                let error_str = err.to_string();
51                if error_str.starts_with("missing field") {
52                    // Extract field name from error message "missing field `type`"
53                    let field = error_str
54                        .trim_start_matches("missing field `")
55                        .trim_end_matches("`");
56                    let env_var = to_env_var(field);
57                    Err(ConfigError::MissingEnvVar { env_var })
58                } else if let config::ConfigError::NotFound(field) = &err {
59                    let env_var = to_env_var(field);
60                    Err(ConfigError::MissingEnvVar { env_var })
61                } else {
62                    Err(ConfigError::Other(err))
63                }
64            }
65        }
66    }
67}
68
69fn default_host() -> String {
70    "127.0.0.1".to_string()
71}
72
73fn default_port() -> u16 {
74    3000
75}
76
77#[cfg(test)]
78mod tests {
79    use super::*;
80
81    #[test]
82    fn test_socket_addr_conversion() {
83        let server_settings = Settings {
84            host: "127.0.0.1".to_string(),
85            port: 3000,
86        };
87        let addr = server_settings.socket_addr();
88        assert_eq!(addr.to_string(), "127.0.0.1:3000");
89    }
90}