octoproxy_lib/
config.rs

1use anyhow::{bail, Context};
2use std::env;
3use std::{fs::read_to_string, net::SocketAddr, path::PathBuf};
4use toml::macros::Deserialize;
5use tracing::metadata::LevelFilter;
6use tracing::Level;
7use tracing_subscriber::EnvFilter;
8
9pub fn parse_socket_address(address: &str) -> anyhow::Result<SocketAddr> {
10    // 127.0.0.1:8081 or 0.0.0.0:8081
11    if let Ok(listen_address) = address.parse::<SocketAddr>() {
12        return Ok(listen_address);
13    }
14
15    // only a port
16    if let Ok(port) = address.parse::<u16>() {
17        let listen_address: SocketAddr = ([127, 0, 0, 1], port).into();
18        return Ok(listen_address);
19    }
20
21    bail!("cannot parse listen address into socket address")
22}
23
24pub trait TomlFileConfig
25where
26    Self: Sized,
27{
28    fn load_from_file(path: PathBuf) -> anyhow::Result<Self>
29    where
30        for<'de> Self: Deserialize<'de>,
31    {
32        let data = read_to_string(path).context("Failed to read file")?;
33        let config: Self = match toml::from_str(&data) {
34            Ok(config) => config,
35            Err(e) => {
36                // display_toml_error(&data, &e);
37                bail!(e);
38            }
39        };
40        Ok(config)
41    }
42}
43
44/// set `RUST_LOG`
45/// environment variable overwrites file config
46pub fn parse_log_level_str(log_level: &str) -> LevelFilter {
47    let log_level = env::var(EnvFilter::DEFAULT_ENV).unwrap_or(log_level.to_string());
48    LevelFilter::from_level(log_level.parse::<Level>().unwrap_or(Level::ERROR))
49}
50
51#[test]
52fn test_parse_socket_address() {
53    let res = parse_socket_address("localhost:8081");
54    assert!(res.is_err());
55
56    let res = parse_socket_address("127.0.0.1:8081");
57    assert!(res.is_ok());
58
59    let res = parse_socket_address("8081");
60    assert!(res.is_ok());
61}
62
63#[test]
64fn test_parse_log_level_str() {
65    let res = parse_log_level_str("INFO");
66    assert_eq!(res, Level::INFO);
67
68    let res = parse_log_level_str("DEBUG");
69    assert_eq!(res, Level::DEBUG);
70
71    std::env::set_var("RUST_LOG", "INFO");
72    let res = parse_log_level_str("");
73    assert_ne!(res, Level::DEBUG);
74}