use super::config::SshConfig;
use crate::error::{Result, SyncError};
use ssh2::Session;
use std::io::ErrorKind;
use std::net::TcpStream;
use std::time::Duration;
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
pub async fn connect(config: &SshConfig) -> Result<Session> {
let tcp = connect_tcp(&config.hostname, config.port).await?;
let username = config.user.clone();
let identity_files = config.identity_file.clone();
let session = tokio::task::spawn_blocking(move || {
let mut session = Session::new().map_err(|e| SyncError::Io(std::io::Error::other(format!("Failed to create SSH session: {}", e))))?;
session.set_timeout(DEFAULT_TIMEOUT.as_millis() as u32);
session.set_tcp_stream(tcp);
session.handshake().map_err(|e| SyncError::Io(std::io::Error::other(format!("SSH handshake failed: {}", e))))?;
session.set_keepalive(true, 60);
if let Ok(mut agent) = session.agent()
&& agent.connect().is_ok()
&& agent.list_identities().is_ok()
&& let Ok(identities) = agent.identities()
{
for identity in identities {
if agent.userauth(&username, &identity).is_ok() {
tracing::debug!("Authenticated using SSH agent");
return Ok(session);
}
}
}
for identity_file in &identity_files {
if session.userauth_pubkey_file(&username, None, identity_file, None).is_ok() {
tracing::debug!("Authenticated using key: {}", identity_file.display());
return Ok(session);
}
}
if identity_files.is_empty()
&& let Some(home) = dirs::home_dir()
{
let default_keys = [home.join(".ssh/id_rsa"), home.join(".ssh/id_ed25519"), home.join(".ssh/id_ecdsa")];
for key_path in &default_keys {
if key_path.exists() && session.userauth_pubkey_file(&username, None, key_path, None).is_ok() {
tracing::debug!("Authenticated using key: {}", key_path.display());
return Ok(session);
}
}
}
Err(SyncError::Io(std::io::Error::new(
ErrorKind::PermissionDenied,
format!("SSH authentication failed for user {}", username),
)))
})
.await
.map_err(|e| SyncError::Io(std::io::Error::other(e.to_string())))??;
Ok(session)
}
async fn connect_tcp(hostname: &str, port: u16) -> Result<TcpStream> {
let addr = format!("{}:{}", hostname, port);
tokio::time::timeout(DEFAULT_TIMEOUT, async {
TcpStream::connect(&addr).map_err(|e| SyncError::Io(std::io::Error::new(ErrorKind::ConnectionRefused, format!("Failed to connect to {}: {}", addr, e))))
})
.await
.map_err(|_| SyncError::Io(std::io::Error::new(ErrorKind::TimedOut, format!("Connection to {} timed out", addr))))?
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
#[test]
fn test_ssh_config_basic() {
let config = SshConfig {
hostname: "localhost".to_string(),
port: 22,
user: "testuser".to_string(),
identity_file: vec![PathBuf::from("~/.ssh/id_rsa")],
proxy_jump: None,
control_master: false,
control_path: None,
control_persist: None,
compression: false,
};
assert_eq!(config.hostname, "localhost");
assert_eq!(config.port, 22);
assert_eq!(config.user, "testuser");
}
}