cindy 0.1.0

Managing infrastructure at breakneck speed.
Documentation
use std::str::FromStr;

use zbus_systemd::systemd1::{ManagerProxy, UnitProxy};
use zbus_systemd::zbus;

use crate as cindy;
use crate::Context;

/// Runtime action attached to an `Enablement`.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
#[crate::wire]
pub enum RuntimeAction {
    Started,
    Stopped,
    Restarted,
    Reloaded,
}

#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
#[crate::wire]
pub enum Enablement {
    Masked,
    Disabled(Option<RuntimeAction>),
    Enabled(Option<RuntimeAction>),
}

impl Enablement {
    fn runtime(self) -> Option<RuntimeAction> {
        match self {
            Self::Masked => Some(RuntimeAction::Stopped),
            Self::Disabled(r) | Self::Enabled(r) => r,
        }
    }
}

#[derive(Clone, Debug, Default, PartialEq, Eq)]
#[crate::wire]
pub struct State {
    /// Unit name. Bare names (`"nginx"`) get `.service` appended; anything
    /// containing a `.` is passed through unchanged (so `"foo.timer"`,
    /// `"multi-user.target"`, `"docker.socket"` etc. work as-is).
    pub name: String,
    /// Desired enablement (and optional runtime). `None` = don't touch.
    pub enablement: Option<Enablement>,
}

/// Default `Diff` impl renders `State` via `{:#?}` + a unified line diff.
/// No binary payloads here, so no custom rendering is needed.
impl crate::Diff for State {}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum UnitState {
    Enabled,
    EnabledRuntime,
    Linked,
    LinkedRuntime,
    Alias,
    Masked,
    MaskedRuntime,
    Static,
    Disabled,
    Generated,
    Transient,
    Indirect,
    Bad,
}

impl FromStr for UnitState {
    type Err = String;
    fn from_str(s: &str) -> Result<Self, Self::Err> {
        Ok(match s {
            "enabled" => Self::Enabled,
            "enabled-runtime" => Self::EnabledRuntime,
            "linked" => Self::Linked,
            "linked-runtime" => Self::LinkedRuntime,
            "alias" => Self::Alias,
            "masked" => Self::Masked,
            "masked-runtime" => Self::MaskedRuntime,
            "static" => Self::Static,
            "disabled" => Self::Disabled,
            "generated" => Self::Generated,
            "transient" => Self::Transient,
            "indirect" => Self::Indirect,
            "bad" | "" => Self::Bad,
            other => return Err(format!("unknown UnitFileState: {other:?}")),
        })
    }
}

impl UnitState {
    fn is_masked(self) -> bool {
        matches!(self, Self::Masked | Self::MaskedRuntime)
    }

    /// `true` when calling `EnableUnitFiles` would actually do something.
    fn needs_enable_call(self) -> bool {
        matches!(self, Self::Disabled)
    }

    /// `true` when calling `DisableUnitFiles` would actively tear down links.
    fn needs_disable_call(self) -> bool {
        matches!(
            self,
            Self::Enabled | Self::EnabledRuntime | Self::Linked | Self::LinkedRuntime | Self::Alias
        )
    }
}

/// Mirror of systemd's `ActiveState` unit property.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum ActiveState {
    Active,
    Reloading,
    Inactive,
    Failed,
    Activating,
    Deactivating,
}

impl FromStr for ActiveState {
    type Err = String;
    fn from_str(s: &str) -> Result<Self, Self::Err> {
        Ok(match s {
            "active" => Self::Active,
            "reloading" => Self::Reloading,
            "inactive" => Self::Inactive,
            "failed" => Self::Failed,
            "activating" => Self::Activating,
            "deactivating" => Self::Deactivating,
            other => return Err(format!("unknown ActiveState: {other:?}")),
        })
    }
}

impl ActiveState {
    /// `true` if the unit is running or actively transitioning to running.
    /// Matches Ansible's "started" predicate.
    fn is_running(self) -> bool {
        matches!(self, Self::Active | Self::Reloading | Self::Activating)
    }
}

