netmito/config/
worker.rs

1use clap::Args;
2use figment::{
3    providers::{Env, Format, Serialized, Toml},
4    value::magic::RelativePathBuf,
5    Figment,
6};
7use serde::{Deserialize, Serialize};
8use std::ops::Not;
9use std::{collections::HashSet, time::Duration};
10use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, Layer};
11use url::Url;
12
13use crate::error::Error;
14
15use super::{coordinator::DEFAULT_COORDINATOR_ADDR, TracingGuard};
16
17#[derive(Deserialize, Serialize, Debug)]
18pub struct WorkerConfig {
19    pub(crate) coordinator_addr: Url,
20    #[serde(with = "humantime_serde")]
21    pub(crate) polling_interval: Duration,
22    #[serde(with = "humantime_serde")]
23    pub(crate) heartbeat_interval: Duration,
24    pub(crate) credential_path: Option<RelativePathBuf>,
25    pub(crate) user: Option<String>,
26    pub(crate) password: Option<String>,
27    pub(crate) groups: HashSet<String>,
28    pub(crate) tags: HashSet<String>,
29    pub(crate) labels: HashSet<String>,
30    pub(crate) log_path: Option<RelativePathBuf>,
31    pub(crate) file_log: bool,
32    #[serde(with = "humantime_serde")]
33    pub(crate) lifetime: Option<Duration>,
34    #[serde(default)]
35    pub(crate) retain: bool,
36    #[serde(default)]
37    pub(crate) skip_redis: bool,
38}
39
40#[derive(Args, Debug, Serialize, Default, Clone)]
41#[command(rename_all = "kebab-case")]
42pub struct WorkerConfigCli {
43    /// The path of the config file
44    #[arg(long)]
45    #[serde(skip_serializing_if = "::std::option::Option::is_none")]
46    pub config: Option<String>,
47    /// The address of the coordinator
48    #[arg(short, long = "coordinator")]
49    #[serde(skip_serializing_if = "::std::option::Option::is_none")]
50    pub coordinator_addr: Option<String>,
51    /// The interval to poll tasks or resources
52    #[arg(long)]
53    #[serde(skip_serializing_if = "::std::option::Option::is_none")]
54    pub polling_interval: Option<String>,
55    /// The interval to send heartbeat
56    #[arg(long)]
57    #[serde(skip_serializing_if = "::std::option::Option::is_none")]
58    pub heartbeat_interval: Option<String>,
59    /// The path of the user credential file
60    #[arg(long)]
61    #[serde(skip_serializing_if = "::std::option::Option::is_none")]
62    pub credential_path: Option<String>,
63    /// The username of the user
64    #[arg(short, long)]
65    #[serde(skip_serializing_if = "::std::option::Option::is_none")]
66    pub user: Option<String>,
67    /// The password of the user
68    #[arg(short, long)]
69    #[serde(skip_serializing_if = "::std::option::Option::is_none")]
70    pub password: Option<String>,
71    /// The groups allowed to submit tasks to this worker
72    #[arg(short, long, num_args = 0.., value_delimiter = ',')]
73    #[serde(skip_serializing_if = "::std::vec::Vec::is_empty")]
74    pub groups: Vec<String>,
75    /// The tags of this worker
76    #[arg(short, long, num_args = 0.., value_delimiter = ',')]
77    #[serde(skip_serializing_if = "::std::vec::Vec::is_empty")]
78    pub tags: Vec<String>,
79    /// The labels of this worker
80    #[arg(short, long, num_args = 0.., value_delimiter = ',')]
81    #[serde(skip_serializing_if = "::std::vec::Vec::is_empty")]
82    pub labels: Vec<String>,
83    /// The log file path. If not specified, then the default rolling log file path would be used.
84    /// If specified, then the log file would be exactly at the path specified.
85    #[arg(long)]
86    #[serde(skip_serializing_if = "::std::option::Option::is_none")]
87    pub log_path: Option<String>,
88    /// Enable logging to file
89    #[arg(long)]
90    #[serde(skip_serializing_if = "<&bool>::not")]
91    pub file_log: bool,
92    /// The lifetime of the worker to alive (e.g., 7d, 1year)
93    #[arg(long)]
94    #[serde(skip_serializing_if = "::std::option::Option::is_none")]
95    pub lifetime: Option<String>,
96    /// Whether to retain the previous login state without refetching the credential
97    #[arg(long)]
98    #[serde(skip_serializing_if = "<&bool>::not")]
99    pub retain: bool,
100    /// Whether to skip connecting to Redis
101    #[arg(long)]
102    #[serde(skip_serializing_if = "<&bool>::not")]
103    pub skip_redis: bool,
104}
105
106impl Default for WorkerConfig {
107    fn default() -> Self {
108        Self {
109            coordinator_addr: Url::parse(&format!("http://{DEFAULT_COORDINATOR_ADDR}")).unwrap(),
110            polling_interval: Duration::from_secs(180),
111            heartbeat_interval: Duration::from_secs(300),
112            credential_path: None,
113            user: None,
114            password: None,
115            groups: HashSet::new(),
116            tags: HashSet::new(),
117            labels: HashSet::new(),
118            log_path: None,
119            file_log: false,
120            lifetime: None,
121            retain: false,
122            skip_redis: false,
123        }
124    }
125}
126
127impl WorkerConfig {
128    pub fn new(cli: &WorkerConfigCli) -> crate::error::Result<Self> {
129        let global_config = dirs::config_dir().map(|mut p| {
130            p.push("mitosis");
131            p.push("config.toml");
132            p
133        });
134        let mut figment = Figment::new().merge(Serialized::from(Self::default(), "worker"));
135        if let Some(global_config) = global_config {
136            if global_config.exists() {
137                figment = figment.merge(Toml::file(global_config).nested());
138            }
139        }
140        figment = figment
141            .merge(Toml::file(cli.config.as_deref().unwrap_or("config.toml")).nested())
142            .merge(Env::prefixed("MITO_").profile("worker"))
143            .merge(Serialized::from(cli, "worker"))
144            .select("worker");
145        Ok(figment.extract()?)
146    }
147
148    pub fn setup_tracing_subscriber<T, U>(&self, worker_id: U) -> crate::error::Result<TracingGuard>
149    where
150        T: std::fmt::Display,
151        U: Into<T>,
152    {
153        if self.file_log {
154            let file_logger = self
155                .log_path
156                .as_ref()
157                .and_then(|p| {
158                    let path = p.relative();
159                    let dir = path.parent();
160                    let file_name = path.file_name();
161                    match (dir, file_name) {
162                        (Some(dir), Some(file_name)) => {
163                            Some(tracing_appender::rolling::never(dir, file_name))
164                        }
165                        _ => None,
166                    }
167                })
168                .or_else(|| {
169                    dirs::cache_dir()
170                        .map(|mut p| {
171                            p.push("mitosis");
172                            p.push("worker");
173                            p
174                        })
175                        .map(|dir| {
176                            let id = worker_id.into();
177                            tracing_appender::rolling::never(dir, format!("{id}.log"))
178                        })
179                })
180                .ok_or(Error::ConfigError(Box::new(figment::Error::from(
181                    "log path not valid and cache directory not found",
182                ))))?;
183            let (non_blocking, guard) = tracing_appender::non_blocking(file_logger);
184            let env_filter = tracing_subscriber::EnvFilter::try_from_env("MITO_FILE_LOG_LEVEL")
185                .unwrap_or_else(|_| "netmito=info".into());
186            let coordinator_guard = tracing_subscriber::registry()
187                .with(
188                    tracing_subscriber::fmt::layer().with_filter(
189                        tracing_subscriber::EnvFilter::try_from_default_env()
190                            .unwrap_or_else(|_| "netmito=info".into()),
191                    ),
192                )
193                .with(
194                    tracing_subscriber::fmt::layer()
195                        .with_writer(non_blocking)
196                        .with_filter(env_filter),
197                )
198                .set_default();
199            Ok(TracingGuard {
200                subscriber_guard: Some(coordinator_guard),
201                file_guard: Some(guard),
202            })
203        } else {
204            let coordinator_guard = tracing_subscriber::registry()
205                .with(
206                    tracing_subscriber::fmt::layer().with_filter(
207                        tracing_subscriber::EnvFilter::try_from_default_env()
208                            .unwrap_or_else(|_| "netmito=info".into()),
209                    ),
210                )
211                .set_default();
212            Ok(TracingGuard {
213                subscriber_guard: Some(coordinator_guard),
214                file_guard: None,
215            })
216        }
217    }
218}