use crate::core::error::{Error, Result};
use blueprint_core::{info, warn};
use std::process::Stdio;
use tokio::process::{Child, Command};
pub struct QosTunnel {
process: Option<Child>,
local_port: u16,
remote_host: String,
remote_port: u16,
ssh_user: String,
ssh_key_path: Option<String>,
}
impl QosTunnel {
pub fn new(
local_port: u16,
remote_host: String,
remote_port: u16,
ssh_user: String,
ssh_key_path: Option<String>,
) -> Self {
Self {
process: None,
local_port,
remote_host,
remote_port,
ssh_user,
ssh_key_path,
}
}
pub async fn connect(&mut self) -> Result<()> {
info!(
"Creating SSH tunnel for QoS metrics: localhost:{} -> {}@{}:{}",
self.local_port, self.ssh_user, self.remote_host, self.remote_port
);
let mut cmd = Command::new("ssh");
cmd.arg("-N") .arg("-L")
.arg(format!(
"{}:localhost:{}",
self.local_port, self.remote_port
))
.arg(format!("{}@{}", self.ssh_user, self.remote_host))
.arg("-o")
.arg("StrictHostKeyChecking=accept-new")
.arg("-o")
.arg("ServerAliveInterval=30")
.arg("-o")
.arg("ServerAliveCountMax=3");
if let Some(ref key_path) = self.ssh_key_path {
cmd.arg("-i").arg(key_path);
}
cmd.stdin(Stdio::null())
.stdout(Stdio::null())
.stderr(Stdio::null());
let child = cmd
.spawn()
.map_err(|e| Error::ConfigurationError(format!("Failed to start SSH tunnel: {e}")))?;
self.process = Some(child);
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
match tokio::net::TcpStream::connect(format!("127.0.0.1:{}", self.local_port)).await {
Ok(_) => {
info!(
"QoS tunnel established successfully on localhost:{}",
self.local_port
);
Ok(())
}
Err(e) => {
warn!("QoS tunnel may not be ready yet: {}", e);
Ok(())
}
}
}
pub fn get_local_endpoint(&self) -> String {
format!("http://127.0.0.1:{}", self.local_port)
}
pub async fn disconnect(&mut self) -> Result<()> {
if let Some(mut process) = self.process.take() {
info!("Closing QoS tunnel on localhost:{}", self.local_port);
if let Err(e) = process.kill().await {
warn!("Failed to kill SSH tunnel process: {}", e);
}
let _ = process.wait().await;
}
Ok(())
}
pub async fn is_active(&self) -> bool {
tokio::net::TcpStream::connect(format!("127.0.0.1:{}", self.local_port))
.await
.is_ok()
}
}
impl Drop for QosTunnel {
fn drop(&mut self) {
if let Some(mut process) = self.process.take() {
let _ = process.start_kill();
}
}
}
pub struct QosTunnelManager {
tunnels: Vec<QosTunnel>,
next_local_port: u16,
}
impl QosTunnelManager {
pub fn new(starting_port: u16) -> Self {
Self {
tunnels: Vec::new(),
next_local_port: starting_port,
}
}
pub async fn create_tunnel(
&mut self,
remote_host: String,
ssh_user: String,
ssh_key_path: Option<String>,
) -> Result<String> {
let local_port = self.next_local_port;
self.next_local_port += 1;
let mut tunnel = QosTunnel::new(
local_port,
remote_host,
9615, ssh_user,
ssh_key_path,
);
tunnel.connect().await?;
let endpoint = tunnel.get_local_endpoint();
self.tunnels.push(tunnel);
Ok(endpoint)
}
pub async fn close_all(&mut self) -> Result<()> {
for mut tunnel in self.tunnels.drain(..) {
tunnel.disconnect().await?;
}
Ok(())
}
pub fn active_count(&self) -> usize {
self.tunnels.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tunnel_configuration() {
let tunnel = QosTunnel::new(
19615,
"remote-host.example.com".to_string(),
9615,
"ubuntu".to_string(),
Some("/path/to/key".to_string()),
);
assert_eq!(tunnel.local_port, 19615);
assert_eq!(tunnel.remote_port, 9615);
assert_eq!(tunnel.get_local_endpoint(), "http://127.0.0.1:19615");
}
#[tokio::test]
async fn test_tunnel_manager() {
let manager = QosTunnelManager::new(20000);
assert_eq!(manager.active_count(), 0);
assert_eq!(manager.next_local_port, 20000);
}
}