use std::path::PathBuf;
#[derive(Debug, Clone)]
pub enum AuthMethod {
Password(String),
PublicKey {
private_key: PathBuf,
passphrase: Option<String>,
},
Agent,
KeyboardInteractive {
responses: Vec<String>,
},
None,
}
impl AuthMethod {
#[must_use]
pub fn password(password: impl Into<String>) -> Self {
Self::Password(password.into())
}
#[must_use]
pub fn public_key(private_key: impl Into<PathBuf>) -> Self {
Self::PublicKey {
private_key: private_key.into(),
passphrase: None,
}
}
#[must_use]
pub fn public_key_with_passphrase(
private_key: impl Into<PathBuf>,
passphrase: impl Into<String>,
) -> Self {
Self::PublicKey {
private_key: private_key.into(),
passphrase: Some(passphrase.into()),
}
}
#[must_use]
pub const fn agent() -> Self {
Self::Agent
}
#[must_use]
pub const fn keyboard_interactive(responses: Vec<String>) -> Self {
Self::KeyboardInteractive { responses }
}
#[must_use]
pub fn keyboard_interactive_password(password: impl Into<String>) -> Self {
Self::KeyboardInteractive {
responses: vec![password.into()],
}
}
#[must_use]
pub const fn is_keyboard_interactive(&self) -> bool {
matches!(self, Self::KeyboardInteractive { .. })
}
#[must_use]
pub const fn is_password(&self) -> bool {
matches!(self, Self::Password(_))
}
#[must_use]
pub const fn is_public_key(&self) -> bool {
matches!(self, Self::PublicKey { .. })
}
}
#[derive(Debug, Clone)]
pub struct SshCredentials {
pub username: String,
pub auth_methods: Vec<AuthMethod>,
}
impl SshCredentials {
#[must_use]
pub fn new(username: impl Into<String>) -> Self {
Self {
username: username.into(),
auth_methods: Vec::new(),
}
}
#[must_use]
pub fn with_auth(mut self, method: AuthMethod) -> Self {
self.auth_methods.push(method);
self
}
#[must_use]
pub fn with_password(self, password: impl Into<String>) -> Self {
self.with_auth(AuthMethod::password(password))
}
#[must_use]
pub fn with_key(self, private_key: impl Into<PathBuf>) -> Self {
self.with_auth(AuthMethod::public_key(private_key))
}
#[must_use]
pub fn with_key_passphrase(
self,
private_key: impl Into<PathBuf>,
passphrase: impl Into<String>,
) -> Self {
self.with_auth(AuthMethod::public_key_with_passphrase(
private_key,
passphrase,
))
}
#[must_use]
pub fn with_agent(self) -> Self {
self.with_auth(AuthMethod::Agent)
}
#[must_use]
pub fn with_keyboard_interactive(self, password: impl Into<String>) -> Self {
self.with_auth(AuthMethod::keyboard_interactive_password(password))
}
#[must_use]
pub fn with_keyboard_interactive_responses(self, responses: Vec<String>) -> Self {
self.with_auth(AuthMethod::keyboard_interactive(responses))
}
#[must_use]
pub fn with_defaults(self) -> Self {
let home = std::env::var("HOME").unwrap_or_default();
self.with_agent()
.with_key(format!("{home}/.ssh/id_ed25519"))
.with_key(format!("{home}/.ssh/id_rsa"))
}
}
impl Default for SshCredentials {
fn default() -> Self {
let username = std::env::var("USER")
.or_else(|_| std::env::var("USERNAME"))
.unwrap_or_else(|_| "root".to_string());
Self::new(username)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
#[derive(Default)]
pub enum HostKeyVerification {
#[cfg(feature = "insecure-skip-verify")]
AcceptAll,
RejectUnknown,
#[default]
KnownHosts,
Tofu,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn auth_method_password() {
let auth = AuthMethod::password("secret");
assert!(auth.is_password());
assert!(!auth.is_public_key());
assert!(!auth.is_keyboard_interactive());
}
#[test]
fn auth_method_keyboard_interactive() {
let auth = AuthMethod::keyboard_interactive_password("secret");
assert!(auth.is_keyboard_interactive());
assert!(!auth.is_password());
assert!(!auth.is_public_key());
if let AuthMethod::KeyboardInteractive { responses } = auth {
assert_eq!(responses.len(), 1);
assert_eq!(responses[0], "secret");
} else {
panic!("Expected KeyboardInteractive variant");
}
}
#[test]
fn auth_method_keyboard_interactive_multi_response() {
let auth =
AuthMethod::keyboard_interactive(vec!["password".to_string(), "123456".to_string()]);
assert!(auth.is_keyboard_interactive());
if let AuthMethod::KeyboardInteractive { responses } = auth {
assert_eq!(responses.len(), 2);
assert_eq!(responses[0], "password");
assert_eq!(responses[1], "123456");
} else {
panic!("Expected KeyboardInteractive variant");
}
}
#[test]
fn credentials_builder() {
let creds = SshCredentials::new("user")
.with_password("pass")
.with_agent();
assert_eq!(creds.username, "user");
assert_eq!(creds.auth_methods.len(), 2);
}
#[test]
fn credentials_keyboard_interactive() {
let creds = SshCredentials::new("user").with_keyboard_interactive("password");
assert_eq!(creds.username, "user");
assert_eq!(creds.auth_methods.len(), 1);
assert!(creds.auth_methods[0].is_keyboard_interactive());
}
#[test]
fn credentials_keyboard_interactive_multi_response() {
let creds = SshCredentials::new("user").with_keyboard_interactive_responses(vec![
"password".to_string(),
"otp_code".to_string(),
]);
assert_eq!(creds.username, "user");
assert_eq!(creds.auth_methods.len(), 1);
if let AuthMethod::KeyboardInteractive { responses } = &creds.auth_methods[0] {
assert_eq!(responses.len(), 2);
} else {
panic!("Expected KeyboardInteractive variant");
}
}
#[test]
fn credentials_multiple_auth_methods() {
let creds = SshCredentials::new("user")
.with_agent()
.with_keyboard_interactive("password")
.with_password("fallback");
assert_eq!(creds.auth_methods.len(), 3);
assert!(matches!(creds.auth_methods[0], AuthMethod::Agent));
assert!(creds.auth_methods[1].is_keyboard_interactive());
assert!(creds.auth_methods[2].is_password());
}
}