use std::fmt;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::{Arc, Mutex as StdMutex};
use argon2::password_hash::SaltString;
use argon2::{Argon2, PasswordHash, PasswordHasher, PasswordVerifier};
use russh::keys::ssh_key::{HashAlg, PublicKey};
use tracing::{info, warn};
use crate::error::AuthError;
use crate::config::{HostKeyPolicy, SecurityConfig, StateConfig};
use crate::error::Result;
use crate::storage::trust::write_authorized_client;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum AuthMethod {
PublicKey,
Password,
}
pub trait Authenticator: Send + Sync + fmt::Debug + 'static {
fn supported_methods(&self) -> Vec<AuthMethod>;
fn check_public_key(&self, user: &str, key: &PublicKey) -> Result<bool>;
fn check_password(&self, user: &str, password: &str) -> Result<bool>;
}
#[derive(Debug)]
pub struct KeyOnlyAuth {
policy: HostKeyPolicy,
authorized_keys: Arc<StdMutex<Vec<PublicKey>>>,
state: StateConfig,
}
impl KeyOnlyAuth {
pub fn new(
security: SecurityConfig,
authorized_keys: Vec<PublicKey>,
state: StateConfig,
) -> Self {
Self {
policy: security.host_key_policy,
authorized_keys: Arc::new(StdMutex::new(authorized_keys)),
state,
}
}
fn lock_keys(&self) -> std::sync::MutexGuard<'_, Vec<PublicKey>> {
match self.authorized_keys.lock() {
Ok(guard) => guard,
Err(poisoned) => {
warn!("authorized client state mutex poisoned; recovering");
poisoned.into_inner()
}
}
}
}
impl Authenticator for KeyOnlyAuth {
fn supported_methods(&self) -> Vec<AuthMethod> {
vec![AuthMethod::PublicKey]
}
fn check_public_key(&self, _user: &str, key: &PublicKey) -> Result<bool> {
let fingerprint = key.fingerprint(HashAlg::Sha256).to_string();
if self.policy == HostKeyPolicy::AcceptAll {
info!(%fingerprint, "AcceptAll policy: automatically accepting client key.");
return Ok(true);
}
let mut authorized = self.lock_keys();
if !authorized.is_empty() {
if authorized.contains(key) {
info!(%fingerprint, "Client matched pre-authorized key. Access granted.");
return Ok(true);
}
warn!(%fingerprint, "Client key not in authorized list. Rejecting connection.");
return Ok(false);
}
match self.policy {
HostKeyPolicy::Strict => {
warn!(%fingerprint, "Strict policy: No pre-authorized keys found. Rejecting connection.");
Ok(false)
}
HostKeyPolicy::Tofu => {
info!(%fingerprint, "Tofu policy: No pre-authorized keys found. Trusting first client.");
let _event = write_authorized_client(&self.state, &fingerprint, key)?;
authorized.push(key.clone());
Ok(true)
}
HostKeyPolicy::AcceptAll => unreachable!(),
}
}
fn check_password(&self, _user: &str, _password: &str) -> Result<bool> {
Ok(false) }
}
#[derive(Debug)]
pub struct PasswordAuth {
password_hash: String,
}
impl PasswordAuth {
pub fn new(password_hash: impl Into<String>) -> Self {
Self {
password_hash: password_hash.into(),
}
}
}
impl Authenticator for PasswordAuth {
fn supported_methods(&self) -> Vec<AuthMethod> {
vec![AuthMethod::Password]
}
fn check_public_key(&self, _user: &str, _key: &PublicKey) -> Result<bool> {
Ok(false) }
fn check_password(&self, _user: &str, password: &str) -> Result<bool> {
let parsed_hash = PasswordHash::new(&self.password_hash)
.map_err(|reason| AuthError::VerificationFailed { reason })?;
match Argon2::default().verify_password(password.as_bytes(), &parsed_hash) {
Ok(()) => Ok(true),
Err(argon2::password_hash::Error::Password) => Ok(false),
Err(reason) => Err(AuthError::VerificationFailed { reason }.into()),
}
}
}
pub fn hash_password(password: &str) -> Result<String> {
let mut salt_bytes = [0u8; 16];
rand::fill(&mut salt_bytes);
let salt = SaltString::encode_b64(&salt_bytes)
.map_err(|reason| crate::error::StorageError::PasswordHash { reason })?;
let argon2 = Argon2::default();
let password_hash = argon2
.hash_password(password.as_bytes(), &salt)
.map_err(|reason| crate::error::StorageError::PasswordHash { reason })?
.to_string();
Ok(password_hash)
}
#[derive(Debug)]
pub struct CombinedAuth {
key_auth: KeyOnlyAuth,
password_auth: PasswordAuth,
}
impl CombinedAuth {
pub fn new(key_auth: KeyOnlyAuth, password_auth: PasswordAuth) -> Self {
Self {
key_auth,
password_auth,
}
}
}
impl Authenticator for CombinedAuth {
fn supported_methods(&self) -> Vec<AuthMethod> {
vec![AuthMethod::PublicKey, AuthMethod::Password]
}
fn check_public_key(&self, user: &str, key: &PublicKey) -> Result<bool> {
self.key_auth.check_public_key(user, key)
}
fn check_password(&self, user: &str, password: &str) -> Result<bool> {
self.password_auth.check_password(user, password)
}
}
#[derive(Debug, Clone)]
pub struct Credentials {
pub user: String,
pub password: String,
}
impl Credentials {
pub fn new(user: impl Into<String>, password: impl Into<String>) -> Self {
Self {
user: user.into(),
password: password.into(),
}
}
}
pub trait PasswordPrompter: Send + Sync + std::fmt::Debug + 'static {
fn prompt_password(&self, user: &str) -> Option<String>;
}
pub trait ConfirmationCallback: Send + Sync + std::fmt::Debug + 'static {
fn confirm_pairing(&self, fingerprint: &str, key: &PublicKey) -> bool;
}
#[derive(Debug, Clone)]
pub struct PairingMonitor {
pub success_flag: Arc<std::sync::atomic::AtomicBool>,
pub failed_attempts: Arc<AtomicU32>,
pub success_tx: Option<tokio::sync::mpsc::Sender<()>>,
}
#[derive(Debug)]
pub struct UnifiedAuthenticator {
state: StateConfig,
policy: HostKeyPolicy,
authorized_keys: Arc<StdMutex<Vec<PublicKey>>>,
temp_password_hash: Option<String>,
success_flag: Arc<std::sync::atomic::AtomicBool>,
failed_attempts: Arc<AtomicU32>,
cached_key: Arc<StdMutex<Option<PublicKey>>>,
success_tx: Option<tokio::sync::mpsc::Sender<()>>,
}
impl UnifiedAuthenticator {
pub fn new(
state: StateConfig,
policy: HostKeyPolicy,
authorized_keys: Vec<PublicKey>,
_temp_password_hash: Option<String>,
) -> Self {
Self {
state,
policy,
authorized_keys: Arc::new(StdMutex::new(authorized_keys)),
temp_password_hash: _temp_password_hash,
success_flag: Arc::new(std::sync::atomic::AtomicBool::new(false)),
failed_attempts: Arc::new(AtomicU32::new(0)),
cached_key: Arc::new(StdMutex::new(None)),
success_tx: None,
}
}
pub fn with_tracking(
state: StateConfig,
policy: HostKeyPolicy,
authorized_keys: Vec<PublicKey>,
_temp_password_hash: Option<String>,
monitor: PairingMonitor,
) -> Self {
Self {
state,
policy,
authorized_keys: Arc::new(StdMutex::new(authorized_keys)),
temp_password_hash: _temp_password_hash,
success_flag: monitor.success_flag,
failed_attempts: monitor.failed_attempts,
cached_key: Arc::new(StdMutex::new(None)),
success_tx: monitor.success_tx,
}
}
pub fn was_successful(&self) -> bool {
self.success_flag.load(Ordering::Relaxed)
}
pub fn failed_attempts(&self) -> u32 {
self.failed_attempts.load(Ordering::Relaxed)
}
fn lock_keys(&self) -> std::sync::MutexGuard<'_, Vec<PublicKey>> {
match self.authorized_keys.lock() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
}
}
fn refresh_keys(&self) -> Result<()> {
let vault = crate::storage::load_all_authorized_clients(&self.state)?;
let keys: Vec<_> = vault.into_iter().map(|(_, k)| k).collect();
let mut authorized = self.lock_keys();
*authorized = keys;
Ok(())
}
fn check_password_match(&self, password: &str) -> Result<bool> {
let argon2 = Argon2::default();
let node_hash = crate::storage::load_shadow_file(&self.state).unwrap_or_default();
if let Some(hash) = node_hash {
let parsed_hash = PasswordHash::new(&hash)
.map_err(|reason| AuthError::VerificationFailed { reason })?;
if argon2
.verify_password(password.as_bytes(), &parsed_hash)
.is_ok()
{
return Ok(true);
}
}
if let Some(hash) = &self.temp_password_hash {
let parsed_hash = PasswordHash::new(hash)
.map_err(|reason| AuthError::VerificationFailed { reason })?;
if argon2
.verify_password(password.as_bytes(), &parsed_hash)
.is_ok()
{
return Ok(true);
}
}
Ok(false)
}
fn notify_success(&self) {
if let Some(tx) = &self.success_tx {
let tx = tx.clone();
tokio::spawn(async move {
let _ = tx.send(()).await;
});
}
}
}
impl Authenticator for UnifiedAuthenticator {
fn supported_methods(&self) -> Vec<AuthMethod> {
let mut methods = vec![AuthMethod::PublicKey];
let node_pw_exists = crate::storage::load_shadow_file(&self.state)
.unwrap_or_default()
.is_some();
if node_pw_exists || self.temp_password_hash.is_some() {
methods.push(AuthMethod::Password);
}
methods
}
fn check_public_key(&self, _user: &str, key: &PublicKey) -> Result<bool> {
let fingerprint = key.fingerprint(HashAlg::Sha256).to_string();
{
let authorized = self.lock_keys();
if authorized.contains(key) {
info!(%fingerprint, "Client matched pre-authorized key. Access granted.");
return Ok(true);
}
}
let _ = self.refresh_keys();
let mut authorized = self.lock_keys();
if authorized.contains(key) {
info!(%fingerprint, "Client matched key after vault refresh. Access granted.");
return Ok(true);
}
if self.policy == HostKeyPolicy::Strict && !authorized.is_empty() {
warn!(%fingerprint, "Strict policy: unknown key rejected.");
return Ok(false);
}
let node_pw_exists = crate::storage::load_shadow_file(&self.state)
.unwrap_or_default()
.is_some();
if node_pw_exists || self.temp_password_hash.is_some() {
if let Ok(mut cache) = self.cached_key.lock() {
*cache = Some(key.clone());
}
return Ok(false);
}
if authorized.is_empty() {
info!(%fingerprint, "Vault is empty and no password set. Accepting first connection (TOFU).");
let _event =
crate::storage::trust::write_authorized_client(&self.state, &fingerprint, key)?;
authorized.push(key.clone());
self.success_flag.store(true, Ordering::Relaxed);
self.notify_success();
return Ok(true);
}
warn!(%fingerprint, "Vault is claimed and no password is set. Unknown key rejected.");
Ok(false)
}
fn check_password(&self, _user: &str, password: &str) -> Result<bool> {
if self.check_password_match(password)? {
if let Ok(cache) = self.cached_key.lock() {
if let Some(key) = &*cache {
let fingerprint = key.fingerprint(HashAlg::Sha256).to_string();
let mut authorized = self.lock_keys();
if !authorized.contains(key) {
info!(%fingerprint, "Password accepted: Adding new client to vault.");
let _event = write_authorized_client(&self.state, &fingerprint, key)?;
authorized.push(key.clone());
self.success_flag.store(true, Ordering::Relaxed);
self.notify_success();
}
return Ok(true);
}
}
}
self.failed_attempts.fetch_add(1, Ordering::Relaxed);
Ok(false)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{HostKeyPolicy, SecurityConfig, StateConfig};
fn temp_state(name: &str) -> StateConfig {
let mut path = std::env::temp_dir();
path.push(format!(
"irosh-auth-test-{}-{}",
name,
rand::random::<u32>()
));
StateConfig::new(path)
}
#[test]
fn key_only_accept_all_accepts_any_key() -> crate::Result<()> {
let auth = KeyOnlyAuth::new(
SecurityConfig {
host_key_policy: HostKeyPolicy::AcceptAll,
},
vec![],
temp_state("accept-all"),
);
assert!(auth.supported_methods().contains(&AuthMethod::PublicKey));
assert!(!auth.supported_methods().contains(&AuthMethod::Password));
assert!(!auth.check_password("user", "pass")?);
Ok(())
}
#[test]
fn password_auth_validates_correct_password() -> crate::Result<()> {
let password = "secret123";
let hash = hash_password(password).expect("failed to hash test password");
let auth = PasswordAuth::new(hash);
assert!(auth.check_password("anyone", password)?);
assert!(!auth.check_password("anyone", "wrong")?);
assert!(!auth.check_password("anyone", "")?);
assert!(auth.supported_methods().contains(&AuthMethod::Password));
assert!(!auth.supported_methods().contains(&AuthMethod::PublicKey));
Ok(())
}
#[test]
fn combined_auth_supports_both_methods() -> crate::Result<()> {
let key = KeyOnlyAuth::new(
SecurityConfig {
host_key_policy: HostKeyPolicy::AcceptAll,
},
vec![],
temp_state("combined"),
);
let password = "combo";
let hash = hash_password(password).expect("failed to hash test password");
let pass = PasswordAuth::new(hash);
let auth = CombinedAuth::new(key, pass);
assert_eq!(auth.supported_methods().len(), 2);
assert!(auth.supported_methods().contains(&AuthMethod::PublicKey));
assert!(auth.supported_methods().contains(&AuthMethod::Password));
assert!(auth.check_password("user", password)?);
assert!(!auth.check_password("user", "wrong")?);
Ok(())
}
#[test]
fn unified_auth_tofu_works_with_no_passwords() -> crate::Result<()> {
let state = temp_state("unified-tofu");
let auth = UnifiedAuthenticator::new(state.clone(), HostKeyPolicy::Tofu, vec![], None);
use russh::keys::ssh_key::PrivateKey;
use russh::keys::ssh_key::private::Ed25519Keypair;
let keypair = Ed25519Keypair::from_seed(&[0u8; 32]);
let key = PrivateKey::from(keypair).public_key().clone();
assert!(auth.check_public_key("user", &key)?);
let vault = crate::storage::load_all_authorized_clients(&state)?;
assert_eq!(vault.len(), 1);
Ok(())
}
#[test]
fn credentials_construction() {
let creds = Credentials::new("admin", "pass123");
assert_eq!(creds.user, "admin");
assert_eq!(creds.password, "pass123");
}
}