use std::sync::Arc;
use anyhow::Result;
use async_trait::async_trait;
use russh::keys::ssh_key::PublicKey;
use super::password::{PasswordAuthConfig, PasswordVerifier};
use super::provider::AuthProvider;
use super::publickey::{PublicKeyAuthConfig, PublicKeyVerifier};
use crate::shared::auth_types::{AuthResult, UserInfo};
pub struct CompositeAuthProvider {
publickey_verifier: Option<PublicKeyVerifier>,
password_verifier: Option<Arc<PasswordVerifier>>,
}
impl CompositeAuthProvider {
pub async fn new(
publickey_config: Option<PublicKeyAuthConfig>,
password_config: Option<PasswordAuthConfig>,
) -> Result<Self> {
let publickey_verifier = publickey_config.map(PublicKeyVerifier::new);
let password_verifier = match password_config {
Some(config) => Some(Arc::new(PasswordVerifier::new(config).await?)),
None => None,
};
tracing::info!(
publickey_enabled = publickey_verifier.is_some(),
password_enabled = password_verifier.is_some(),
"Composite auth provider initialized"
);
Ok(Self {
publickey_verifier,
password_verifier,
})
}
pub fn publickey_only(config: PublicKeyAuthConfig) -> Self {
Self {
publickey_verifier: Some(PublicKeyVerifier::new(config)),
password_verifier: None,
}
}
pub async fn password_only(config: PasswordAuthConfig) -> Result<Self> {
Ok(Self {
publickey_verifier: None,
password_verifier: Some(Arc::new(PasswordVerifier::new(config).await?)),
})
}
pub fn publickey_enabled(&self) -> bool {
self.publickey_verifier.is_some()
}
pub fn password_enabled(&self) -> bool {
self.password_verifier.is_some()
}
pub fn password_verifier(&self) -> Option<&Arc<PasswordVerifier>> {
self.password_verifier.as_ref()
}
pub async fn reload_password_users(&self) -> Result<()> {
if let Some(ref verifier) = self.password_verifier {
verifier.reload_users().await?;
}
Ok(())
}
}
#[async_trait]
impl AuthProvider for CompositeAuthProvider {
async fn verify_publickey(&self, username: &str, key: &PublicKey) -> Result<AuthResult> {
if let Some(ref verifier) = self.publickey_verifier {
verifier.verify_publickey(username, key).await
} else {
Ok(AuthResult::Reject)
}
}
async fn verify_password(&self, username: &str, password: &str) -> Result<AuthResult> {
if let Some(ref verifier) = self.password_verifier {
verifier.verify_password(username, password).await
} else {
Ok(AuthResult::Reject)
}
}
async fn get_user_info(&self, username: &str) -> Result<Option<UserInfo>> {
if let Some(ref verifier) = self.password_verifier {
if let Some(info) = verifier.get_user_info(username).await? {
return Ok(Some(info));
}
}
if let Some(ref verifier) = self.publickey_verifier {
return verifier.get_user_info(username).await;
}
Ok(None)
}
async fn user_exists(&self, username: &str) -> Result<bool> {
if let Some(ref verifier) = self.password_verifier {
if verifier.user_exists(username).await? {
return Ok(true);
}
}
if let Some(ref verifier) = self.publickey_verifier {
if verifier.user_exists(username).await? {
return Ok(true);
}
}
Ok(false)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::server::auth::hash_password;
use crate::server::config::UserDefinition;
use std::collections::HashMap;
#[tokio::test]
async fn test_composite_provider_publickey_only() {
let config = PublicKeyAuthConfig::with_directory("/tmp/nonexistent");
let provider = CompositeAuthProvider::publickey_only(config);
assert!(provider.publickey_enabled());
assert!(!provider.password_enabled());
}
#[tokio::test]
async fn test_composite_provider_password_only() {
let hash = hash_password("password").unwrap();
let users = vec![UserDefinition {
name: "testuser".to_string(),
password_hash: hash,
shell: None,
home: None,
env: HashMap::new(),
}];
let config = PasswordAuthConfig::with_users(users);
let provider = CompositeAuthProvider::password_only(config).await.unwrap();
assert!(!provider.publickey_enabled());
assert!(provider.password_enabled());
let result = provider
.verify_password("testuser", "password")
.await
.unwrap();
assert!(result.is_accepted());
let result = provider.verify_password("testuser", "wrong").await.unwrap();
assert!(result.is_rejected());
}
#[tokio::test]
async fn test_composite_provider_both() {
let pubkey_config = PublicKeyAuthConfig::with_directory("/tmp/nonexistent");
let hash = hash_password("password").unwrap();
let users = vec![UserDefinition {
name: "testuser".to_string(),
password_hash: hash,
shell: None,
home: None,
env: HashMap::new(),
}];
let password_config = PasswordAuthConfig::with_users(users);
let provider = CompositeAuthProvider::new(Some(pubkey_config), Some(password_config))
.await
.unwrap();
assert!(provider.publickey_enabled());
assert!(provider.password_enabled());
}
#[tokio::test]
async fn test_composite_provider_user_info() {
let hash = hash_password("password").unwrap();
let users = vec![UserDefinition {
name: "testuser".to_string(),
password_hash: hash,
shell: Some("/bin/bash".into()),
home: Some("/home/testuser".into()),
env: HashMap::new(),
}];
let config = PasswordAuthConfig::with_users(users);
let provider = CompositeAuthProvider::password_only(config).await.unwrap();
let info = provider.get_user_info("testuser").await.unwrap();
assert!(info.is_some());
let info = info.unwrap();
assert_eq!(info.username, "testuser");
assert_eq!(info.shell.to_str().unwrap(), "/bin/bash");
assert_eq!(info.home_dir.to_str().unwrap(), "/home/testuser");
}
#[tokio::test]
async fn test_composite_provider_user_exists() {
let hash = hash_password("password").unwrap();
let users = vec![UserDefinition {
name: "existinguser".to_string(),
password_hash: hash,
shell: None,
home: None,
env: HashMap::new(),
}];
let config = PasswordAuthConfig::with_users(users);
let provider = CompositeAuthProvider::password_only(config).await.unwrap();
assert!(provider.user_exists("existinguser").await.unwrap());
assert!(!provider.user_exists("nonexistent").await.unwrap());
}
#[tokio::test]
async fn test_composite_provider_disabled_methods() {
let pubkey_config = PublicKeyAuthConfig::with_directory("/tmp/nonexistent");
let provider = CompositeAuthProvider::publickey_only(pubkey_config);
let result = provider.verify_password("user", "pass").await.unwrap();
assert!(result.is_rejected());
}
}