use openssh::{ForwardType, KnownHosts, Session, SessionBuilder, Socket};
use std::fs::File;
use std::io::Write;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::PathBuf;
use thiserror::Error;
use crate::types::db_connection::SshSettings;
#[derive(Debug, Error)]
pub enum SshTunnelError {
#[error("io error: {0}")]
IoError(std::io::Error),
#[error("ssh error: {0}")]
SshError(openssh::Error),
}
impl From<std::io::Error> for SshTunnelError {
fn from(error: std::io::Error) -> Self {
SshTunnelError::IoError(error)
}
}
impl From<openssh::Error> for SshTunnelError {
fn from(error: openssh::Error) -> Self {
SshTunnelError::SshError(error)
}
}
async fn generate_temp_keyfile(private_key: &str) -> Result<PathBuf, SshTunnelError> {
let mut temp_keyfile_path = std::env::temp_dir();
temp_keyfile_path.push(format!("ssh_tunnel_key-{}.pem", uuid::Uuid::new_v4()));
let mut temp_keyfile = File::create(&temp_keyfile_path)?;
temp_keyfile.write_all(private_key.as_bytes())?;
let mut permissions = temp_keyfile.metadata()?.permissions();
std::os::unix::fs::PermissionsExt::set_mode(&mut permissions, 0o600);
temp_keyfile.set_permissions(permissions)?;
Ok(temp_keyfile_path)
}
pub async fn cleanup_tunnel(
session: Session,
path: &PathBuf,
port: u16,
server_port: u16,
) -> Result<(), SshTunnelError> {
let local = Socket::from(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port));
let remote = Socket::from(SocketAddr::new(
IpAddr::V4(Ipv4Addr::LOCALHOST),
server_port,
));
session
.close_port_forward(ForwardType::Local, local, remote)
.await?;
session.close().await?;
std::fs::remove_file(path)?;
Ok(())
}
pub async fn create_tunnel(
setting: SshSettings,
server_port: u16,
) -> Result<(Session, PathBuf, u16), SshTunnelError> {
let temp_keyfile = generate_temp_keyfile(&setting.private_key).await?;
let moved_keyfile = &temp_keyfile;
let mut builder = SessionBuilder::default();
builder
.user(setting.username)
.keyfile(moved_keyfile)
.known_hosts_check(KnownHosts::Accept);
if !setting.jump_servers.is_empty() {
builder.jump_hosts(setting.jump_servers);
}
let session = builder.connect(setting.host).await?;
session.check().await?;
let port = server_port; let local_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port);
let remote_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), server_port);
let local = Socket::from(local_addr);
let remote = Socket::from(remote_addr);
session
.request_port_forward(ForwardType::Local, local.clone(), remote.clone())
.await
.map_err(|e| {
eprintln!("Error establishing port forwarding: {e}");
e
})?;
println!("Port forwarding established on port {port}");
Ok((session, moved_keyfile.to_owned(), port))
}