use std::path::PathBuf;
use std::time::Duration;
use super::auth::{AuthMethod, HostKeyVerification, SshCredentials};
use super::session::{SshConfig, SshSession};
#[derive(Debug, Default)]
pub struct SshSessionBuilder {
host: Option<String>,
port: u16,
username: Option<String>,
auth_methods: Vec<AuthMethod>,
connect_timeout: Duration,
host_key_verification: HostKeyVerification,
compression: bool,
tcp_keepalive: Option<Duration>,
}
impl SshSessionBuilder {
#[must_use]
pub fn new() -> Self {
Self {
port: 22,
connect_timeout: Duration::from_secs(30),
host_key_verification: HostKeyVerification::KnownHosts,
tcp_keepalive: Some(Duration::from_secs(60)),
..Default::default()
}
}
#[must_use]
pub fn host(mut self, host: impl Into<String>) -> Self {
self.host = Some(host.into());
self
}
#[must_use]
pub const fn port(mut self, port: u16) -> Self {
self.port = port;
self
}
#[must_use]
pub fn username(mut self, username: impl Into<String>) -> Self {
self.username = Some(username.into());
self
}
#[must_use]
pub fn password(mut self, password: impl Into<String>) -> Self {
self.auth_methods
.push(AuthMethod::Password(password.into()));
self
}
#[must_use]
pub fn private_key(mut self, path: impl Into<PathBuf>) -> Self {
self.auth_methods.push(AuthMethod::PublicKey {
private_key: path.into(),
passphrase: None,
});
self
}
#[must_use]
pub fn private_key_with_passphrase(
mut self,
path: impl Into<PathBuf>,
passphrase: impl Into<String>,
) -> Self {
self.auth_methods.push(AuthMethod::PublicKey {
private_key: path.into(),
passphrase: Some(passphrase.into()),
});
self
}
#[must_use]
pub fn agent(mut self) -> Self {
self.auth_methods.push(AuthMethod::Agent);
self
}
#[must_use]
pub const fn connect_timeout(mut self, timeout: Duration) -> Self {
self.connect_timeout = timeout;
self
}
#[must_use]
pub const fn host_key_verification(mut self, policy: HostKeyVerification) -> Self {
self.host_key_verification = policy;
self
}
#[cfg(feature = "insecure-skip-verify")]
#[must_use]
pub const fn accept_all_keys(mut self) -> Self {
self.host_key_verification = HostKeyVerification::AcceptAll;
self
}
#[must_use]
pub const fn compression(mut self, enabled: bool) -> Self {
self.compression = enabled;
self
}
#[must_use]
pub const fn tcp_keepalive(mut self, interval: Option<Duration>) -> Self {
self.tcp_keepalive = interval;
self
}
pub fn build(self) -> crate::error::Result<SshSession> {
let host = self
.host
.ok_or_else(|| crate::error::ExpectError::config("SSH host is required"))?;
let username = self.username.unwrap_or_else(|| {
std::env::var("USER")
.or_else(|_| std::env::var("USERNAME"))
.unwrap_or_else(|_| "root".to_string())
});
let mut credentials = SshCredentials::new(username);
for method in self.auth_methods {
credentials = credentials.with_auth(method);
}
if credentials.auth_methods.is_empty() {
credentials = credentials.with_defaults();
}
let config = SshConfig {
host,
port: self.port,
credentials,
connect_timeout: self.connect_timeout,
host_key_verification: self.host_key_verification,
compression: self.compression,
tcp_keepalive: self.tcp_keepalive,
};
Ok(SshSession::new(config))
}
pub fn connect(self) -> crate::error::Result<SshSession> {
let mut session = self.build()?;
session.connect()?;
Ok(session)
}
}
#[must_use]
pub fn parse_ssh_target(target: &str) -> (Option<String>, String, u16) {
let (user_part, rest) = if let Some(at_pos) = target.find('@') {
(Some(target[..at_pos].to_string()), &target[at_pos + 1..])
} else {
(None, target)
};
let (host, port) = if let Some(colon_pos) = rest.rfind(':') {
let port_str = &rest[colon_pos + 1..];
if let Ok(port) = port_str.parse() {
(rest[..colon_pos].to_string(), port)
} else {
(rest.to_string(), 22)
}
} else {
(rest.to_string(), 22)
};
(user_part, host, port)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn builder_basic() {
let session = SshSessionBuilder::new()
.host("example.com")
.username("user")
.password("pass")
.build()
.unwrap();
assert_eq!(session.config().host, "example.com");
assert_eq!(session.config().credentials.username, "user");
}
#[test]
fn parse_target_full() {
let (user, host, port) = parse_ssh_target("admin@server.com:2222");
assert_eq!(user, Some("admin".to_string()));
assert_eq!(host, "server.com");
assert_eq!(port, 2222);
}
#[test]
fn parse_target_simple() {
let (user, host, port) = parse_ssh_target("server.com");
assert_eq!(user, None);
assert_eq!(host, "server.com");
assert_eq!(port, 22);
}
}