use chrono::{DateTime, Utc};
use once_cell::sync::Lazy;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
use std::process::Command;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use tokio::task;
use tracing::{debug, info, warn};
use crate::service_error::ServiceError;
#[derive(Debug, Clone, Deserialize, JsonSchema)]
pub struct SshRegisterArgs {
pub id: String,
pub host: String,
pub user: String,
#[serde(default = "default_port")]
#[schemars(default = "default_port")]
pub port: u16,
pub key_passphrase: Option<String>,
#[serde(default = "default_known_hosts")]
#[schemars(default = "default_known_hosts")]
pub known_hosts_path: String,
#[serde(default)]
pub client_id: Option<String>,
}
#[derive(Debug, Clone, Deserialize, JsonSchema)]
pub struct SshExecArgs {
pub id: String,
pub command: String,
#[serde(default = "default_timeout_secs")]
#[schemars(default = "default_timeout_secs")]
pub timeout_secs: u64,
#[serde(default)]
pub context: Option<String>,
}
#[derive(Debug, Clone, Deserialize, JsonSchema)]
pub struct SshUnregisterArgs {
pub id: String,
}
#[derive(Debug, Clone, Serialize)]
pub struct SshExecResult {
pub exec_id: String,
pub exit_code: i32,
pub stdout: String,
pub stderr: String,
pub duration_ms: u128,
}
#[derive(Debug, Clone, Serialize)]
pub struct SshRegisterResult {
pub id: String,
pub host: String,
pub port: u16,
pub user: String,
}
#[derive(Debug, Clone, Serialize)]
pub struct SshUnregisterResult {
pub id: String,
pub existed: bool,
}
#[derive(Debug, Clone, Serialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub input_schema: Value,
}
pub fn get_tool_definitions() -> Vec<ToolDefinition> {
vec![
ToolDefinition {
name: "ssh_register_target".to_string(),
description: "Register or replace an SSH target configuration".to_string(),
input_schema: schemars::schema_for!(SshRegisterArgs).to_value(),
},
ToolDefinition {
name: "ssh_exec".to_string(),
description: "Execute a command on a registered target".to_string(),
input_schema: schemars::schema_for!(SshExecArgs).to_value(),
},
ToolDefinition {
name: "ssh_unregister_target".to_string(),
description: "Unregister a previously registered target".to_string(),
input_schema: schemars::schema_for!(SshUnregisterArgs).to_value(),
},
]
}
trait SchemaToValue {
fn to_value(&self) -> Value;
}
impl SchemaToValue for schemars::schema::RootSchema {
fn to_value(&self) -> Value {
serde_json::to_value(self).unwrap_or(json!({}))
}
}
#[derive(Debug, Clone, Serialize)]
struct ExecLogEntry {
ts: DateTime<Utc>,
target: String,
#[serde(skip_serializing_if = "Option::is_none")]
client: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
ctx: Option<String>,
cmd: String,
exit: i32,
dur_ms: u128,
#[serde(skip_serializing_if = "String::is_empty")]
out: String,
#[serde(skip_serializing_if = "String::is_empty")]
err: String,
}
#[derive(Debug, Clone)]
pub struct TargetConfig {
pub host: String,
pub user: String,
pub port: u16,
pub key_path: PathBuf,
pub key_passphrase: Option<String>,
pub client_id: Option<String>,
}
type TargetRegistry = RwLock<HashMap<String, TargetConfig>>;
fn is_logging_disabled() -> bool {
std::env::var("SSH_EXEC_LOG_DISABLED")
.map(|v| v == "1" || v.to_lowercase() == "true")
.unwrap_or(false)
}
const MAX_CMD_LOG_LENGTH: usize = 500;
const MAX_OUTPUT_LOG_LENGTH: usize = 2000;
const REMOTE_LOG_PATH: &str = "~/.cnctd/ssh_exec.jsonl";
#[derive(Clone)]
pub struct SshService {
targets: Arc<TargetRegistry>,
}
impl Default for SshService {
fn default() -> Self {
Self::new()
}
}
impl SshService {
pub fn new() -> Self {
Self {
targets: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn global() -> Self {
Self {
targets: Arc::clone(&GLOBAL_TARGETS),
}
}
pub async fn register(&self, args: SshRegisterArgs) -> Result<SshRegisterResult, ServiceError> {
register_impl(&self.targets, args).await
}
pub async fn exec(&self, args: SshExecArgs) -> Result<SshExecResult, ServiceError> {
exec_impl(&self.targets, args).await
}
pub async fn unregister(&self, id: String) -> Result<SshUnregisterResult, ServiceError> {
unregister_impl(&self.targets, id).await
}
pub fn list_targets(&self) -> Result<Vec<String>, ServiceError> {
let map = self
.targets
.read()
.map_err(|e| ServiceError::Internal(format!("lock poisoned: {}", e)))?;
Ok(map.keys().cloned().collect())
}
pub fn has_target(&self, id: &str) -> Result<bool, ServiceError> {
let map = self
.targets
.read()
.map_err(|e| ServiceError::Internal(format!("lock poisoned: {}", e)))?;
Ok(map.contains_key(id))
}
}
static GLOBAL_TARGETS: Lazy<Arc<TargetRegistry>> =
Lazy::new(|| Arc::new(RwLock::new(HashMap::new())));
pub async fn ssh_register(args: SshRegisterArgs) -> Result<Value, ServiceError> {
let result = register_impl(&GLOBAL_TARGETS, args).await?;
Ok(json!(result))
}
pub async fn ssh_exec(args: SshExecArgs) -> Result<Value, ServiceError> {
let result = exec_impl(&GLOBAL_TARGETS, args).await?;
Ok(json!(result))
}
pub async fn ssh_unregister(id: String) -> Result<Value, ServiceError> {
let result = unregister_impl(&GLOBAL_TARGETS, id).await?;
Ok(json!(result))
}
pub async fn lookup_target(id: &str) -> Result<TargetConfig, ServiceError> {
let map = GLOBAL_TARGETS
.read()
.map_err(|e| ServiceError::Internal(format!("target registry lock poisoned: {}", e)))?;
map.get(id)
.cloned()
.ok_or_else(|| ServiceError::NotFound(format!("unknown target id: {}. Use register first.", id)))
}
async fn register_impl(
targets: &TargetRegistry,
args: SshRegisterArgs,
) -> Result<SshRegisterResult, ServiceError> {
info!("Registering SSH target: {}", args.id);
let key_path_str = std::env::var("SSH_KEY_PATH").map_err(|_| {
ServiceError::InvalidParams("SSH_KEY_PATH environment variable not set".to_string())
})?;
let key_path = expand_tilde(&key_path_str);
if !key_path.exists() {
return Err(ServiceError::InvalidParams(format!(
"key_path not found: {}",
key_path.display()
)));
}
if let Err(e) = fs::metadata(&key_path) {
return Err(ServiceError::InvalidParams(format!(
"cannot read key_path {}: {}",
key_path.display(),
e
)));
}
if args.key_passphrase.is_some() {
warn!(
"key_passphrase is not supported when using ssh command - key must be unencrypted or use ssh-agent"
);
}
let cfg = TargetConfig {
host: args.host.clone(),
user: args.user.clone(),
port: args.port,
key_path,
key_passphrase: args.key_passphrase,
client_id: args.client_id,
};
{
let mut map = targets
.write()
.map_err(|e| ServiceError::Internal(format!("target registry lock poisoned: {}", e)))?;
if map.contains_key(&args.id) {
return Err(ServiceError::InvalidParams(format!(
"target '{}' already exists. Use unregister first to replace it.",
args.id
)));
}
map.insert(args.id.clone(), cfg);
}
info!("Successfully registered SSH target: {}", args.id);
Ok(SshRegisterResult {
id: args.id,
host: args.host,
port: args.port,
user: args.user,
})
}
async fn exec_impl(
targets: &TargetRegistry,
args: SshExecArgs,
) -> Result<SshExecResult, ServiceError> {
info!("Executing command on target {}: {}", args.id, args.command);
let cfg = {
let map = targets
.read()
.map_err(|e| ServiceError::Internal(format!("target registry lock poisoned: {}", e)))?;
match map.get(&args.id) {
Some(c) => c.clone(),
None => {
return Err(ServiceError::NotFound(format!(
"unknown target id: {}. Use register first.",
args.id
)));
}
}
};
let command = args.command.clone();
let timeout = Duration::from_secs(args.timeout_secs.max(1));
let result = task::spawn_blocking(move || exec_over_ssh(&cfg, &command, timeout))
.await
.map_err(|e| ServiceError::Internal(format!("task join error: {}", e)))??;
info!(
"Command completed on {}: exit_code={}, duration={}ms",
args.id, result.exit_code, result.duration_ms
);
if !is_logging_disabled() {
let log_cfg = {
let map = targets
.read()
.map_err(|e| ServiceError::Internal(format!("lock poisoned: {}", e)))?;
map.get(&args.id).cloned()
};
if let Some(cfg) = log_cfg {
let log_entry = ExecLogEntry {
ts: Utc::now(),
target: args.id.clone(),
client: cfg.client_id.clone(),
ctx: args.context.clone(),
cmd: truncate_command(&args.command),
exit: result.exit_code,
dur_ms: result.duration_ms,
out: truncate_output(&result.stdout),
err: truncate_output(&result.stderr),
};
tokio::spawn(async move {
if let Err(e) = write_exec_log(&cfg, &log_entry).await {
warn!("Failed to write exec log: {}", e);
}
});
}
}
Ok(result)
}
async fn unregister_impl(
targets: &TargetRegistry,
id: String,
) -> Result<SshUnregisterResult, ServiceError> {
info!("Unregistering SSH target: {}", id);
let existed = {
let mut map = targets
.write()
.map_err(|e| ServiceError::Internal(format!("target registry lock poisoned: {}", e)))?;
map.remove(&id).is_some()
};
if existed {
info!("Successfully unregistered SSH target: {}", id);
} else {
warn!("Target {} was not registered", id);
}
Ok(SshUnregisterResult { id, existed })
}
fn exec_over_ssh(
cfg: &TargetConfig,
command: &str,
timeout: Duration,
) -> Result<SshExecResult, ServiceError> {
if let Err(e) = fs::metadata(&cfg.key_path) {
return Err(ServiceError::InvalidParams(format!(
"cannot read key_path {}: {}",
cfg.key_path.display(),
e
)));
}
let start = Instant::now();
let target = format!("{}@{}", cfg.user, cfg.host);
println!(
"[SSH] Executing: ssh -i {} -p {} {} '{}'",
cfg.key_path.display(),
cfg.port,
target,
command
);
let output = Command::new("ssh")
.arg("-i")
.arg(&cfg.key_path)
.arg("-p")
.arg(cfg.port.to_string())
.arg("-o")
.arg("BatchMode=yes") .arg("-o")
.arg("ConnectTimeout=30")
.arg("-o")
.arg("StrictHostKeyChecking=accept-new")
.arg("-o")
.arg("UserKnownHostsFile=/dev/null")
.arg(&target)
.arg(command)
.output()
.map_err(|e| ServiceError::Internal(format!("failed to execute ssh command: {}", e)))?;
let duration_ms = start.elapsed().as_millis();
if duration_ms > timeout.as_millis() {
warn!("SSH command exceeded timeout of {}s", timeout.as_secs());
}
let exit_code = output.status.code().unwrap_or(-1);
let stdout = String::from_utf8_lossy(&output.stdout).to_string();
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
println!(
"[SSH] Command completed: exit_code={}, stdout_len={}, stderr_len={}",
exit_code,
stdout.len(),
stderr.len()
);
if !stderr.is_empty() {
println!("[SSH] stderr: {}", stderr);
}
Ok(SshExecResult {
exec_id: uuid::Uuid::new_v4().to_string(),
exit_code,
stdout,
stderr,
duration_ms,
})
}
async fn write_exec_log(cfg: &TargetConfig, entry: &ExecLogEntry) -> Result<(), ServiceError> {
let json_line = serde_json::to_string(entry)
.map_err(|e| ServiceError::Internal(format!("failed to serialize log entry: {}", e)))?;
let escaped_json = json_line.replace("'", "'\"'\"'");
let log_command = format!(
"[ -f ~/.cnctd/ssh_exec_log_disable ] && exit 0; mkdir -p ~/.cnctd && echo '{}' >> {}",
escaped_json, REMOTE_LOG_PATH
);
let target = format!("{}@{}", cfg.user, cfg.host);
let key_path = cfg.key_path.clone();
let port = cfg.port;
task::spawn_blocking(move || {
let output = Command::new("ssh")
.arg("-i")
.arg(&key_path)
.arg("-p")
.arg(port.to_string())
.arg("-o")
.arg("BatchMode=yes")
.arg("-o")
.arg("ConnectTimeout=5") .arg("-o")
.arg("StrictHostKeyChecking=accept-new")
.arg("-o")
.arg("UserKnownHostsFile=/dev/null")
.arg(&target)
.arg(&log_command)
.output();
match output {
Ok(out) if out.status.success() => {
debug!("Exec log written successfully");
}
Ok(out) => {
let stderr = String::from_utf8_lossy(&out.stderr);
warn!("Exec log write returned non-zero: {}", stderr);
}
Err(e) => {
warn!("Failed to write exec log: {}", e);
}
}
})
.await
.map_err(|e| ServiceError::Internal(format!("log task join error: {}", e)))?;
Ok(())
}
fn truncate_command(cmd: &str) -> String {
if cmd.len() <= MAX_CMD_LOG_LENGTH {
cmd.to_string()
} else {
format!("{}...[truncated]", &cmd[..MAX_CMD_LOG_LENGTH])
}
}
fn truncate_output(output: &str) -> String {
if output.len() <= MAX_OUTPUT_LOG_LENGTH {
output.to_string()
} else {
format!(
"{}...[truncated, {} total]",
&output[..MAX_OUTPUT_LOG_LENGTH],
output.len()
)
}
}
fn default_port() -> u16 {
22
}
fn default_known_hosts() -> String {
"~/.ssh/known_hosts".into()
}
fn default_timeout_secs() -> u64 {
120
}
fn expand_tilde(p: &str) -> PathBuf {
if let Some(rest) = p.strip_prefix("~/") {
if let Ok(home) = std::env::var("HOME") {
return PathBuf::from(home).join(rest);
}
}
PathBuf::from(p)
}