amq/
config.rs

1use std::{env, error::Error};
2#[cfg(unix)]
3use std::{ffi::OsString, path::Path};
4
5use serde::{Deserialize, Serialize};
6
7#[cfg(unix)]
8use crate::utils::normalize_path;
9use crate::utils::resolve_config_path;
10
11/// # Configuration.
12///
13/// ## Fields
14///
15/// - `path`: Path to the unix socket or tcp address.
16/// - `host`: Host of the tcp address.
17/// - `port`: Port of the tcp address.
18/// - `access_key`: Access key for authentication.
19/// - `access_secret`: Access secret for authentication.
20/// - `retry_times`: Retry times when connection failed.
21/// - `retry_interval`: Retry interval when connection failed.
22///
23/// ## Default values
24///
25/// - `path`: Empty string.
26/// - `host`: "0.0.0.0".
27/// - `port`: 60001.
28/// - `access_key`: Empty string.
29/// - `access_secret`: Empty string.
30/// - `retry_times`: 3.
31/// - `retry_interval`: 60.
32///
33/// ## Environment variables (when config file not exists, cover default values)
34///
35/// - `AMQ_PATH`: Path to the unix socket or tcp address.
36/// - `AMQ_HOST`: Host of the tcp address.
37/// - `AMQ_PORT`: Port of the tcp address.
38/// - `AMQ_ACCESS_KEY`: Access key for authentication.
39/// - `AMQ_ACCESS_SECRET`: Access secret for authentication.
40/// - `AMQ_RETRY_TIMES`: Retry times when connection failed.
41/// - `AMQ_RETRY_INTERVAL`: Retry interval when connection failed.
42///
43/// ## Example
44///
45/// ```toml
46/// path = "/tmp/amq.sock"  # use host and port if path is empty
47/// host = "127.0.0.1"      # ignored if path is not empty
48/// port = 60001            # ignored if path is not empty
49/// access_key = "access_key"
50/// access_secret = "access_secret"
51/// retry_times = 3
52/// retry_interval = 60
53/// ```
54///
55/// ```bash
56/// export AMQ_PATH="/tmp/amq.sock"
57/// export AMQ_HOST="127.0.0.1"
58/// export AMQ_PORT=60001
59/// export AMQ_ACCESS_KEY="access_key"
60/// export AMQ_ACCESS_SECRET="access_secret"
61/// export AMQ_RETRY_TIMES=3
62/// export AMQ_RETRY_INTERVAL=60
63/// ```
64#[derive(Deserialize, Serialize, Clone)]
65pub struct Config {
66    #[serde(default = "default_string")]
67    pub path: String,
68
69    #[serde(default = "default_host")]
70    pub host: String,
71
72    #[serde(default = "default_port")]
73    pub port: u16,
74
75    #[serde(default = "default_string")]
76    pub access_key: String,
77
78    #[serde(default = "default_string")]
79    pub access_secret: String,
80
81    #[serde(default = "default_retry_times")]
82    pub retry_times: u8,
83
84    #[serde(default = "default_retry_interval")]
85    pub retry_interval: u64,
86}
87
88fn default_host() -> String {
89    "0.0.0.0".into()
90}
91fn default_port() -> u16 {
92    60001
93}
94
95fn default_string() -> String {
96    "".into()
97}
98
99fn default_retry_times() -> u8 {
100    3
101}
102
103fn default_retry_interval() -> u64 {
104    60
105}
106
107impl Config {
108    /// # Create a new configuration.
109    ///
110    /// 1. use the first argument as config file path, if exists.
111    /// 2. use the `./config.toml` as config file path.
112    /// 3. use environment variables to cover default values.
113    /// 4. use default values.
114    pub fn new() -> Result<Self, Box<dyn Error>> {
115        let args: Vec<String> = env::args().collect();
116        let config_file = if args.len() > 1 {
117            &args[1]
118        } else {
119            "./config.toml"
120        };
121
122        match resolve_config_path(config_file) {
123            Ok(abs_path) => {
124                if !abs_path.exists() || !abs_path.is_file() {
125                    Ok(Self::default())
126                } else {
127                    let content = std::fs::read_to_string(abs_path)?;
128                    Ok(toml::from_str::<Config>(&content)?)
129                }
130            }
131            Err(e) => {
132                panic!("Failed to resolve config file path: {}", e);
133            }
134        }
135    }
136
137    pub fn get_address(&self) -> String {
138        format!("{}:{}", self.host, self.port)
139    }
140
141    #[cfg(unix)]
142    pub fn get_unix_path(&self) -> String {
143        let path = self.path.trim();
144        if path.starts_with("./") {
145            let current_dir = env::current_dir().unwrap();
146            let abs_path = current_dir.join(path);
147            let abs_path = normalize_path(&abs_path);
148            abs_path.to_str().unwrap().to_string()
149        } else if path.starts_with("~") {
150            let home_dir = env::var_os("HOME")
151                .or_else(|| env::var_os("USERPROFILE")) // Windows 兼容
152                .unwrap_or(OsString::from("./"));
153            let abs_path = Path::new(&home_dir).join(path.trim_start_matches("~/"));
154            abs_path.to_str().unwrap().to_string()
155        } else {
156            path.to_string()
157        }
158    }
159}
160
161impl Default for Config {
162    fn default() -> Self {
163        let mut c = Self {
164            path: default_string(),
165            host: default_host(),
166            port: default_port(),
167            access_key: default_string(),
168            access_secret: default_string(),
169            retry_times: default_retry_times(),
170            retry_interval: default_retry_interval(),
171        };
172        // AMQ_PATH
173        if let Ok(path) = std::env::var("AMQ_PATH") {
174            c.path = path;
175        }
176        // AMQ_HOST
177        if let Ok(host) = std::env::var("AMQ_HOST") {
178            c.host = host;
179        }
180        // AMQ_PORT
181        if let Ok(port) = std::env::var("AMQ_PORT") {
182            c.port = port.parse().unwrap();
183        }
184        // AMQ_ACCESS_KEY
185        if let Ok(key) = std::env::var("AMQ_ACCESS_KEY") {
186            c.access_key = key;
187        }
188        // AMQ_ACCESS_SECRET
189        if let Ok(secret) = std::env::var("AMQ_ACCESS_SECRET") {
190            c.access_secret = secret;
191        }
192        // AMQ_RETRY_TIMES
193        if let Ok(times) = std::env::var("AMQ_RETRY_TIMES") {
194            c.retry_times = times.parse().unwrap();
195        }
196        // AMQ_RETRY_INTERVAL
197        if let Ok(interval) = std::env::var("AMQ_RETRY_INTERVAL") {
198            c.retry_interval = interval.parse().unwrap();
199        }
200        c
201    }
202}