use std::panic::Location;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use derive_more::Display;
use rand::distributions::Alphanumeric;
use rand::{thread_rng, Rng};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use torrust_tracker_clock::clock::Time;
use torrust_tracker_clock::conv::convert_from_timestamp_to_datetime_utc;
use torrust_tracker_located_error::{DynError, LocatedError};
use torrust_tracker_primitives::DurationSinceUnixEpoch;
use crate::shared::bit_torrent::common::AUTH_KEY_LENGTH;
use crate::CurrentClock;
#[must_use]
pub fn generate_permanent_key() -> PeerKey {
generate_key(None)
}
#[must_use]
pub fn generate_key(lifetime: Option<Duration>) -> PeerKey {
let random_id: String = thread_rng()
.sample_iter(&Alphanumeric)
.take(AUTH_KEY_LENGTH)
.map(char::from)
.collect();
if let Some(lifetime) = lifetime {
tracing::debug!("Generated key: {}, valid for: {:?} seconds", random_id, lifetime);
PeerKey {
key: random_id.parse::<Key>().unwrap(),
valid_until: Some(CurrentClock::now_add(&lifetime).unwrap()),
}
} else {
tracing::debug!("Generated key: {}, permanent", random_id);
PeerKey {
key: random_id.parse::<Key>().unwrap(),
valid_until: None,
}
}
}
pub fn verify_key_expiration(auth_key: &PeerKey) -> Result<(), Error> {
let current_time: DurationSinceUnixEpoch = CurrentClock::now();
match auth_key.valid_until {
Some(valid_until) => {
if valid_until < current_time {
Err(Error::KeyExpired {
location: Location::caller(),
})
} else {
Ok(())
}
}
None => Ok(()), }
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
pub struct PeerKey {
pub key: Key,
pub valid_until: Option<DurationSinceUnixEpoch>,
}
impl std::fmt::Display for PeerKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.expiry_time() {
Some(expire_time) => write!(f, "key: `{}`, valid until `{}`", self.key, expire_time),
None => write!(f, "key: `{}`, permanent", self.key),
}
}
}
impl PeerKey {
#[must_use]
pub fn key(&self) -> Key {
self.key.clone()
}
#[must_use]
pub fn expiry_time(&self) -> Option<chrono::DateTime<chrono::Utc>> {
self.valid_until.map(convert_from_timestamp_to_datetime_utc)
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone, Display, Hash)]
pub struct Key(String);
impl Key {
pub fn new(value: &str) -> Result<Self, ParseKeyError> {
if value.len() != AUTH_KEY_LENGTH {
return Err(ParseKeyError::InvalidKeyLength);
}
if !value.chars().all(|c| c.is_ascii_alphanumeric()) {
return Err(ParseKeyError::InvalidChars);
}
Ok(Self(value.to_owned()))
}
#[must_use]
pub fn value(&self) -> &str {
&self.0
}
}
#[derive(Debug, Error)]
pub enum ParseKeyError {
#[error("Invalid key length. Key must be have 32 chars")]
InvalidKeyLength,
#[error("Invalid chars for key. Key can only alphanumeric chars (0-9, a-z, A-Z)")]
InvalidChars,
}
impl FromStr for Key {
type Err = ParseKeyError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Key::new(s)?;
Ok(Self(s.to_string()))
}
}
#[derive(Debug, Error)]
#[allow(dead_code)]
pub enum Error {
#[error("Key could not be verified: {source}")]
KeyVerificationError {
source: LocatedError<'static, dyn std::error::Error + Send + Sync>,
},
#[error("Failed to read key: {key}, {location}")]
UnableToReadKey {
location: &'static Location<'static>,
key: Box<Key>,
},
#[error("Key has expired, {location}")]
KeyExpired { location: &'static Location<'static> },
}
impl From<r2d2_sqlite::rusqlite::Error> for Error {
fn from(e: r2d2_sqlite::rusqlite::Error) -> Self {
Error::KeyVerificationError {
source: (Arc::new(e) as DynError).into(),
}
}
}
#[cfg(test)]
mod tests {
mod key {
use std::str::FromStr;
use crate::core::auth::Key;
#[test]
fn should_be_parsed_from_an_string() {
let key_string = "YZSl4lMZupRuOpSRC3krIKR5BPB14nrJ";
let key = Key::from_str(key_string);
assert!(key.is_ok());
assert_eq!(key.unwrap().to_string(), key_string);
}
#[test]
fn length_should_be_32() {
let key = Key::new("");
assert!(key.is_err());
let string_longer_than_32 = "012345678901234567890123456789012"; let key = Key::new(string_longer_than_32);
assert!(key.is_err());
}
#[test]
fn should_only_include_alphanumeric_chars() {
let key = Key::new("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%");
assert!(key.is_err());
}
}
mod expiring_auth_key {
use std::str::FromStr;
use std::time::Duration;
use torrust_tracker_clock::clock;
use torrust_tracker_clock::clock::stopped::Stopped as _;
use crate::core::auth;
#[test]
fn should_be_parsed_from_an_string() {
let key_string = "YZSl4lMZupRuOpSRC3krIKR5BPB14nrJ";
let auth_key = auth::Key::from_str(key_string);
assert!(auth_key.is_ok());
assert_eq!(auth_key.unwrap().to_string(), key_string);
}
#[test]
fn should_be_displayed() {
clock::Stopped::local_set_to_unix_epoch();
let expiring_key = auth::generate_key(Some(Duration::from_secs(0)));
assert_eq!(
expiring_key.to_string(),
format!("key: `{}`, valid until `1970-01-01 00:00:00 UTC`", expiring_key.key) );
}
#[test]
fn should_be_generated_with_a_expiration_time() {
let expiring_key = auth::generate_key(Some(Duration::new(9999, 0)));
assert!(auth::verify_key_expiration(&expiring_key).is_ok());
}
#[test]
fn should_be_generate_and_verified() {
clock::Stopped::local_set_to_system_time_now();
let expiring_key = auth::generate_key(Some(Duration::from_secs(19)));
clock::Stopped::local_add(&Duration::from_secs(10)).unwrap();
assert!(auth::verify_key_expiration(&expiring_key).is_ok());
clock::Stopped::local_add(&Duration::from_secs(10)).unwrap();
assert!(auth::verify_key_expiration(&expiring_key).is_err());
}
}
}