rneter 0.4.5

SSH connection manager for network devices with intelligent state machine handling
Documentation
use std::collections::HashMap;

use schemars::JsonSchema;
use serde::{Deserialize, Serialize};

use super::SessionOperation;

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum HookTrigger<'a> {
    AfterConnect,
    BeforeDisconnect,
    BeforeExitState(&'a str),
    AfterEnterState(&'a str),
}

impl HookTrigger<'_> {
    pub(crate) fn label(&self) -> &'static str {
        match self {
            Self::AfterConnect => "after_connect",
            Self::BeforeDisconnect => "before_disconnect",
            Self::BeforeExitState(_) => "before_exit_state",
            Self::AfterEnterState(_) => "after_enter_state",
        }
    }

    pub(crate) fn state(&self) -> Option<&str> {
        match self {
            Self::BeforeExitState(state) | Self::AfterEnterState(state) => Some(state),
            Self::AfterConnect | Self::BeforeDisconnect => None,
        }
    }
}

#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Default)]
#[serde(rename_all = "snake_case")]
pub enum HookFailurePolicy {
    Required,
    #[default]
    BestEffort,
}

#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
pub struct HookAction {
    pub name: String,
    pub operation: SessionOperation,
    #[serde(default)]
    pub failure_policy: HookFailurePolicy,
    #[serde(default)]
    pub record_output: bool,
}

impl HookAction {
    pub fn new(name: impl Into<String>, operation: SessionOperation) -> Self {
        Self {
            name: name.into(),
            operation,
            failure_policy: HookFailurePolicy::BestEffort,
            record_output: false,
        }
    }
}

#[derive(Debug, Clone, Default, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
pub struct SessionHooks {
    #[serde(default)]
    pub after_connect: Vec<HookAction>,
    #[serde(default)]
    pub before_disconnect: Vec<HookAction>,
    #[serde(default)]
    pub after_enter_state: HashMap<String, Vec<HookAction>>,
    #[serde(default)]
    pub before_exit_state: HashMap<String, Vec<HookAction>>,
}

impl SessionHooks {
    pub(crate) fn normalized(self) -> Self {
        Self {
            after_connect: self.after_connect,
            before_disconnect: self.before_disconnect,
            after_enter_state: normalize_state_hook_map(self.after_enter_state),
            before_exit_state: normalize_state_hook_map(self.before_exit_state),
        }
    }

    pub fn after_enter_state(&self, state: &str) -> &[HookAction] {
        state_actions(&self.after_enter_state, state)
    }

    pub fn before_exit_state(&self, state: &str) -> &[HookAction] {
        state_actions(&self.before_exit_state, state)
    }
}

fn normalize_state_hook_map(
    map: HashMap<String, Vec<HookAction>>,
) -> HashMap<String, Vec<HookAction>> {
    let mut normalized = HashMap::with_capacity(map.len());
    for (state, actions) in map {
        normalized
            .entry(state.to_ascii_lowercase())
            .or_insert_with(Vec::new)
            .extend(actions);
    }
    normalized
}

fn state_actions<'a>(map: &'a HashMap<String, Vec<HookAction>>, state: &str) -> &'a [HookAction] {
    map.get(state)
        .or_else(|| map.get(&state.to_ascii_lowercase()))
        .or_else(|| {
            map.iter()
                .find(|(key, _)| key.eq_ignore_ascii_case(state))
                .map(|(_, actions)| actions)
        })
        .map(Vec::as_slice)
        .unwrap_or(&[])
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::session::{Command, SessionOperation};

    #[test]
    fn hook_action_defaults_to_best_effort_and_no_record_output() {
        let hook = HookAction::new(
            "disable-paging",
            SessionOperation::from(Command {
                mode: "enable".to_string(),
                command: "terminal length 0".to_string(),
                ..Command::default()
            }),
        );

        assert_eq!(hook.failure_policy, HookFailurePolicy::BestEffort);
        assert!(!hook.record_output);
    }

    #[test]
    fn session_hooks_state_lookup_returns_empty_when_missing() {
        let hooks = SessionHooks::default();
        assert!(hooks.after_enter_state("config").is_empty());
        assert!(hooks.before_exit_state("enable").is_empty());
    }

    #[test]
    fn session_hooks_state_lookup_is_case_insensitive() {
        let hook = HookAction::new(
            "prepare-config",
            SessionOperation::from(Command {
                mode: "enable".to_string(),
                command: "configure terminal".to_string(),
                ..Command::default()
            }),
        );
        let hooks = SessionHooks {
            after_enter_state: HashMap::from([("Config".to_string(), vec![hook.clone()])]),
            before_exit_state: HashMap::from([("ENABLE".to_string(), vec![hook.clone()])]),
            ..SessionHooks::default()
        };

        assert_eq!(hooks.after_enter_state("config"), std::slice::from_ref(&hook));
        assert_eq!(hooks.after_enter_state("CONFIG"), std::slice::from_ref(&hook));
        assert_eq!(hooks.before_exit_state("enable"), std::slice::from_ref(&hook));
        assert_eq!(hooks.before_exit_state("Enable"), &[hook]);
    }

    #[test]
    fn session_hooks_normalized_lowercases_state_maps() {
        let hook = HookAction::new(
            "prepare-config",
            SessionOperation::from(Command {
                mode: "enable".to_string(),
                command: "configure terminal".to_string(),
                ..Command::default()
            }),
        );
        let hooks = SessionHooks {
            after_enter_state: HashMap::from([("Config".to_string(), vec![hook.clone()])]),
            before_exit_state: HashMap::from([("ENABLE".to_string(), vec![hook.clone()])]),
            ..SessionHooks::default()
        }
        .normalized();

        assert!(hooks.after_enter_state.contains_key("config"));
        assert!(!hooks.after_enter_state.contains_key("Config"));
        assert!(hooks.before_exit_state.contains_key("enable"));
        assert_eq!(hooks.after_enter_state("CONFIG"), std::slice::from_ref(&hook));
        assert_eq!(hooks.before_exit_state("Enable"), &[hook]);
    }
}