use std::ffi::CString;
use std::os::fd::BorrowedFd;
use std::path::{Path, PathBuf};
use std::process::ExitCode;
use clap::Parser;
use nix::sys::stat::Mode;
use blivet::{DaemonConfig, DaemonizeError};
#[derive(Parser)]
#[command(version, about)]
#[command(trailing_var_arg = true)]
struct Args {
#[arg(short = 'p', long = "pidfile")]
pidfile: Option<PathBuf>,
#[arg(short = 'c', long = "chdir")]
chdir: Option<PathBuf>,
#[arg(short = 'm', long = "umask", value_parser = parse_octal_mode)]
umask: Option<Mode>,
#[arg(short = 'o', long = "stdout")]
stdout: Option<PathBuf>,
#[arg(short = 'e', long = "stderr")]
stderr: Option<PathBuf>,
#[arg(short = 'a', long = "append")]
append: bool,
#[arg(short = 'l', long = "lock")]
lockfile: Option<PathBuf>,
#[arg(short = 'E', long = "env")]
env: Vec<String>,
#[arg(short = 'u', long = "user")]
user: Option<String>,
#[arg(short = 'g', long = "group")]
group: Option<String>,
#[arg(short = 'f', long = "foreground")]
foreground: bool,
#[arg(long = "no-close-fds")]
no_close_fds: bool,
#[arg(short = 'v', long = "verbose")]
verbose: bool,
#[arg(required = true, allow_hyphen_values = true)]
program: Vec<String>,
}
fn parse_octal_mode(s: &str) -> Result<Mode, String> {
let bits = u32::from_str_radix(s, 8).map_err(|e| format!("invalid octal umask: {e}"))?;
Ok(Mode::from_bits_truncate(bits as libc::mode_t))
}
fn parse_env_pair(s: &str) -> (String, String) {
match s.find('=') {
Some(pos) => (s[..pos].to_string(), s[pos + 1..].to_string()),
None => (s.to_string(), String::new()),
}
}
fn main() -> ExitCode {
let args = Args::parse();
let mut config = DaemonConfig::new();
if let Some(ref p) = args.pidfile {
config.pidfile(p);
}
if let Some(ref p) = args.chdir {
config.chdir(p);
}
if let Some(m) = args.umask {
config.umask(m);
}
if let Some(ref p) = args.stdout {
config.stdout(p);
}
let stderr = args.stderr.as_ref().or(args.stdout.as_ref());
let stderr = stderr.map(|p| {
if args.stderr.is_some() {
p.clone()
} else {
derive_stderr_path(p)
}
});
if let Some(ref p) = stderr {
config.stderr(p);
}
config.append(args.append);
let lockfile = args.lockfile.as_ref().or(args.pidfile.as_ref());
if let Some(p) = lockfile {
config.lockfile(p);
}
for env_str in &args.env {
let (key, value) = parse_env_pair(env_str);
config.env(key, value);
}
if let Some(ref u) = args.user {
config.user(u);
}
if let Some(ref g) = args.group {
config.group(g);
}
config.foreground(args.foreground);
config.close_fds(!args.no_close_fds);
let program_path = resolve_program_path(&args.program[0]);
let program_path = match program_path {
Ok(p) => p,
Err(e) => {
eprintln!("{e}");
return ExitCode::from(e.exit_code());
}
};
if args.verbose {
eprintln!("daemonize: program={program_path}");
eprintln!("daemonize: config={config:?}");
}
#[allow(unsafe_code)]
let mut ctx = match unsafe { blivet::daemonize(&config) } {
Ok(ctx) => ctx,
Err(e) => {
eprintln!("{e}");
return ExitCode::from(e.exit_code());
}
};
if args.user.is_some() || args.group.is_some() {
if let Err(e) = ctx.chown_paths() {
ctx.report_error(&e);
}
if let Err(e) = ctx.drop_privileges() {
ctx.report_error(&e);
}
}
if let Some(lockfile_fd) = ctx.lockfile_fd() {
if let Err(e) = clear_cloexec(lockfile_fd) {
ctx.report_error(&DaemonizeError::ExecFailed(format!(
"failed to clear CLOEXEC on lockfile fd: {e}"
)));
}
}
let c_program = CString::new(program_path.as_str()).unwrap_or_else(|_| {
ctx.report_error(&DaemonizeError::ExecFailed(
"program path contains null byte".into(),
));
});
let c_args: Vec<CString> = args
.program
.iter()
.enumerate()
.map(|(i, a)| {
if i == 0 {
c_program.clone()
} else {
CString::new(a.as_str()).unwrap_or_else(|_| {
ctx.report_error(&DaemonizeError::ExecFailed(format!(
"argument contains null byte: {a}"
)));
})
}
})
.collect();
let Err(err) = nix::unistd::execvp(&c_program, &c_args);
ctx.report_error(&DaemonizeError::ExecFailed(format!(
"exec {program_path}: {err}"
)));
}
fn resolve_program_path(program: &str) -> Result<String, DaemonizeError> {
if program.contains('/') {
let canonical = std::fs::canonicalize(program).map_err(|e| {
DaemonizeError::ProgramNotFound(format!("cannot resolve {program}: {e}"))
})?;
let path_str = canonical
.to_str()
.ok_or_else(|| DaemonizeError::ProgramNotFound("path is not valid UTF-8".into()))?;
if nix::unistd::access(&canonical, nix::unistd::AccessFlags::X_OK).is_err() {
return Err(DaemonizeError::ProgramNotFound(format!(
"not executable: {path_str}"
)));
}
Ok(path_str.to_string())
} else {
Ok(program.to_string())
}
}
fn derive_stderr_path(stdout: &Path) -> PathBuf {
match stdout.extension() {
Some(ext) if ext == "stdout" => stdout.with_extension("stderr"),
Some(ext) if ext == "out" => stdout.with_extension("err"),
_ => stdout.to_path_buf(),
}
}
fn clear_cloexec(fd: BorrowedFd<'_>) -> Result<(), nix::Error> {
use nix::fcntl::{fcntl, FcntlArg, FdFlag};
fcntl(fd, FcntlArg::F_SETFD(FdFlag::empty()))?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_octal_mode_valid() {
let mode = parse_octal_mode("022").unwrap();
assert_eq!(mode.bits(), 0o022);
}
#[test]
fn parse_octal_mode_zero() {
let mode = parse_octal_mode("000").unwrap();
assert_eq!(mode.bits(), 0);
}
#[test]
fn parse_octal_mode_full() {
let mode = parse_octal_mode("777").unwrap();
assert_eq!(mode.bits(), 0o777);
}
#[test]
fn parse_octal_mode_invalid() {
assert!(parse_octal_mode("999").is_err());
assert!(parse_octal_mode("abc").is_err());
assert!(parse_octal_mode("").is_err());
}
#[test]
fn parse_env_pair_key_value() {
assert_eq!(parse_env_pair("FOO=bar"), ("FOO".into(), "bar".into()));
}
#[test]
fn parse_env_pair_empty_value() {
assert_eq!(parse_env_pair("FOO="), ("FOO".into(), String::new()));
}
#[test]
fn parse_env_pair_no_equals() {
assert_eq!(parse_env_pair("FOO"), ("FOO".into(), String::new()));
}
#[test]
fn parse_env_pair_multiple_equals() {
assert_eq!(
parse_env_pair("FOO=bar=baz"),
("FOO".into(), "bar=baz".into())
);
}
#[test]
fn derive_stderr_stdout_extension() {
let result = derive_stderr_path(Path::new("/var/log/app.stdout"));
assert_eq!(result, PathBuf::from("/var/log/app.stderr"));
}
#[test]
fn derive_stderr_out_extension() {
let result = derive_stderr_path(Path::new("/var/log/app.out"));
assert_eq!(result, PathBuf::from("/var/log/app.err"));
}
#[test]
fn derive_stderr_other_extension() {
let result = derive_stderr_path(Path::new("/var/log/app.log"));
assert_eq!(result, PathBuf::from("/var/log/app.log"));
}
#[test]
fn derive_stderr_no_extension() {
let result = derive_stderr_path(Path::new("/var/log/app"));
assert_eq!(result, PathBuf::from("/var/log/app"));
}
#[test]
fn clear_cloexec_removes_flag() {
use nix::fcntl::{fcntl, open, FcntlArg, FdFlag, OFlag};
use nix::sys::stat::Mode;
use std::os::fd::AsFd;
let fd = open(
c"/dev/null",
OFlag::O_RDONLY | OFlag::O_CLOEXEC,
Mode::empty(),
)
.unwrap();
let flags = fcntl(fd.as_fd(), FcntlArg::F_GETFD).unwrap();
assert!(FdFlag::from_bits_truncate(flags).contains(FdFlag::FD_CLOEXEC));
clear_cloexec(fd.as_fd()).unwrap();
let flags = fcntl(fd.as_fd(), FcntlArg::F_GETFD).unwrap();
assert!(!FdFlag::from_bits_truncate(flags).contains(FdFlag::FD_CLOEXEC));
}
}