fn is_no_such_unit(err: &zbus::Error) -> bool {
    if let zbus::Error::MethodError(name, _, _) = err {
        let n = name.to_string();
        n.ends_with(".NoSuchUnit") || n.ends_with(".NoSuchUnitFile") || n.ends_with(".LoadFailed")
    } else {
        false
    }
}

async fn read_file_state(
    manager: &ManagerProxy<'_>,
    name: &str,
) -> crate::Result<Option<UnitState>> {
    match manager.get_unit_file_state(name.to_owned()).await {
        Ok(s) => {
            let parsed = s
                .parse::<UnitState>()
                .map_err(anyhow_serde::Error::msg)
                .context("Parsing UnitFileState")?;
            Ok(Some(parsed))
        }
        Err(e) if is_no_such_unit(&e) => Ok(None),
        Err(e) => Err(e).context(format!("GetUnitFileState({name}) failed")),
    }
}

async fn read_active_state(
    manager: &ManagerProxy<'_>,
    conn: &zbus::Connection,
    name: &str,
) -> crate::Result<Option<ActiveState>> {
    let path = match manager.load_unit(name.to_owned()).await {
        Ok(p) => p,
        Err(e) if is_no_such_unit(&e) => return Ok(None),
        Err(e) => return Err(e).context(format!("LoadUnit({name}) failed")),
    };

    let unit = UnitProxy::builder(conn)
        .path(path)
        .context("Invalid unit object path")?
        .build()
        .await
        .context("Couldn't construct Unit proxy")?;

    let raw = unit
        .active_state()
        .await
        .context("Reading ActiveState failed")?;

    let parsed = raw
        .parse::<ActiveState>()
        .map_err(anyhow_serde::Error::msg)
        .context("Parsing ActiveState")?;
    Ok(Some(parsed))
}

