use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
use std::process::Command;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use tokio::task;
use tracing::{debug, info, warn};
use crate::service_error::ServiceError;
#[derive(Debug, Clone, Deserialize)]
pub struct SshRegisterArgs {
pub id: String,
pub host: String,
pub user: String,
#[serde(default = "default_port")]
pub port: u16,
pub key_passphrase: Option<String>,
#[serde(default = "default_known_hosts")]
pub known_hosts_path: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct SshExecArgs {
pub id: String,
pub command: String,
#[serde(default = "default_timeout_secs")]
pub timeout_secs: u64,
}
#[derive(Debug, Clone, Serialize)]
pub struct SshExecResult {
pub exec_id: String,
pub exit_code: i32,
pub stdout: String,
pub stderr: String,
pub duration_ms: u128,
}
#[derive(Debug, Clone, Serialize)]
pub struct SshRegisterResult {
pub id: String,
pub host: String,
pub port: u16,
pub user: String,
}
#[derive(Debug, Clone, Serialize)]
pub struct SshUnregisterResult {
pub id: String,
pub existed: bool,
}
#[derive(Debug, Clone)]
struct TargetConfig {
host: String,
user: String,
port: u16,
key_path: PathBuf,
key_passphrase: Option<String>,
}
type TargetRegistry = RwLock<HashMap<String, TargetConfig>>;
#[derive(Clone)]
pub struct SshService {
targets: Arc<TargetRegistry>,
}
impl Default for SshService {
fn default() -> Self {
Self::new()
}
}
impl SshService {
pub fn new() -> Self {
Self {
targets: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn global() -> Self {
Self {
targets: Arc::clone(&GLOBAL_TARGETS),
}
}
pub async fn register(&self, args: SshRegisterArgs) -> Result<SshRegisterResult, ServiceError> {
register_impl(&self.targets, args).await
}
pub async fn exec(&self, args: SshExecArgs) -> Result<SshExecResult, ServiceError> {
exec_impl(&self.targets, args).await
}
pub async fn unregister(&self, id: String) -> Result<SshUnregisterResult, ServiceError> {
unregister_impl(&self.targets, id).await
}
pub fn list_targets(&self) -> Result<Vec<String>, ServiceError> {
let map = self.targets.read()
.map_err(|e| ServiceError::Internal(format!("lock poisoned: {}", e)))?;
Ok(map.keys().cloned().collect())
}
pub fn has_target(&self, id: &str) -> Result<bool, ServiceError> {
let map = self.targets.read()
.map_err(|e| ServiceError::Internal(format!("lock poisoned: {}", e)))?;
Ok(map.contains_key(id))
}
}
static GLOBAL_TARGETS: Lazy<Arc<TargetRegistry>> =
Lazy::new(|| Arc::new(RwLock::new(HashMap::new())));
pub async fn ssh_register(args: SshRegisterArgs) -> Result<Value, ServiceError> {
let result = register_impl(&GLOBAL_TARGETS, args).await?;
Ok(json!(result))
}
pub async fn ssh_exec(args: SshExecArgs) -> Result<Value, ServiceError> {
let result = exec_impl(&GLOBAL_TARGETS, args).await?;
Ok(json!(result))
}
pub async fn ssh_unregister(id: String) -> Result<Value, ServiceError> {
let result = unregister_impl(&GLOBAL_TARGETS, id).await?;
Ok(json!(result))
}
async fn register_impl(
targets: &TargetRegistry,
args: SshRegisterArgs,
) -> Result<SshRegisterResult, ServiceError> {
info!("Registering SSH target: {}", args.id);
let key_path_str = std::env::var("SSH_KEY_PATH")
.map_err(|_| ServiceError::InvalidParams(
"SSH_KEY_PATH environment variable not set".to_string(),
))?;
let key_path = expand_tilde(&key_path_str);
if !key_path.exists() {
return Err(ServiceError::InvalidParams(format!(
"key_path not found: {}", key_path.display()
)));
}
if let Err(e) = fs::metadata(&key_path) {
return Err(ServiceError::InvalidParams(format!(
"cannot read key_path {}: {}", key_path.display(), e
)));
}
if args.key_passphrase.is_some() {
warn!("key_passphrase is not supported when using ssh command - key must be unencrypted or use ssh-agent");
}
let cfg = TargetConfig {
host: args.host.clone(),
user: args.user.clone(),
port: args.port,
key_path,
key_passphrase: args.key_passphrase,
};
{
let mut map = targets.write()
.map_err(|e| ServiceError::Internal(format!("target registry lock poisoned: {}", e)))?;
if map.contains_key(&args.id) {
return Err(ServiceError::InvalidParams(format!(
"target '{}' already exists. Use unregister first to replace it.",
args.id
)));
}
map.insert(args.id.clone(), cfg);
}
info!("Successfully registered SSH target: {}", args.id);
Ok(SshRegisterResult {
id: args.id,
host: args.host,
port: args.port,
user: args.user,
})
}
async fn exec_impl(
targets: &TargetRegistry,
args: SshExecArgs,
) -> Result<SshExecResult, ServiceError> {
info!("Executing command on target {}: {}", args.id, args.command);
let cfg = {
let map = targets.read()
.map_err(|e| ServiceError::Internal(format!("target registry lock poisoned: {}", e)))?;
match map.get(&args.id) {
Some(c) => c.clone(),
None => {
return Err(ServiceError::NotFound(format!(
"unknown target id: {}. Use register first.",
args.id
)))
}
}
};
let command = args.command.clone();
let timeout = Duration::from_secs(args.timeout_secs.max(1));
let result = task::spawn_blocking(move || exec_over_ssh(&cfg, &command, timeout))
.await
.map_err(|e| ServiceError::Internal(format!("task join error: {}", e)))??;
info!(
"Command completed on {}: exit_code={}, duration={}ms",
args.id, result.exit_code, result.duration_ms
);
Ok(result)
}
async fn unregister_impl(
targets: &TargetRegistry,
id: String,
) -> Result<SshUnregisterResult, ServiceError> {
info!("Unregistering SSH target: {}", id);
let existed = {
let mut map = targets.write()
.map_err(|e| ServiceError::Internal(format!("target registry lock poisoned: {}", e)))?;
map.remove(&id).is_some()
};
if existed {
info!("Successfully unregistered SSH target: {}", id);
} else {
warn!("Target {} was not registered", id);
}
Ok(SshUnregisterResult { id, existed })
}
fn exec_over_ssh(
cfg: &TargetConfig,
command: &str,
timeout: Duration,
) -> Result<SshExecResult, ServiceError> {
if let Err(e) = fs::metadata(&cfg.key_path) {
return Err(ServiceError::InvalidParams(format!(
"cannot read key_path {}: {}", cfg.key_path.display(), e
)));
}
let start = Instant::now();
let target = format!("{}@{}", cfg.user, cfg.host);
debug!("Executing SSH command: ssh -i {} -p {} {} {}",
cfg.key_path.display(), cfg.port, target, command);
let output = Command::new("ssh")
.arg("-i")
.arg(&cfg.key_path)
.arg("-p")
.arg(cfg.port.to_string())
.arg("-o")
.arg("BatchMode=yes") .arg("-o")
.arg("ConnectTimeout=30")
.arg(&target)
.arg(command)
.output()
.map_err(|e| ServiceError::Internal(format!(
"failed to execute ssh command: {}", e
)))?;
let duration_ms = start.elapsed().as_millis();
if duration_ms > timeout.as_millis() {
warn!("SSH command exceeded timeout of {}s", timeout.as_secs());
}
let exit_code = output.status.code().unwrap_or(-1);
let stdout = String::from_utf8_lossy(&output.stdout).to_string();
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
debug!("SSH command completed: exit_code={}, stdout_len={}, stderr_len={}",
exit_code, stdout.len(), stderr.len());
Ok(SshExecResult {
exec_id: uuid::Uuid::new_v4().to_string(),
exit_code,
stdout,
stderr,
duration_ms,
})
}
fn default_port() -> u16 {
22
}
fn default_known_hosts() -> String {
"~/.ssh/known_hosts".into()
}
fn default_timeout_secs() -> u64 {
120
}
fn expand_tilde(p: &str) -> PathBuf {
if let Some(rest) = p.strip_prefix("~/") {
if let Ok(home) = std::env::var("HOME") {
return PathBuf::from(home).join(rest);
}
}
PathBuf::from(p)
}