use async_ssh2_tokio::client::{AuthMethod, Client};
use async_ssh2_tokio::{Config, ServerCheckMethod};
use log::{debug, trace};
use moka::future::Cache;
use once_cell::sync::Lazy;
use sha2::{Digest, Sha256};
use russh::{ChannelMsg, Preferred};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::borrow::Cow;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::mpsc::{self, Receiver, Sender};
use tokio::sync::{RwLock, oneshot};
use crate::config;
use crate::error::ConnectError;
use super::device::{DeviceHandler, IGNORE_START_LINE};
pub use recording::{
NormalizeOptions, ReplayContext, SessionEvent, SessionRecordEntry, SessionRecordLevel,
SessionRecorder, SessionReplayer,
};
pub use security::{ConnectionSecurityOptions, SecurityLevel};
pub use transaction::{
RollbackPolicy, TxBlock, TxOperationStepResult, TxResult, TxStep, TxStepExecutionState,
TxStepResult, TxStepRollbackState, TxWorkflow, TxWorkflowResult, failed_block_rollback_summary,
workflow_rollback_order,
};
pub static MANAGER: Lazy<SshConnectionManager> = Lazy::new(SshConnectionManager::new);
pub struct ConnectionRequest {
pub user: String,
pub addr: String,
pub port: u16,
pub password: String,
pub enable_password: Option<String>,
pub handler: DeviceHandler,
}
impl ConnectionRequest {
pub fn new(
user: String,
addr: String,
port: u16,
password: String,
enable_password: Option<String>,
handler: DeviceHandler,
) -> Self {
Self {
user,
addr,
port,
password,
enable_password,
handler,
}
}
pub fn device_addr(&self) -> String {
format!("{}@{}:{}", self.user, self.addr, self.port)
}
}
#[derive(Clone, Default)]
pub struct ExecutionContext {
pub security_options: ConnectionSecurityOptions,
pub sys: Option<String>,
}
impl ExecutionContext {
pub fn new() -> Self {
Self::default()
}
pub fn with_security_options(mut self, security_options: ConnectionSecurityOptions) -> Self {
self.security_options = security_options;
self
}
pub fn with_sys(mut self, sys: Option<String>) -> Self {
self.sys = sys;
self
}
}
pub struct SharedSshClient {
client: Client,
sender: Sender<String>,
recv: Receiver<String>,
handler: DeviceHandler,
prompt: String,
password_hash: [u8; 32],
enable_password_hash: Option<[u8; 32]>,
security_options: ConnectionSecurityOptions,
recorder: Option<SessionRecorder>,
}
#[derive(Default, Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
pub struct CommandDynamicParams {
#[serde(default, alias = "EnablePassword")]
pub enable_password: Option<String>,
#[serde(default, alias = "SudoPassword")]
pub sudo_password: Option<String>,
#[serde(default, flatten)]
pub extra: HashMap<String, String>,
}
impl CommandDynamicParams {
pub fn is_empty(&self) -> bool {
self.enable_password.is_none() && self.sudo_password.is_none() && self.extra.is_empty()
}
pub fn insert_extra(
&mut self,
key: impl Into<String>,
value: impl Into<String>,
) -> Option<String> {
self.extra.insert(key.into(), value.into())
}
pub(crate) fn runtime_values(&self) -> HashMap<String, String> {
let mut values = self.extra.clone();
if let Some(value) = self.enable_password.as_ref() {
values.insert("EnablePassword".to_string(), value.clone());
}
if let Some(value) = self.sudo_password.as_ref() {
values.insert("SudoPassword".to_string(), value.clone());
}
values
}
}
#[derive(Default, Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
pub struct PromptResponseRule {
pub patterns: Vec<String>,
pub response: String,
#[serde(default)]
pub record_input: bool,
}
impl PromptResponseRule {
pub fn new(patterns: Vec<String>, response: String) -> Self {
Self {
patterns,
response,
record_input: false,
}
}
pub fn with_record_input(mut self, record_input: bool) -> Self {
self.record_input = record_input;
self
}
}
#[derive(Default, Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
pub struct CommandInteraction {
#[serde(default)]
pub prompts: Vec<PromptResponseRule>,
}
impl CommandInteraction {
pub fn is_empty(&self) -> bool {
self.prompts.is_empty()
}
pub fn push_prompt(mut self, prompt: PromptResponseRule) -> Self {
self.prompts.push(prompt);
self
}
}
#[derive(Default, Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum CommandOutputBranchSource {
#[default]
All,
Content,
Prompt,
}
#[derive(Default, Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum CommandBranchTarget {
#[default]
Next,
StopSuccess,
StopFailure,
Jump {
step_index: usize,
},
}
#[derive(Default, Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
pub struct CommandOutputBranchRule {
pub patterns: Vec<String>,
#[serde(default)]
pub source: CommandOutputBranchSource,
#[serde(default)]
pub target: CommandBranchTarget,
}
impl CommandOutputBranchRule {
pub fn new(patterns: Vec<String>, target: CommandBranchTarget) -> Self {
Self {
patterns,
source: CommandOutputBranchSource::All,
target,
}
}
pub fn with_source(mut self, source: CommandOutputBranchSource) -> Self {
self.source = source;
self
}
}
#[derive(Default, Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
pub struct Command {
pub mode: String,
pub command: String,
pub timeout: Option<u64>,
#[serde(default)]
pub dyn_params: CommandDynamicParams,
#[serde(default)]
pub interaction: CommandInteraction,
#[serde(default)]
pub output_branches: Vec<CommandOutputBranchRule>,
#[serde(default)]
pub output_fallback: CommandBranchTarget,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum SessionOperation {
Command(Command),
Flow(CommandFlow),
Template {
template: crate::templates::CommandFlowTemplate,
runtime: crate::templates::CommandFlowTemplateRuntime,
},
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
pub struct SessionOperationSummary {
pub kind: String,
pub mode: String,
pub description: String,
pub step_count: usize,
}
impl SessionOperation {
pub fn command(command: Command) -> Self {
Self::Command(command)
}
pub fn flow(flow: CommandFlow) -> Self {
Self::Flow(flow)
}
pub fn template(
template: crate::templates::CommandFlowTemplate,
runtime: crate::templates::CommandFlowTemplateRuntime,
) -> Self {
Self::Template { template, runtime }
}
pub fn summary(&self) -> Result<SessionOperationSummary, ConnectError> {
self.summary_impl()
}
}
impl From<Command> for SessionOperation {
fn from(value: Command) -> Self {
Self::Command(value)
}
}
impl From<CommandFlow> for SessionOperation {
fn from(value: CommandFlow) -> Self {
Self::Flow(value)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct FileUploadRequest {
pub local_path: String,
pub remote_path: String,
pub timeout_secs: Option<u64>,
pub buffer_size: Option<usize>,
pub show_progress: bool,
}
impl FileUploadRequest {
pub fn new(local_path: String, remote_path: String) -> Self {
Self {
local_path,
remote_path,
timeout_secs: None,
buffer_size: None,
show_progress: false,
}
}
pub fn with_timeout_secs(mut self, timeout_secs: u64) -> Self {
self.timeout_secs = Some(timeout_secs);
self
}
pub fn with_buffer_size(mut self, buffer_size: usize) -> Self {
self.buffer_size = Some(buffer_size);
self
}
pub fn with_progress_reporting(mut self, show_progress: bool) -> Self {
self.show_progress = show_progress;
self
}
}
fn default_stop_on_error() -> bool {
true
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
pub struct CommandFlow {
#[serde(default)]
pub steps: Vec<Command>,
#[serde(default = "default_stop_on_error")]
pub stop_on_error: bool,
#[serde(default)]
pub max_steps: Option<usize>,
}
impl Default for CommandFlow {
fn default() -> Self {
Self {
steps: Vec::new(),
stop_on_error: true,
max_steps: None,
}
}
}
impl CommandFlow {
pub fn new(steps: Vec<Command>) -> Self {
Self {
steps,
..Self::default()
}
}
pub fn with_stop_on_error(mut self, stop_on_error: bool) -> Self {
self.stop_on_error = stop_on_error;
self
}
pub fn with_max_steps(mut self, max_steps: usize) -> Self {
self.max_steps = Some(max_steps);
self
}
}
pub struct CmdJob {
pub data: Command,
pub sys: Option<String>,
pub responder: oneshot::Sender<Result<Output, ConnectError>>,
}
#[derive(Debug, Clone)]
pub struct Output {
pub success: bool,
pub exit_code: Option<i32>,
pub content: String,
pub all: String,
pub prompt: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
pub struct SessionOperationStepOutput {
pub step_index: usize,
pub mode: String,
pub operation_summary: String,
pub success: bool,
pub exit_code: Option<i32>,
pub content: String,
pub all: String,
pub prompt: Option<String>,
}
impl SessionOperationStepOutput {
pub fn into_output(self) -> Output {
Output {
success: self.success,
exit_code: self.exit_code,
content: self.content,
all: self.all,
prompt: self.prompt,
}
}
fn to_output(&self) -> Output {
Output {
success: self.success,
exit_code: self.exit_code,
content: self.content.clone(),
all: self.all.clone(),
prompt: self.prompt.clone(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
pub struct SessionOperationOutput {
pub success: bool,
#[serde(default)]
pub steps: Vec<SessionOperationStepOutput>,
}
impl SessionOperationOutput {
pub fn into_command_flow_output(self) -> CommandFlowOutput {
CommandFlowOutput {
success: self.success,
outputs: self
.steps
.into_iter()
.map(SessionOperationStepOutput::into_output)
.collect(),
}
}
pub fn to_command_flow_output(&self) -> CommandFlowOutput {
CommandFlowOutput {
success: self.success,
outputs: self
.steps
.iter()
.map(SessionOperationStepOutput::to_output)
.collect(),
}
}
}
#[derive(Debug)]
pub struct SessionOperationExecutionError {
error: ConnectError,
partial_output: SessionOperationOutput,
}
impl SessionOperationExecutionError {
pub fn new(error: ConnectError, partial_output: SessionOperationOutput) -> Self {
Self {
error,
partial_output,
}
}
pub fn error(&self) -> &ConnectError {
&self.error
}
pub fn partial_output(&self) -> &SessionOperationOutput {
&self.partial_output
}
pub fn into_parts(self) -> (ConnectError, SessionOperationOutput) {
(self.error, self.partial_output)
}
}
impl std::fmt::Display for SessionOperationExecutionError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.error.fmt(f)
}
}
impl std::error::Error for SessionOperationExecutionError {}
#[derive(Debug, Clone)]
pub struct CommandFlowOutput {
pub success: bool,
pub outputs: Vec<Output>,
}
#[derive(Clone)]
pub struct SshConnectionManager {
cache: Cache<String, (mpsc::Sender<CmdJob>, Arc<RwLock<SharedSshClient>>)>,
}
mod client;
mod manager;
mod recording;
mod security;
mod transaction;
#[cfg(test)]
mod tests {
use super::*;
use crate::templates;
#[test]
fn connection_request_formats_device_addr() {
let request = ConnectionRequest::new(
"admin".to_string(),
"192.168.1.1".to_string(),
22,
"password".to_string(),
None,
templates::cisco().expect("template"),
);
assert_eq!(request.device_addr(), "admin@192.168.1.1:22");
}
#[test]
fn execution_context_builder_overrides_defaults() {
let context = ExecutionContext::new()
.with_security_options(ConnectionSecurityOptions::legacy_compatible())
.with_sys(Some("vsys1".to_string()));
assert_eq!(
context.security_options,
ConnectionSecurityOptions::legacy_compatible()
);
assert_eq!(context.sys.as_deref(), Some("vsys1"));
}
#[test]
fn file_upload_request_builder_overrides_defaults() {
let upload = FileUploadRequest::new(
"./fixtures/config.txt".to_string(),
"/tmp/config.txt".to_string(),
)
.with_timeout_secs(30)
.with_buffer_size(8192)
.with_progress_reporting(true);
assert_eq!(upload.local_path, "./fixtures/config.txt");
assert_eq!(upload.remote_path, "/tmp/config.txt");
assert_eq!(upload.timeout_secs, Some(30));
assert_eq!(upload.buffer_size, Some(8192));
assert!(upload.show_progress);
}
#[test]
fn operation_execution_error_preserves_partial_output() {
let err = SessionOperationExecutionError::new(
ConnectError::ExecTimeout("show version".to_string()),
SessionOperationOutput {
success: false,
steps: vec![SessionOperationStepOutput {
step_index: 0,
mode: "Enable".to_string(),
operation_summary: "terminal length 0".to_string(),
success: true,
exit_code: None,
content: "ok".to_string(),
all: "ok".to_string(),
prompt: Some("router#".to_string()),
}],
},
);
assert!(matches!(err.error(), ConnectError::ExecTimeout(_)));
assert_eq!(err.partial_output().steps.len(), 1);
assert_eq!(
err.partial_output().steps[0].operation_summary,
"terminal length 0"
);
}
#[test]
fn command_default_has_empty_dyn_params() {
let cmd = Command::default();
assert_eq!(cmd.timeout, None);
assert!(cmd.mode.is_empty());
assert!(cmd.command.is_empty());
assert!(cmd.dyn_params.is_empty());
assert!(cmd.interaction.is_empty());
assert!(cmd.output_branches.is_empty());
assert_eq!(cmd.output_fallback, CommandBranchTarget::Next);
}
#[test]
fn command_dynamic_params_collect_unknown_keys_into_extra() {
let cmd: Command = serde_json::from_value(serde_json::json!({
"mode": "Enable",
"command": "show version",
"dyn_params": {
"EnablePassword": "enable\n",
"SudoPassword": "sudo\n",
"CustomPrompt": "yes\n"
}
}))
.expect("deserialize command");
assert_eq!(cmd.dyn_params.enable_password.as_deref(), Some("enable\n"));
assert_eq!(cmd.dyn_params.sudo_password.as_deref(), Some("sudo\n"));
assert_eq!(
cmd.dyn_params.extra.get("CustomPrompt"),
Some(&"yes\n".to_string())
);
assert_eq!(
cmd.dyn_params.runtime_values().get("EnablePassword"),
Some(&"enable\n".to_string())
);
}
#[test]
fn command_flow_defaults_to_stop_on_error() {
let flow = CommandFlow::default();
assert!(flow.steps.is_empty());
assert!(flow.stop_on_error);
assert_eq!(flow.max_steps, None);
}
#[test]
fn command_output_branch_rule_builder_sets_source() {
let rule = CommandOutputBranchRule::new(
vec![r"(?i)completed".to_string()],
CommandBranchTarget::StopSuccess,
)
.with_source(CommandOutputBranchSource::Content);
assert_eq!(rule.patterns, vec![r"(?i)completed".to_string()]);
assert_eq!(rule.source, CommandOutputBranchSource::Content);
assert_eq!(rule.target, CommandBranchTarget::StopSuccess);
}
#[test]
fn prompt_response_rule_builder_sets_recording_flag() {
let rule =
PromptResponseRule::new(vec![r"^Password:\s*$".to_string()], "secret\n".to_string())
.with_record_input(true);
assert_eq!(rule.patterns, vec![r"^Password:\s*$".to_string()]);
assert_eq!(rule.response, "secret\n");
assert!(rule.record_input);
}
}