use super::{ForwardingConfig, ForwardingStats, ForwardingStatus, ForwardingType};
use crate::ssh::tokio_client::Client;
use anyhow::{Context, Result};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::{Mutex, RwLock, mpsc};
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
pub type ForwardingId = Uuid;
#[derive(Debug)]
pub enum ForwardingMessage {
StatusUpdate {
id: ForwardingId,
status: ForwardingStatus,
},
StatsUpdate {
id: ForwardingId,
stats: ForwardingStats,
},
SessionTerminated {
id: ForwardingId,
reason: Option<String>,
},
}
#[derive(Debug)]
#[allow(dead_code)] struct ForwardingSession {
id: ForwardingId,
spec: ForwardingType,
status: ForwardingStatus,
stats: ForwardingStats,
task_handle: Option<JoinHandle<Result<()>>>,
cancel_token: CancellationToken,
created_at: Instant,
updated_at: Instant,
}
pub struct ForwardingManager {
config: ForwardingConfig,
sessions: Arc<RwLock<HashMap<ForwardingId, Arc<Mutex<ForwardingSession>>>>>,
message_tx: mpsc::UnboundedSender<ForwardingMessage>,
message_rx: Arc<Mutex<mpsc::UnboundedReceiver<ForwardingMessage>>>,
shutdown_token: CancellationToken,
manager_task: Option<JoinHandle<()>>,
}
impl ForwardingManager {
pub fn new(config: ForwardingConfig) -> Self {
let (message_tx, message_rx) = mpsc::unbounded_channel();
Self {
config,
sessions: Arc::new(RwLock::new(HashMap::new())),
message_tx,
message_rx: Arc::new(Mutex::new(message_rx)),
shutdown_token: CancellationToken::new(),
manager_task: None,
}
}
pub async fn start(&mut self) -> Result<()> {
if self.manager_task.is_some() {
return Err(anyhow::anyhow!("ForwardingManager is already started"));
}
let sessions = Arc::clone(&self.sessions);
let message_rx = Arc::clone(&self.message_rx);
let shutdown_token = self.shutdown_token.clone();
let task = tokio::spawn(async move {
Self::message_loop(sessions, message_rx, shutdown_token).await;
});
self.manager_task = Some(task);
tracing::info!("ForwardingManager started");
Ok(())
}
async fn message_loop(
sessions: Arc<RwLock<HashMap<ForwardingId, Arc<Mutex<ForwardingSession>>>>>,
message_rx: Arc<Mutex<mpsc::UnboundedReceiver<ForwardingMessage>>>,
shutdown_token: CancellationToken,
) {
let mut rx = message_rx.lock().await;
loop {
tokio::select! {
msg = rx.recv() => {
match msg {
Some(message) => {
if let Err(e) = Self::handle_message(&sessions, message).await {
tracing::error!("Error handling forwarding message: {}", e);
}
}
None => {
tracing::info!("Message channel closed, stopping manager");
break;
}
}
}
_ = shutdown_token.cancelled() => {
tracing::info!("Shutdown requested, stopping manager");
break;
}
}
}
tracing::info!("ForwardingManager message loop stopped");
}
#[allow(clippy::type_complexity)] async fn handle_message(
sessions: &Arc<RwLock<HashMap<ForwardingId, Arc<Mutex<ForwardingSession>>>>>,
message: ForwardingMessage,
) -> Result<()> {
match message {
ForwardingMessage::StatusUpdate { id, status } => {
let sessions_read = sessions.read().await;
if let Some(session_arc) = sessions_read.get(&id) {
let mut session = session_arc.lock().await;
session.status = status;
session.updated_at = Instant::now();
tracing::debug!("Updated status for forwarding {}: {}", id, session.status);
}
}
ForwardingMessage::StatsUpdate { id, stats } => {
let sessions_read = sessions.read().await;
if let Some(session_arc) = sessions_read.get(&id) {
let mut session = session_arc.lock().await;
session.stats = stats;
session.updated_at = Instant::now();
}
}
ForwardingMessage::SessionTerminated { id, reason } => {
let sessions_write = sessions.write().await;
if let Some(session_arc) = sessions_write.get(&id) {
let mut session = session_arc.lock().await;
session.status = if let Some(err) = reason {
ForwardingStatus::Failed(err)
} else {
ForwardingStatus::Stopped
};
session.updated_at = Instant::now();
session.task_handle = None; tracing::info!("Forwarding session {} terminated: {}", id, session.status);
}
}
}
Ok(())
}
pub async fn add_forwarding(&mut self, spec: ForwardingType) -> Result<ForwardingId> {
super::spec::ForwardingSpec::validate(&spec)
.with_context(|| "Invalid forwarding specification")?;
let id = Uuid::new_v4();
let now = Instant::now();
let session = ForwardingSession {
id,
spec,
status: ForwardingStatus::Initializing,
stats: ForwardingStats::default(),
task_handle: None,
cancel_token: CancellationToken::new(),
created_at: now,
updated_at: now,
};
let mut sessions = self.sessions.write().await;
sessions.insert(id, Arc::new(Mutex::new(session)));
tracing::info!("Added forwarding session {}", id);
Ok(id)
}
pub async fn start_forwarding(&self, id: ForwardingId, ssh_client: Arc<Client>) -> Result<()> {
let sessions = self.sessions.read().await;
let session_arc = sessions
.get(&id)
.ok_or_else(|| anyhow::anyhow!("Forwarding session {id} not found"))?;
let mut session = session_arc.lock().await;
if session.task_handle.is_some() {
return Err(anyhow::anyhow!(
"Forwarding session {id} is already started"
));
}
let session_id = session.id;
let spec = session.spec.clone();
let config = self.config.clone();
let cancel_token = session.cancel_token.clone();
let message_tx = self.message_tx.clone();
let task = match &spec {
ForwardingType::Local { .. } => tokio::spawn(async move {
super::local::LocalForwarder::run(
session_id,
spec.clone(),
ssh_client,
config,
cancel_token,
message_tx,
)
.await
}),
ForwardingType::Remote { .. } => tokio::spawn(async move {
super::remote::RemoteForwarder::run(
session_id,
spec.clone(),
ssh_client,
config,
cancel_token,
message_tx,
)
.await
}),
ForwardingType::Dynamic { .. } => tokio::spawn(async move {
super::dynamic::DynamicForwarder::run(
session_id,
spec.clone(),
ssh_client,
config,
cancel_token,
message_tx,
)
.await
}),
};
session.task_handle = Some(task);
session.status = ForwardingStatus::Initializing;
session.updated_at = Instant::now();
tracing::info!("Started forwarding session {}", id);
Ok(())
}
pub async fn start_all(&self, ssh_client: Arc<Client>) -> Result<()> {
let sessions = self.sessions.read().await;
let ids: Vec<ForwardingId> = sessions.keys().copied().collect();
drop(sessions);
for id in ids {
if let Err(e) = self.start_forwarding(id, Arc::clone(&ssh_client)).await {
tracing::error!("Failed to start forwarding session {}: {}", id, e);
}
}
Ok(())
}
pub async fn stop_forwarding(&self, id: ForwardingId) -> Result<()> {
let sessions = self.sessions.read().await;
let session_arc = sessions
.get(&id)
.ok_or_else(|| anyhow::anyhow!("Forwarding session {id} not found"))?;
let mut session = session_arc.lock().await;
session.cancel_token.cancel();
if let Some(task) = session.task_handle.take() {
let _ = task.await; }
session.status = ForwardingStatus::Stopped;
session.updated_at = Instant::now();
tracing::info!("Stopped forwarding session {}", id);
Ok(())
}
pub async fn stop_all(&self) -> Result<()> {
let sessions = self.sessions.read().await;
let ids: Vec<ForwardingId> = sessions.keys().copied().collect();
drop(sessions);
for id in ids {
if let Err(e) = self.stop_forwarding(id).await {
tracing::error!("Failed to stop forwarding session {}: {}", id, e);
}
}
Ok(())
}
pub async fn get_status(&self, id: ForwardingId) -> Result<ForwardingStatus> {
let sessions = self.sessions.read().await;
let session_arc = sessions
.get(&id)
.ok_or_else(|| anyhow::anyhow!("Forwarding session {id} not found"))?;
let session = session_arc.lock().await;
Ok(session.status.clone())
}
pub async fn get_stats(&self, id: ForwardingId) -> Result<ForwardingStats> {
let sessions = self.sessions.read().await;
let session_arc = sessions
.get(&id)
.ok_or_else(|| anyhow::anyhow!("Forwarding session {id} not found"))?;
let session = session_arc.lock().await;
Ok(session.stats.clone())
}
pub async fn list_sessions(&self) -> HashMap<ForwardingId, (ForwardingType, ForwardingStatus)> {
let sessions = self.sessions.read().await;
let mut result = HashMap::new();
for (id, session_arc) in sessions.iter() {
if let Ok(session) = session_arc.try_lock() {
result.insert(*id, (session.spec.clone(), session.status.clone()));
}
}
result
}
pub async fn remove_forwarding(&mut self, id: ForwardingId) -> Result<()> {
let _ = self.stop_forwarding(id).await;
let mut sessions = self.sessions.write().await;
sessions
.remove(&id)
.ok_or_else(|| anyhow::anyhow!("Forwarding session {id} not found"))?;
tracing::info!("Removed forwarding session {}", id);
Ok(())
}
pub async fn shutdown(&mut self) -> Result<()> {
tracing::info!("Shutting down ForwardingManager");
self.stop_all().await?;
self.shutdown_token.cancel();
if let Some(task) = self.manager_task.take() {
let _ = task.await;
}
let mut sessions = self.sessions.write().await;
sessions.clear();
tracing::info!("ForwardingManager shutdown complete");
Ok(())
}
}
impl Drop for ForwardingManager {
fn drop(&mut self) {
self.shutdown_token.cancel();
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr};
#[tokio::test]
async fn test_manager_lifecycle() {
let mut manager = ForwardingManager::new(ForwardingConfig::default());
assert!(manager.start().await.is_ok());
let spec = ForwardingType::Local {
bind_addr: IpAddr::V4(Ipv4Addr::LOCALHOST),
bind_port: 8080,
remote_host: "example.com".to_string(),
remote_port: 80,
};
let id = manager.add_forwarding(spec.clone()).await.unwrap();
let status = manager.get_status(id).await.unwrap();
assert_eq!(status, ForwardingStatus::Initializing);
let sessions = manager.list_sessions().await;
assert_eq!(sessions.len(), 1);
assert!(sessions.contains_key(&id));
manager.remove_forwarding(id).await.unwrap();
let sessions = manager.list_sessions().await;
assert_eq!(sessions.len(), 0);
assert!(manager.shutdown().await.is_ok());
}
#[tokio::test]
async fn test_invalid_forwarding_spec() {
let mut manager = ForwardingManager::new(ForwardingConfig::default());
let invalid_spec = ForwardingType::Local {
bind_addr: IpAddr::V4(Ipv4Addr::LOCALHOST),
bind_port: 0, remote_host: "example.com".to_string(),
remote_port: 80,
};
let result = manager.add_forwarding(invalid_spec).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_duplicate_start() {
let mut manager = ForwardingManager::new(ForwardingConfig::default());
assert!(manager.start().await.is_ok());
assert!(manager.start().await.is_err());
let _ = manager.shutdown().await;
}
}