use crate::options::{FailedLoginsBlock, FailedLoginsPolicy};
use super::shutdown;
use slog::Logger;
use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{Mutex, RwLock};
#[derive(Hash, Eq, PartialEq, Debug, Clone)]
struct FailedLoginsKey {
ip: Option<IpAddr>,
username: Option<String>,
}
#[derive(Debug, Clone)]
struct FailedLoginsEntry {
attempts: u32,
last_attempt_at: Instant,
}
impl FailedLoginsEntry {
fn new() -> Mutex<FailedLoginsEntry> {
Mutex::new(FailedLoginsEntry {
attempts: 1,
last_attempt_at: Instant::now(),
})
}
fn time_elapsed(&self) -> Duration {
self.last_attempt_at.elapsed()
}
fn touch(&mut self) {
self.last_attempt_at = Instant::now();
}
}
#[derive(Debug)]
pub struct FailedLoginsCache {
policy: FailedLoginsPolicy,
failed_logins: Arc<RwLock<HashMap<FailedLoginsKey, Mutex<FailedLoginsEntry>>>>,
}
#[derive(Debug)]
pub enum LockState {
MaxFailuresReached,
AlreadyLocked,
}
impl FailedLoginsCache {
pub fn new(policy: FailedLoginsPolicy) -> Arc<FailedLoginsCache> {
Arc::new(FailedLoginsCache {
policy,
failed_logins: Arc::new(RwLock::new(HashMap::new())),
})
}
fn is_expired(&self, time_elapsed: Duration) -> bool {
time_elapsed > self.policy.expires_after
}
fn is_locked(&self, attempts: u32) -> bool {
attempts >= self.policy.max_attempts
}
fn getkey(&self, ip: IpAddr, user: String) -> FailedLoginsKey {
match self.policy.block_by {
FailedLoginsBlock::UserAndIP => FailedLoginsKey {
ip: Some(ip),
username: Some(user),
},
FailedLoginsBlock::IP => FailedLoginsKey { ip: Some(ip), username: None },
FailedLoginsBlock::User => FailedLoginsKey {
ip: None,
username: Some(user),
},
}
}
pub async fn failed(&self, ip: IpAddr, user: String) -> Option<LockState> {
let map = self.failed_logins.read().await;
let key = self.getkey(ip, user);
let entry = map.get(&key);
let attempts = match entry {
Some(entry) => {
let mut entry = entry.lock().await;
if self.is_expired(entry.time_elapsed()) {
entry.attempts = 1;
} else {
entry.attempts += 1;
}
entry.touch();
entry.attempts
}
None => {
drop(map);
let mut map = self.failed_logins.write().await;
let entry = FailedLoginsEntry::new();
map.insert(key, entry);
1 }
};
match attempts {
a if a == self.policy.max_attempts => Some(LockState::MaxFailuresReached),
a if a > self.policy.max_attempts => Some(LockState::AlreadyLocked),
_ => None,
}
}
pub async fn success(&self, ip: IpAddr, user: String) -> Option<LockState> {
let map = self.failed_logins.read().await;
let key = self.getkey(ip, user);
let entry = map.get(&key);
let (is_expired, is_locked) = if let Some(entry) = entry {
let entry = entry.lock().await;
(self.is_expired(entry.time_elapsed()), self.is_locked(entry.attempts))
} else {
return None;
};
drop(map);
match (is_expired, is_locked) {
(false, true) => Some(LockState::AlreadyLocked),
(_, _) => {
let mut map = self.failed_logins.write().await;
map.remove(&key);
None
}
}
}
pub async fn sweeper(&self, logger: Logger, shutdown_topic: Arc<shutdown::Notifier>) {
let mut shutdown_listener = shutdown_topic.subscribe().await;
let interval = std::time::Duration::new(10, 0);
loop {
let mut expire_check_interval = Box::pin(tokio::time::sleep(interval));
tokio::select! {
_ = &mut expire_check_interval => {
let map = self.failed_logins.read().await;
let mut expired_entries: Vec<FailedLoginsKey> = Vec::new();
for (key, entry) in map.iter() {
let entry = entry.lock().await;
slog::debug!(logger, "Checking expired entry: key={:?} attempts={} elapsed={:?} policy={:?}", key, entry.attempts, entry.time_elapsed(), self.policy);
if self.is_expired(entry.time_elapsed()) {
expired_entries.push(key.clone());
}
}
drop(map);
if !expired_entries.is_empty() {
let mut map = self.failed_logins.write().await;
for key in expired_entries {
slog::debug!(logger, "Failed logins entry expired: {:?}", key);
map.remove(&key);
}
}
}
_ = shutdown_listener.listen() => {
slog::info!(logger, "Sweeper received shutdown signal.");
return;
}
}
}
}
}