bbox_core/
config.rs

1use crate::auth::oidc::OidcAuthCfg;
2use crate::cli::GlobalArgs;
3use crate::service::ServiceConfig;
4use actix_web::HttpRequest;
5use clap::{ArgMatches, FromArgMatches};
6use core::fmt::Display;
7use figment::providers::{Env, Format, Toml};
8use figment::Figment;
9use log::info;
10use once_cell::sync::OnceCell;
11use serde::{Deserialize, Serialize};
12use std::env;
13use std::path::PathBuf;
14
15/// Application configuration singleton
16pub fn app_config() -> &'static Figment {
17    static CONFIG: OnceCell<Figment> = OnceCell::new();
18    CONFIG.get_or_init(|| {
19        let config = Figment::new()
20            .merge(Toml::file(
21                env::var("BBOX_CONFIG").unwrap_or("bbox.toml".to_string()),
22            ))
23            .merge(Env::prefixed("BBOX_").split("__"));
24        if let Some(source) = config_source(&config) {
25            // Logger is not initialized yet
26            println!("Reading configuration from `{source}`");
27            info!("Reading configuration from `{source}`");
28        }
29        config
30    })
31}
32
33fn config_source(config: &Figment) -> &Option<figment::Source> {
34    if let Some(meta) = config.metadata().next() {
35        &meta.source
36    } else {
37        &None
38    }
39}
40
41/// Base directory for files referenced in configuration
42pub fn base_dir() -> PathBuf {
43    let config = app_config();
44    if let Some(source) = config_source(config)
45        .as_ref()
46        .and_then(|source| source.file_path())
47    {
48        source
49            .parent()
50            .expect("absolute config file path")
51            .canonicalize()
52            .expect("absolute config file path")
53    } else {
54        env::current_dir().expect("current working dir")
55    }
56}
57
58/// Full path relative to application base directory
59pub fn app_dir(path: impl Into<PathBuf>) -> PathBuf {
60    let path = path.into();
61    if path.is_relative() {
62        base_dir().join(path)
63    } else {
64        path
65    }
66}
67
68#[derive(thiserror::Error, Debug)]
69pub enum ConfigError {
70    #[error("Configuration error")]
71    ConfigurationError,
72}
73
74pub fn from_config_or_exit<'a, T: Default + Deserialize<'a>>(tag: &str) -> T {
75    let config = app_config();
76    match config.extract_inner(tag) {
77        Ok(config) => config,
78        Err(err) => {
79            config_error_exit(err);
80            Default::default()
81        }
82    }
83}
84
85pub fn from_config_root_or_exit<'a, T: Default + Deserialize<'a>>() -> T {
86    let config = app_config();
87    match config.extract() {
88        Ok(config) => config,
89        Err(err) => {
90            config_error_exit(err);
91            Default::default()
92        }
93    }
94}
95
96pub fn from_config_opt_or_exit<'a, T: Deserialize<'a>>(tag: &str) -> Option<T> {
97    let config = app_config();
98    config
99        .find_value(tag)
100        .map(|_| config.extract_inner(tag).unwrap_or_else(error_exit))
101        .ok()
102}
103
104pub fn config_error_exit<T: Display>(err: T) {
105    eprintln!("Error during initialization: {err}");
106    std::process::exit(1);
107}
108
109pub fn error_exit<T: Display, R>(err: T) -> R {
110    eprintln!("Error during initialization: {err}");
111    std::process::exit(1);
112}
113
114// -- Common configuration --
115
116#[derive(Deserialize, Default)]
117pub struct CoreServiceCfg {
118    pub webserver: Option<WebserverCfg>,
119    pub metrics: Option<MetricsCfg>,
120    #[serde(default)]
121    pub datasource: Vec<NamedDatasourceCfg>,
122    pub auth: Option<AuthCfg>,
123}
124
125#[derive(Deserialize, Serialize, Clone, Debug)]
126#[serde(default, deny_unknown_fields)]
127pub struct WebserverCfg {
128    /// IP address of interface and port to bind web server (e.g. 0.0.0.0:8080 for all)
129    pub server_addr: String,
130    /// Number of parallel web server threads. Defaults to number of available logical CPUs
131    worker_threads: Option<usize>,
132    public_server_url: Option<String>,
133    /// Log level (Default: info)
134    pub loglevel: Option<Loglevel>,
135    pub tls_cert: Option<String>,
136    pub tls_key: Option<String>,
137    pub cors: Option<CorsCfg>,
138}
139
140#[derive(clap::ValueEnum, Deserialize, Serialize, Clone, Debug)]
141pub enum Loglevel {
142    Error,
143    Warn,
144    Info,
145    Debug,
146    Trace,
147}
148
149#[derive(Deserialize, Serialize, Clone, Debug)]
150#[serde(deny_unknown_fields)]
151pub struct CorsCfg {
152    pub allow_all_origins: bool,
153    // #[serde(rename = "allowed_origin")]
154    // pub allowed_origins: Vec<String>,
155}
156
157impl ServiceConfig for CoreServiceCfg {
158    fn initialize(args: &ArgMatches) -> Result<Self, ConfigError> {
159        let mut cfg: CoreServiceCfg = from_config_root_or_exit();
160        if let Ok(args) = GlobalArgs::from_arg_matches(args) {
161            if let Some(loglevel) = args.loglevel {
162                let mut webserver = cfg.webserver.unwrap_or_default();
163                webserver.loglevel = Some(loglevel);
164                cfg.webserver = Some(webserver);
165            }
166        };
167        Ok(cfg)
168    }
169}
170
171impl CoreServiceCfg {
172    pub fn loglevel(&self) -> Option<Loglevel> {
173        self.webserver.as_ref().and_then(|cfg| cfg.loglevel.clone())
174    }
175}
176
177impl Default for WebserverCfg {
178    fn default() -> Self {
179        let cors = if cfg!(debug_assertions) {
180            // Enable CORS for debug build
181            Some(CorsCfg {
182                allow_all_origins: true,
183            })
184        } else {
185            None
186        };
187        WebserverCfg {
188            server_addr: "127.0.0.1:8080".to_string(),
189            worker_threads: None,
190            public_server_url: None,
191            loglevel: None,
192            tls_cert: None,
193            tls_key: None,
194            cors,
195        }
196    }
197}
198
199impl WebserverCfg {
200    pub fn worker_threads(&self) -> usize {
201        self.worker_threads.unwrap_or(num_cpus::get())
202    }
203    pub fn public_server_url(&self, req: HttpRequest) -> String {
204        if let Some(url) = &self.public_server_url {
205            url.clone()
206        } else {
207            let conninfo = req.connection_info();
208            format!("{}://{}", conninfo.scheme(), conninfo.host(),)
209        }
210    }
211}
212
213#[derive(Deserialize, Serialize, Default, Clone, Debug)]
214#[serde(default, deny_unknown_fields)]
215pub struct AuthCfg {
216    pub oidc: Option<OidcAuthCfg>,
217}
218
219// -- Metrics --
220
221#[derive(Deserialize, Serialize, Default, Debug)]
222#[serde(deny_unknown_fields)]
223pub struct MetricsCfg {
224    pub prometheus: Option<PrometheusCfg>,
225    pub jaeger: Option<JaegerCfg>,
226}
227
228#[derive(Deserialize, Serialize, Debug)]
229#[serde(deny_unknown_fields)]
230pub struct PrometheusCfg {
231    pub path: String,
232}
233
234#[derive(Deserialize, Serialize, Debug)]
235#[serde(deny_unknown_fields)]
236pub struct JaegerCfg {
237    pub agent_endpoint: String,
238}
239
240impl MetricsCfg {
241    pub fn from_config() -> Option<Self> {
242        from_config_opt_or_exit("metrics")
243    }
244}
245
246// -- Datasources --
247
248#[derive(Deserialize, Serialize, Debug)]
249#[serde(deny_unknown_fields)]
250pub struct NamedDatasourceCfg {
251    pub name: String,
252    #[serde(flatten)]
253    pub datasource: DatasourceCfg,
254}
255
256#[derive(Deserialize, Serialize, Clone, Debug)]
257pub enum DatasourceCfg {
258    // -- vector sources --
259    #[serde(rename = "postgis")]
260    Postgis(DsPostgisCfg),
261    #[serde(rename = "gpkg")]
262    Gpkg(DsGpkgCfg),
263    // GdalData(GdalSource),
264    // -- raster sources --
265    WmsFcgi,
266    #[serde(rename = "wms_proxy")]
267    WmsHttp(WmsHttpSourceProviderCfg),
268    // GdalData(GdalSource),
269    // RasterData(GeorasterSource),
270    // -- direct tile sources --
271    #[serde(rename = "mbtiles")]
272    Mbtiles,
273}
274
275#[derive(Deserialize, Serialize, Clone, Debug)]
276#[serde(deny_unknown_fields)]
277pub struct DsPostgisCfg {
278    pub url: String,
279    // pub pool: Option<u16>,
280    // pub connection_timeout: Option<u64>,
281}
282
283#[derive(Deserialize, Serialize, Clone, Debug)]
284#[serde(deny_unknown_fields)]
285pub struct DsGpkgCfg {
286    pub path: PathBuf,
287    // pub pool_min_connections(0)
288    // pub pool_max_connections(8)
289}
290
291impl DsGpkgCfg {
292    pub fn abs_path(&self) -> PathBuf {
293        app_dir(&self.path)
294    }
295}
296
297#[derive(Deserialize, Serialize, Clone, Debug)]
298#[serde(deny_unknown_fields)]
299pub struct WmsHttpSourceProviderCfg {
300    pub baseurl: String,
301    pub format: String,
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307    use figment::providers::Env;
308    use serde::Deserialize;
309
310    #[derive(Deserialize, Serialize, Debug)]
311    struct Package {
312        name: String,
313    }
314
315    #[test]
316    fn toml_config() {
317        let config = Figment::new()
318            .merge(Toml::file("Cargo.toml"))
319            .merge(Env::prefixed("CARGO_"));
320        let package: Package = config.extract_inner("package").unwrap();
321        assert_eq!(package.name, "bbox-core");
322    }
323}