use crate::error::ConnectError;
use crate::session::{
Command, CommandBranchTarget, CommandFlow, CommandInteraction, CommandOutputBranchRule,
PromptResponseRule,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
use std::collections::HashSet;
fn invalid_template(message: impl Into<String>) -> ConnectError {
ConnectError::InvalidCommandFlowTemplate(message.into())
}
fn default_true() -> bool {
true
}
fn default_var_kind() -> CommandFlowTemplateVarKind {
CommandFlowTemplateVarKind::String
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
#[serde(transparent)]
pub struct CommandFlowTemplateText {
value: String,
}
impl CommandFlowTemplateText {
pub fn template(value: impl Into<String>) -> Self {
Self {
value: value.into(),
}
}
fn render(&self, values: &Map<String, Value>) -> String {
render_inline_template(self.value.as_str(), values)
}
}
impl From<String> for CommandFlowTemplateText {
fn from(value: String) -> Self {
Self::template(value)
}
}
impl From<&str> for CommandFlowTemplateText {
fn from(value: &str) -> Self {
Self::template(value)
}
}
fn render_value_as_text(value: &Value) -> String {
match value {
Value::Null => String::new(),
Value::String(value) => value.clone(),
Value::Number(value) => value.to_string(),
Value::Bool(value) => value.to_string(),
other => other.to_string(),
}
}
fn render_inline_template(template: &str, values: &Map<String, Value>) -> String {
let mut output = String::new();
let mut rest = template;
while let Some(start) = rest.find("{{") {
output.push_str(&rest[..start]);
let after_start = &rest[start + 2..];
if let Some(end) = after_start.find("}}") {
let raw_name = &after_start[..end];
let name = raw_name.trim();
if name.is_empty() {
output.push_str("{{");
output.push_str(raw_name);
output.push_str("}}");
} else if let Some(value) = values.get(name) {
output.push_str(&render_value_as_text(value));
}
rest = &after_start[end + 2..];
} else {
output.push_str(&rest[start..]);
rest = "";
break;
}
}
output.push_str(rest);
output
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
pub struct CommandFlowTemplate {
pub name: String,
#[serde(default)]
pub description: Option<String>,
#[serde(default)]
pub vars: Vec<CommandFlowTemplateVar>,
#[serde(default = "default_true")]
pub stop_on_error: bool,
#[serde(default)]
pub default_mode: Option<String>,
#[serde(default)]
pub steps: Vec<CommandFlowTemplateStep>,
}
impl CommandFlowTemplate {
pub fn new(name: impl Into<String>, steps: Vec<CommandFlowTemplateStep>) -> Self {
Self {
name: name.into(),
description: None,
vars: Vec::new(),
stop_on_error: true,
default_mode: None,
steps,
}
}
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
pub fn with_vars(mut self, vars: Vec<CommandFlowTemplateVar>) -> Self {
self.vars = vars;
self
}
pub fn with_default_mode(mut self, default_mode: impl Into<String>) -> Self {
self.default_mode = Some(default_mode.into());
self
}
pub fn with_stop_on_error(mut self, stop_on_error: bool) -> Self {
self.stop_on_error = stop_on_error;
self
}
pub fn to_command_flow(
&self,
runtime: &CommandFlowTemplateRuntime,
) -> Result<CommandFlow, ConnectError> {
self.validate_definition()?;
let resolved_vars = self.resolve_runtime_vars(&runtime.vars)?;
let context = build_command_flow_values(self, runtime, resolved_vars);
let fallback_mode = runtime
.default_mode
.as_deref()
.or(self.default_mode.as_deref())
.unwrap_or_default()
.to_string();
let mut steps = Vec::with_capacity(self.steps.len());
for step in &self.steps {
let command = step.command.render(&context);
if command.trim().is_empty() {
return Err(invalid_template(format!(
"template '{}' rendered an empty command",
self.name
)));
}
let mode = if let Some(mode_template) = &step.mode {
let rendered = mode_template.render(&context);
let normalized = rendered.trim();
if normalized.is_empty() {
fallback_mode.clone()
} else {
normalized.to_string()
}
} else {
fallback_mode.clone()
};
let mut prompts = Vec::with_capacity(step.prompts.len());
for prompt in &step.prompts {
if prompt.patterns.is_empty() {
return Err(invalid_template(format!(
"template '{}' contains a prompt with no patterns",
self.name
)));
}
let mut response = prompt.response.render(&context);
if prompt.append_newline {
response.push('\n');
}
prompts.push(
PromptResponseRule::new(prompt.patterns.clone(), response)
.with_record_input(prompt.record_input),
);
}
steps.push(Command {
mode,
command,
timeout: step.timeout_secs,
dyn_params: Default::default(),
interaction: CommandInteraction { prompts },
output_branches: step.output_branches.clone(),
output_fallback: step.output_fallback.clone(),
});
}
Ok(CommandFlow {
steps,
stop_on_error: self.stop_on_error,
max_steps: None,
})
}
fn validate_definition(&self) -> Result<(), ConnectError> {
if self.name.trim().is_empty() {
return Err(invalid_template("template name cannot be empty"));
}
if self.steps.is_empty() {
return Err(invalid_template(format!(
"template '{}' has no steps",
self.name
)));
}
let mut seen = HashSet::new();
for field in &self.vars {
let name = field.name.trim();
if name.is_empty() {
return Err(invalid_template(format!(
"template '{}' contains a var with an empty name",
self.name
)));
}
if !is_safe_var_name(name) {
return Err(invalid_template(format!(
"template '{}' has invalid var name '{}'",
self.name, field.name
)));
}
if !seen.insert(name.to_string()) {
return Err(invalid_template(format!(
"template '{}' contains duplicate var '{}'",
self.name, field.name
)));
}
if let Some(default_value) = &field.default_value {
field.validate_value(default_value)?;
}
}
Ok(())
}
fn resolve_runtime_vars(&self, raw_vars: &Value) -> Result<Map<String, Value>, ConnectError> {
let mut vars = match raw_vars {
Value::Null => Map::new(),
Value::Object(map) => map.clone(),
_ => {
return Err(invalid_template(format!(
"template '{}' expects vars to be a JSON object",
self.name
)));
}
};
for field in &self.vars {
let key = field.name.trim();
let treat_as_missing =
!vars.contains_key(key) || vars.get(key).is_some_and(Value::is_null);
if treat_as_missing {
vars.remove(key);
if let Some(default_value) = &field.default_value {
vars.insert(key.to_string(), default_value.clone());
continue;
}
if field.required {
return Err(invalid_template(format!(
"template '{}' is missing required var '{}'",
self.name, field.name
)));
}
continue;
}
if let Some(value) = vars.get(key) {
field.validate_value(value)?;
}
}
Ok(vars)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
pub struct CommandFlowTemplateStep {
pub command: CommandFlowTemplateText,
#[serde(default)]
pub mode: Option<CommandFlowTemplateText>,
#[serde(default)]
pub timeout_secs: Option<u64>,
#[serde(default)]
pub prompts: Vec<CommandFlowTemplatePrompt>,
#[serde(default)]
pub output_branches: Vec<CommandOutputBranchRule>,
#[serde(default)]
pub output_fallback: CommandBranchTarget,
}
impl CommandFlowTemplateStep {
pub fn new(command: impl Into<CommandFlowTemplateText>) -> Self {
Self {
command: command.into(),
mode: None,
timeout_secs: None,
prompts: Vec::new(),
output_branches: Vec::new(),
output_fallback: CommandBranchTarget::Next,
}
}
pub fn from_template(command: impl Into<String>) -> Self {
Self::new(CommandFlowTemplateText::template(command))
}
pub fn with_mode(mut self, mode: impl Into<CommandFlowTemplateText>) -> Self {
self.mode = Some(mode.into());
self
}
pub fn with_timeout_secs(mut self, timeout_secs: u64) -> Self {
self.timeout_secs = Some(timeout_secs);
self
}
pub fn with_prompts(mut self, prompts: Vec<CommandFlowTemplatePrompt>) -> Self {
self.prompts = prompts;
self
}
pub fn with_output_branches(mut self, output_branches: Vec<CommandOutputBranchRule>) -> Self {
self.output_branches = output_branches;
self
}
pub fn with_output_fallback(mut self, output_fallback: CommandBranchTarget) -> Self {
self.output_fallback = output_fallback;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
pub struct CommandFlowTemplatePrompt {
pub patterns: Vec<String>,
pub response: CommandFlowTemplateText,
#[serde(default)]
pub append_newline: bool,
#[serde(default)]
pub record_input: bool,
}
impl CommandFlowTemplatePrompt {
pub fn new(patterns: Vec<String>, response: impl Into<CommandFlowTemplateText>) -> Self {
Self {
patterns,
response: response.into(),
append_newline: false,
record_input: false,
}
}
pub fn from_template(patterns: Vec<String>, response: impl Into<String>) -> Self {
Self::new(patterns, CommandFlowTemplateText::template(response))
}
pub fn with_append_newline(mut self, append_newline: bool) -> Self {
self.append_newline = append_newline;
self
}
pub fn with_record_input(mut self, record_input: bool) -> Self {
self.record_input = record_input;
self
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum CommandFlowTemplateVarKind {
String,
Secret,
Number,
Boolean,
Json,
}
impl CommandFlowTemplateVarKind {
fn validate_value(self, value: &Value) -> bool {
match self {
Self::String | Self::Secret => value.is_string(),
Self::Number => value.is_number(),
Self::Boolean => value.is_boolean(),
Self::Json => true,
}
}
fn label(self) -> &'static str {
match self {
Self::String => "string",
Self::Secret => "secret",
Self::Number => "number",
Self::Boolean => "boolean",
Self::Json => "json",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
pub struct CommandFlowTemplateVar {
pub name: String,
#[serde(default)]
pub label: Option<String>,
#[serde(default)]
pub description: Option<String>,
#[serde(rename = "type", default = "default_var_kind")]
pub kind: CommandFlowTemplateVarKind,
#[serde(default)]
pub required: bool,
#[serde(default)]
pub placeholder: Option<String>,
#[serde(default)]
pub options: Vec<String>,
#[serde(rename = "default", default)]
pub default_value: Option<Value>,
}
impl CommandFlowTemplateVar {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
label: None,
description: None,
kind: default_var_kind(),
required: false,
placeholder: None,
options: Vec::new(),
default_value: None,
}
}
pub fn with_label(mut self, label: impl Into<String>) -> Self {
self.label = Some(label.into());
self
}
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
pub fn with_kind(mut self, kind: CommandFlowTemplateVarKind) -> Self {
self.kind = kind;
self
}
pub fn with_required(mut self, required: bool) -> Self {
self.required = required;
self
}
pub fn with_placeholder(mut self, placeholder: impl Into<String>) -> Self {
self.placeholder = Some(placeholder.into());
self
}
pub fn with_options<I, S>(mut self, options: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.options = options.into_iter().map(Into::into).collect();
self
}
pub fn with_default_value(mut self, default_value: Value) -> Self {
self.default_value = Some(default_value);
self
}
pub fn display_label(&self) -> &str {
self.label
.as_deref()
.filter(|value| !value.trim().is_empty())
.unwrap_or(self.name.as_str())
}
fn validate_value(&self, value: &Value) -> Result<(), ConnectError> {
if !self.kind.validate_value(value) {
return Err(invalid_template(format!(
"var '{}' expected {}",
self.name,
self.kind.label()
)));
}
if !self.options.is_empty() && !matches!(self.kind, CommandFlowTemplateVarKind::Json) {
let Some(text) = value.as_str() else {
return Err(invalid_template(format!(
"var '{}' expected one of [{}]",
self.name,
self.options.join(", ")
)));
};
if !self.options.iter().any(|option| option == text) {
return Err(invalid_template(format!(
"var '{}' expected one of [{}]",
self.name,
self.options.join(", ")
)));
}
}
Ok(())
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
pub struct CommandFlowTemplateRuntime {
#[serde(default)]
pub default_mode: Option<String>,
#[serde(default)]
pub connection_name: Option<String>,
#[serde(default)]
pub host: Option<String>,
#[serde(default)]
pub username: Option<String>,
#[serde(default)]
pub device_profile: Option<String>,
#[serde(default)]
pub vars: Value,
}
impl CommandFlowTemplateRuntime {
pub fn new() -> Self {
Self::default()
}
pub fn with_default_mode(mut self, default_mode: impl Into<String>) -> Self {
self.default_mode = Some(default_mode.into());
self
}
pub fn with_vars(mut self, vars: Value) -> Self {
self.vars = vars;
self
}
}
fn build_command_flow_values(
template: &CommandFlowTemplate,
runtime: &CommandFlowTemplateRuntime,
mut vars: Map<String, Value>,
) -> Map<String, Value> {
vars.insert(
"default_mode".to_string(),
runtime
.default_mode
.clone()
.or_else(|| template.default_mode.clone())
.map(Value::String)
.unwrap_or(Value::Null),
);
vars.insert(
"connection_name".to_string(),
runtime
.connection_name
.clone()
.map(Value::String)
.unwrap_or(Value::Null),
);
vars.insert(
"host".to_string(),
runtime
.host
.clone()
.map(Value::String)
.unwrap_or(Value::Null),
);
vars.insert(
"username".to_string(),
runtime
.username
.clone()
.map(Value::String)
.unwrap_or(Value::Null),
);
vars.insert(
"device_profile".to_string(),
runtime
.device_profile
.clone()
.map(Value::String)
.unwrap_or(Value::Null),
);
vars
}
fn is_safe_var_name(name: &str) -> bool {
let mut chars = name.chars();
match chars.next() {
Some(ch) if ch.is_ascii_alphabetic() || ch == '_' => {}
_ => return false,
}
chars.all(|ch| ch.is_ascii_alphanumeric() || ch == '_')
}
#[cfg(test)]
mod tests {
use super::*;
use crate::session::{CommandBranchTarget, CommandOutputBranchRule, CommandOutputBranchSource};
use serde_json::json;
#[test]
fn renders_template_with_inline_text() {
let template = CommandFlowTemplate::new(
"demo",
vec![
CommandFlowTemplateStep::new("copy {{protocol}}: {{device_path}}")
.with_timeout_secs(300)
.with_prompts(vec![
CommandFlowTemplatePrompt::new(
vec!["(?i)^Address.*$".to_string()],
"{{server_addr}}",
)
.with_append_newline(true)
.with_record_input(true),
]),
],
)
.with_default_mode("Enable")
.with_vars(vec![
CommandFlowTemplateVar::new("protocol")
.with_required(true)
.with_options(["scp", "tftp"]),
CommandFlowTemplateVar::new("device_path").with_required(true),
CommandFlowTemplateVar::new("server_addr").with_required(true),
]);
let flow = template
.to_command_flow(&CommandFlowTemplateRuntime::new().with_vars(json!({
"protocol": "scp",
"device_path": "flash:/image.bin",
"server_addr": "192.0.2.10",
})))
.expect("render flow");
assert!(flow.stop_on_error);
assert_eq!(flow.steps.len(), 1);
assert_eq!(flow.steps[0].mode, "Enable");
assert_eq!(flow.steps[0].command, "copy scp: flash:/image.bin");
assert_eq!(
flow.steps[0].interaction.prompts[0].response,
"192.0.2.10\n"
);
}
#[test]
fn missing_required_var_fails_rendering() {
let template =
CommandFlowTemplate::new("demo", vec![CommandFlowTemplateStep::new("show {{host}}")])
.with_vars(vec![
CommandFlowTemplateVar::new("host").with_required(true),
]);
let err = template
.to_command_flow(&CommandFlowTemplateRuntime::new())
.expect_err("missing required var should fail");
assert!(matches!(err, ConnectError::InvalidCommandFlowTemplate(_)));
}
#[test]
fn runtime_vars_must_be_json_object() {
let template =
CommandFlowTemplate::new("demo", vec![CommandFlowTemplateStep::new("show version")]);
let err = template
.to_command_flow(&CommandFlowTemplateRuntime::new().with_vars(json!(["bad"])))
.expect_err("non-object vars should fail");
assert!(matches!(err, ConnectError::InvalidCommandFlowTemplate(_)));
}
#[test]
fn inline_template_text_renders_placeholders() {
let template = CommandFlowTemplate::new(
"demo",
vec![CommandFlowTemplateStep::new(
"copy {{protocol}}: {{device_path}}",
)],
);
let flow = template
.to_command_flow(&CommandFlowTemplateRuntime::new().with_vars(json!({
"protocol": "scp",
"device_path": "flash:/image.bin",
})))
.expect("render flow");
assert_eq!(flow.steps[0].command, "copy scp: flash:/image.bin");
}
#[test]
fn template_step_renders_output_branch_rules() {
let template = CommandFlowTemplate::new(
"branch-demo",
vec![
CommandFlowTemplateStep::new("show copy status")
.with_output_branches(vec![
CommandOutputBranchRule::new(
vec![r"(?i)retry".to_string()],
CommandBranchTarget::Jump { step_index: 0 },
)
.with_source(CommandOutputBranchSource::Content),
])
.with_output_fallback(CommandBranchTarget::StopFailure),
],
);
let flow = template
.to_command_flow(&CommandFlowTemplateRuntime::new())
.expect("render flow");
assert_eq!(flow.steps.len(), 1);
assert_eq!(flow.steps[0].output_branches.len(), 1);
assert_eq!(
flow.steps[0].output_branches[0].source,
CommandOutputBranchSource::Content
);
assert_eq!(
flow.steps[0].output_fallback,
CommandBranchTarget::StopFailure
);
}
#[test]
fn prompt_and_mode_accept_plain_text_builders() {
let template = CommandFlowTemplate::new(
"demo",
vec![
CommandFlowTemplateStep::new("show {{target}}")
.with_mode("{{exec_mode}}")
.with_prompts(vec![
CommandFlowTemplatePrompt::new(
vec!["(?i)^Proceed\\?\\s*$".to_string()],
"yes",
)
.with_append_newline(true),
]),
],
)
.with_default_mode("Enable")
.with_vars(vec![
CommandFlowTemplateVar::new("target").with_required(true),
CommandFlowTemplateVar::new("exec_mode").with_required(true),
]);
let flow = template
.to_command_flow(&CommandFlowTemplateRuntime::new().with_vars(json!({
"target": "version",
"exec_mode": "Config",
})))
.expect("render flow");
assert_eq!(flow.steps[0].mode, "Config");
assert_eq!(flow.steps[0].command, "show version");
assert_eq!(flow.steps[0].interaction.prompts[0].response, "yes\n");
}
}