Skip to main content

netmito/config/
worker.rs

1use clap::Args;
2use figment::{
3    providers::{Env, Format, Serialized, Toml},
4    value::magic::RelativePathBuf,
5    Figment,
6};
7use serde::{Deserialize, Serialize};
8use std::io::Write;
9use std::ops::Not;
10use std::{collections::HashSet, time::Duration};
11use tracing_appender::rolling::{RollingFileAppender, Rotation};
12use tracing_subscriber::{fmt::MakeWriter, layer::SubscriberExt, util::SubscriberInitExt, Layer};
13use url::Url;
14
15use crate::error::Error;
16
17use super::{coordinator::DEFAULT_COORDINATOR_ADDR, TracingGuard};
18
19/// A writer wrapper that adds worker UUID prefix to each log line
20struct WorkerIdWriter<W: Write> {
21    inner: W,
22    prefix: Vec<u8>,
23    at_line_start: bool,
24    buffer: Vec<u8>,
25}
26
27impl<W: Write> Drop for WorkerIdWriter<W> {
28    fn drop(&mut self) {
29        let _ = self.flush_buffer();
30    }
31}
32
33impl<W: Write> WorkerIdWriter<W> {
34    fn new(inner: W, worker_id: String) -> Self {
35        // Pre-format the prefix once to avoid repeated formatting
36        let prefix = format!("[worker:{}] ", worker_id).into_bytes();
37        Self {
38            inner,
39            prefix,
40            at_line_start: true,
41            buffer: Vec::with_capacity(8192), // 8KB buffer
42        }
43    }
44
45    fn flush_buffer(&mut self) -> std::io::Result<()> {
46        if !self.buffer.is_empty() {
47            self.inner.write_all(&self.buffer)?;
48            self.buffer.clear();
49        }
50        Ok(())
51    }
52}
53
54impl<W: Write> Write for WorkerIdWriter<W> {
55    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
56        let mut remaining = buf;
57        let total_len = buf.len();
58
59        while !remaining.is_empty() {
60            // Find the next newline using memchr (fast SIMD search)
61            if let Some(newline_pos) = remaining.iter().position(|&b| b == b'\n') {
62                // Write prefix if at line start
63                if self.at_line_start && newline_pos > 0 {
64                    self.buffer.extend_from_slice(&self.prefix);
65                    self.at_line_start = false;
66                }
67
68                // Write the line including newline
69                self.buffer.extend_from_slice(&remaining[..=newline_pos]);
70                self.at_line_start = true;
71
72                // Flush buffer if it's getting large (> 4KB)
73                if self.buffer.len() > 4096 {
74                    self.flush_buffer()?;
75                }
76
77                remaining = &remaining[newline_pos + 1..];
78            } else {
79                // No newline in remaining data
80                if self.at_line_start && !remaining.is_empty() {
81                    self.buffer.extend_from_slice(&self.prefix);
82                    self.at_line_start = false;
83                }
84                self.buffer.extend_from_slice(remaining);
85                break;
86            }
87        }
88
89        Ok(total_len)
90    }
91
92    fn flush(&mut self) -> std::io::Result<()> {
93        self.flush_buffer()?;
94        self.inner.flush()
95    }
96}
97
98/// A MakeWriter wrapper that creates WorkerIdWriter instances
99struct WorkerIdMakeWriter<M> {
100    inner: M,
101    worker_id: String,
102}
103
104impl<'a, M> MakeWriter<'a> for WorkerIdMakeWriter<M>
105where
106    M: MakeWriter<'a>,
107{
108    type Writer = WorkerIdWriter<M::Writer>;
109
110    fn make_writer(&'a self) -> Self::Writer {
111        WorkerIdWriter::new(self.inner.make_writer(), self.worker_id.clone())
112    }
113}
114
115#[derive(Deserialize, Serialize, Debug)]
116pub struct WorkerConfig {
117    pub(crate) coordinator_addr: Url,
118    #[serde(with = "humantime_serde")]
119    pub(crate) polling_interval: Duration,
120    #[serde(with = "humantime_serde")]
121    pub(crate) heartbeat_interval: Duration,
122    pub(crate) credential_path: Option<RelativePathBuf>,
123    pub(crate) user: Option<String>,
124    pub(crate) password: Option<String>,
125    pub(crate) groups: HashSet<String>,
126    pub(crate) tags: HashSet<String>,
127    pub(crate) labels: HashSet<String>,
128    pub(crate) log_path: Option<RelativePathBuf>,
129    pub(crate) file_log: bool,
130    #[serde(default)]
131    pub(crate) shared_log: bool,
132    #[serde(with = "humantime_serde")]
133    pub(crate) lifetime: Option<Duration>,
134    #[serde(default)]
135    pub(crate) retain: bool,
136    #[serde(default)]
137    pub(crate) skip_redis: bool,
138}
139
140#[derive(Args, Debug, Serialize, Default, Clone)]
141#[command(rename_all = "kebab-case")]
142pub struct WorkerConfigCli {
143    /// The path of the config file
144    #[arg(long)]
145    #[serde(skip_serializing_if = "::std::option::Option::is_none")]
146    pub config: Option<String>,
147    /// The address of the coordinator
148    #[arg(short, long = "coordinator")]
149    #[serde(skip_serializing_if = "::std::option::Option::is_none")]
150    pub coordinator_addr: Option<String>,
151    /// The interval to poll tasks or resources
152    #[arg(long)]
153    #[serde(skip_serializing_if = "::std::option::Option::is_none")]
154    pub polling_interval: Option<String>,
155    /// The interval to send heartbeat
156    #[arg(long)]
157    #[serde(skip_serializing_if = "::std::option::Option::is_none")]
158    pub heartbeat_interval: Option<String>,
159    /// The path of the user credential file
160    #[arg(long)]
161    #[serde(skip_serializing_if = "::std::option::Option::is_none")]
162    pub credential_path: Option<String>,
163    /// The username of the user
164    #[arg(short, long)]
165    #[serde(skip_serializing_if = "::std::option::Option::is_none")]
166    pub user: Option<String>,
167    /// The password of the user
168    #[arg(short, long)]
169    #[serde(skip_serializing_if = "::std::option::Option::is_none")]
170    pub password: Option<String>,
171    /// The groups allowed to submit tasks to this worker
172    #[arg(short, long, num_args = 0.., value_delimiter = ',')]
173    #[serde(skip_serializing_if = "::std::vec::Vec::is_empty")]
174    pub groups: Vec<String>,
175    /// The tags of this worker
176    #[arg(short, long, num_args = 0.., value_delimiter = ',')]
177    #[serde(skip_serializing_if = "::std::vec::Vec::is_empty")]
178    pub tags: Vec<String>,
179    /// The labels of this worker
180    #[arg(short, long, num_args = 0.., value_delimiter = ',')]
181    #[serde(skip_serializing_if = "::std::vec::Vec::is_empty")]
182    pub labels: Vec<String>,
183    /// The log file path. If not specified, then the default rolling log file path would be used.
184    /// If specified, then the log file would be exactly at the path specified.
185    #[arg(long)]
186    #[serde(skip_serializing_if = "::std::option::Option::is_none")]
187    pub log_path: Option<String>,
188    /// Enable logging to file
189    #[arg(long)]
190    #[serde(skip_serializing_if = "<&bool>::not")]
191    pub file_log: bool,
192    /// Enable shared logging across multiple workers with daily rotation (max 3 files)
193    #[arg(long)]
194    #[serde(skip_serializing_if = "<&bool>::not")]
195    pub shared_log: bool,
196    /// The lifetime of the worker to alive (e.g., 7d, 1year)
197    #[arg(long)]
198    #[serde(skip_serializing_if = "::std::option::Option::is_none")]
199    pub lifetime: Option<String>,
200    /// Whether to retain the previous login state without refetching the credential
201    #[arg(long)]
202    #[serde(skip_serializing_if = "<&bool>::not")]
203    pub retain: bool,
204    /// Whether to skip connecting to Redis
205    #[arg(long)]
206    #[serde(skip_serializing_if = "<&bool>::not")]
207    pub skip_redis: bool,
208}
209
210impl Default for WorkerConfig {
211    fn default() -> Self {
212        Self {
213            coordinator_addr: Url::parse(&format!("http://{DEFAULT_COORDINATOR_ADDR}")).unwrap(),
214            polling_interval: Duration::from_secs(180),
215            heartbeat_interval: Duration::from_secs(300),
216            credential_path: None,
217            user: None,
218            password: None,
219            groups: HashSet::new(),
220            tags: HashSet::new(),
221            labels: HashSet::new(),
222            log_path: None,
223            file_log: false,
224            shared_log: false,
225            lifetime: None,
226            retain: false,
227            skip_redis: false,
228        }
229    }
230}
231
232impl WorkerConfig {
233    pub fn new(cli: &WorkerConfigCli) -> crate::error::Result<Self> {
234        let global_config = dirs::config_dir().map(|mut p| {
235            p.push("mitosis");
236            p.push("config.toml");
237            p
238        });
239        let mut figment = Figment::new().merge(Serialized::from(Self::default(), "worker"));
240        if let Some(global_config) = global_config {
241            if global_config.exists() {
242                figment = figment.merge(Toml::file(global_config).nested());
243            }
244        }
245        figment = figment
246            .merge(Toml::file(cli.config.as_deref().unwrap_or("config.toml")).nested())
247            .merge(Env::prefixed("MITO_").profile("worker"))
248            .merge(Serialized::from(cli, "worker"))
249            .select("worker");
250        Ok(figment.extract()?)
251    }
252
253    pub fn setup_tracing_subscriber<T, U>(&self, worker_id: U) -> crate::error::Result<TracingGuard>
254    where
255        T: std::fmt::Display,
256        U: Into<T>,
257    {
258        if self.file_log {
259            let id = worker_id.into();
260            let id_str = id.to_string();
261
262            // Determine file logger based on shared_log setting
263            let file_logger = if self.shared_log {
264                // Shared logging: use log_path if provided, otherwise use fixed "workers.log"
265                self.log_path
266                    .as_ref()
267                    .and_then(|p| {
268                        let path = p.relative();
269                        let dir = path.parent();
270                        let file_name = path.file_name();
271                        match (dir, file_name) {
272                            (Some(dir), Some(file_name)) => {
273                                // Use daily rotation with max 3 log files for shared log
274                                RollingFileAppender::builder()
275                                    .rotation(Rotation::DAILY)
276                                    .filename_prefix(file_name.to_string_lossy().to_string())
277                                    .max_log_files(3)
278                                    .build(dir)
279                                    .ok()
280                            }
281                            _ => None,
282                        }
283                    })
284                    .or_else(|| {
285                        // Use fixed "workers.log" with daily rotation in cache directory
286                        dirs::cache_dir()
287                            .map(|mut p| {
288                                p.push("mitosis");
289                                p.push("worker");
290                                p
291                            })
292                            .and_then(|dir| {
293                                RollingFileAppender::builder()
294                                    .rotation(Rotation::DAILY)
295                                    .filename_prefix("workers.log")
296                                    .max_log_files(3)
297                                    .build(dir)
298                                    .ok()
299                            })
300                    })
301                    .ok_or(Error::ConfigError(Box::new(figment::Error::from(
302                        "log path not valid and cache directory not found",
303                    ))))?
304            } else {
305                // Non-shared logging: use per-worker log file with no rotation
306                self.log_path
307                    .as_ref()
308                    .and_then(|p| {
309                        let path = p.relative();
310                        let dir = path.parent();
311                        let file_name = path.file_name();
312                        match (dir, file_name) {
313                            (Some(dir), Some(file_name)) => {
314                                Some(tracing_appender::rolling::never(dir, file_name))
315                            }
316                            _ => None,
317                        }
318                    })
319                    .or_else(|| {
320                        dirs::cache_dir()
321                            .map(|mut p| {
322                                p.push("mitosis");
323                                p.push("worker");
324                                p
325                            })
326                            .map(|dir| {
327                                tracing_appender::rolling::never(dir, format!("{id_str}.log"))
328                            })
329                    })
330                    .ok_or(Error::ConfigError(Box::new(figment::Error::from(
331                        "log path not valid and cache directory not found",
332                    ))))?
333            };
334
335            let (non_blocking, guard) = tracing_appender::non_blocking(file_logger);
336            let env_filter = tracing_subscriber::EnvFilter::try_from_env("MITO_FILE_LOG_LEVEL")
337                .unwrap_or_else(|_| "netmito=info".into());
338
339            // If shared_log is enabled, wrap the writer to add worker UUID prefix
340            let coordinator_guard = if self.shared_log {
341                let worker_writer = WorkerIdMakeWriter {
342                    inner: non_blocking,
343                    worker_id: id_str,
344                };
345
346                tracing_subscriber::registry()
347                    .with(
348                        tracing_subscriber::fmt::layer().with_filter(
349                            tracing_subscriber::EnvFilter::try_from_default_env()
350                                .unwrap_or_else(|_| "netmito=info".into()),
351                        ),
352                    )
353                    .with(
354                        tracing_subscriber::fmt::layer()
355                            .with_writer(worker_writer)
356                            .with_filter(env_filter),
357                    )
358                    .set_default()
359            } else {
360                tracing_subscriber::registry()
361                    .with(
362                        tracing_subscriber::fmt::layer().with_filter(
363                            tracing_subscriber::EnvFilter::try_from_default_env()
364                                .unwrap_or_else(|_| "netmito=info".into()),
365                        ),
366                    )
367                    .with(
368                        tracing_subscriber::fmt::layer()
369                            .with_writer(non_blocking)
370                            .with_filter(env_filter),
371                    )
372                    .set_default()
373            };
374
375            Ok(TracingGuard {
376                subscriber_guard: Some(coordinator_guard),
377                file_guard: Some(guard),
378            })
379        } else {
380            let coordinator_guard = tracing_subscriber::registry()
381                .with(
382                    tracing_subscriber::fmt::layer().with_filter(
383                        tracing_subscriber::EnvFilter::try_from_default_env()
384                            .unwrap_or_else(|_| "netmito=info".into()),
385                    ),
386                )
387                .set_default();
388            Ok(TracingGuard {
389                subscriber_guard: Some(coordinator_guard),
390                file_guard: None,
391            })
392        }
393    }
394}