/// Manage a single systemd unit on the remote machine.
#[crate::remote]
pub async fn systemd(state: State) -> crate::Result<super::Return> {
    let unit_name = {
        let name = &state.name;
        if name.contains('.') {
            name.to_owned()
        } else {
            format!("{name}.service")
        }
    };
    let conn = zbus::Connection::system()
        .await
        .context("Couldn't connect to the system bus")?;
    let manager = ManagerProxy::new(&conn)
        .await
        .context("Couldn't construct Manager proxy")?;

    // Unconditionally reload daemon - costs basically nothing
    manager.reload().await.context("Daemon-reload failed")?;

    let initial_unit_state = read_file_state(&manager, &unit_name).await?;
    let initial_active_state = read_active_state(&manager, &conn, &unit_name).await?;

    let observed_runtime = Some(
        if initial_active_state
            .map(ActiveState::is_running)
            .unwrap_or(false)
        {
            RuntimeAction::Started
        } else {
            RuntimeAction::Stopped
        },
    );
    let observed_enablement = initial_unit_state.map(|f| {
        if f.is_masked() {
            Enablement::Masked
        } else if matches!(f, UnitState::Enabled | UnitState::EnabledRuntime) {
            Enablement::Enabled(observed_runtime)
        } else {
            Enablement::Disabled(observed_runtime)
        }
    });
    let old_view = State {
        name: state.name.clone(),
        enablement: observed_enablement,
    };
    let new_view = State {
        name: state.name.clone(),
        enablement: match state.enablement {
            None => observed_enablement,
            Some(Enablement::Masked) => Some(Enablement::Masked),
            Some(Enablement::Enabled(None)) => Some(Enablement::Enabled(observed_runtime)),
            Some(Enablement::Disabled(None)) => Some(Enablement::Disabled(observed_runtime)),
            Some(Enablement::Enabled(Some(_))) | Some(Enablement::Disabled(Some(_))) => {
                state.enablement
            }
        },
    };
    if old_view != new_view {
        let _ = <State as crate::Diff>::diff(&old_view, &new_view, &mut std::io::stderr().lock());
    }

    let Some(wanted_enablement) = state.enablement else {
        return Ok(super::Return::Unchanged);
    };

    let mut changed = false;
    let unit_files = vec![unit_name.clone()];
    let is_masked = initial_unit_state
        .map(UnitState::is_masked)
        .unwrap_or(false);

    // Track state mutations across dynamic unmask steps cleanly
    let mut current_file_state = initial_unit_state;

    match wanted_enablement {
        Enablement::Masked => {
            if !is_masked {
                let unit_path = std::path::Path::new("/etc/systemd/system").join(&unit_name);
                if unit_path.exists() || unit_path.is_symlink() {
                    if unit_path.is_dir() {
                        std::fs::remove_dir_all(&unit_path)
                            .context("Failed to clear conflicting directory at mask target")?;
                    } else {
                        std::fs::remove_file(&unit_path)
                            .context("Failed to clear conflicting file/symlink at mask target")?;
                    }
                }
                manager
                    .mask_unit_files(unit_files, false, true) // Force point to /dev/null
                    .await
                    .context(format!("MaskUnitFiles({unit_name}) failed"))?;

                manager.reload().await.context("Daemon-reload failed")?;
                changed = true;
            }
        }
        Enablement::Disabled(_) => {
            if is_masked {
                manager
                    .unmask_unit_files(unit_files.clone(), false)
                    .await
                    .context(format!("UnmaskUnitFiles({unit_name}) failed"))?;
                changed = true;
                current_file_state = read_file_state(&manager, &unit_name).await?;
            }
            if current_file_state
                .map(UnitState::needs_disable_call)
                .unwrap_or(false)
            {
                manager
                    .disable_unit_files(unit_files, false)
                    .await
                    .context(format!("DisableUnitFiles({unit_name}) failed"))?;
                changed = true;
            }
        }
        Enablement::Enabled(_) => {
            if is_masked {
                manager
                    .unmask_unit_files(unit_files.clone(), false)
                    .await
                    .context(format!("UnmaskUnitFiles({unit_name}) failed"))?;
                changed = true;
                current_file_state = read_file_state(&manager, &unit_name).await?;
            }
            if current_file_state
                .map(UnitState::needs_enable_call)
                .unwrap_or(true)
            {
                manager
                    .enable_unit_files(unit_files, false, true) // Pass force=true to clear stale overrides
                    .await
                    .context(format!("EnableUnitFiles({unit_name}) failed"))?;
                changed = true;
            }
        }
    }

    if let Some(runtime_action) = wanted_enablement.runtime() {
        let active = if changed {
            read_active_state(&manager, &conn, &unit_name).await?
        } else {
            initial_active_state
        };
        let is_running = active.map(ActiveState::is_running).unwrap_or(false);

        match runtime_action {
            RuntimeAction::Started => {
                if !is_running {
                    manager
                        .start_unit(unit_name.clone(), "replace".to_owned())
                        .await
                        .context(format!("StartUnit({unit_name}) failed"))?;
                    changed = true;
                }
            }
            RuntimeAction::Stopped => {
                if is_running {
                    manager
                        .stop_unit(unit_name.clone(), "replace".to_owned())
                        .await
                        .context(format!("StopUnit({unit_name}) failed"))?;
                    changed = true;
                }
            }
            RuntimeAction::Restarted => {
                manager
                    .restart_unit(unit_name.clone(), "replace".to_owned())
                    .await
                    .context(format!("RestartUnit({unit_name}) failed"))?;
                changed = true;
            }
            RuntimeAction::Reloaded => {
                manager
                    .reload_unit(unit_name.clone(), "replace".to_owned())
                    .await
                    .context(format!("ReloadUnit({unit_name}) failed"))?;
                changed = true;
            }
        }
    }

    Ok(if changed {
        super::Return::Changed
    } else {
        super::Return::Unchanged
    })
}