use std::panic::Location;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use derive_more::Display;
use log::debug;
use rand::distributions::Alphanumeric;
use rand::{thread_rng, Rng};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use torrust_tracker_located_error::LocatedError;
use crate::shared::bit_torrent::common::AUTH_KEY_LENGTH;
use crate::shared::clock::{convert_from_timestamp_to_datetime_utc, Current, DurationSinceUnixEpoch, Time, TimeNow};
#[must_use]
pub fn generate(lifetime: Duration) -> ExpiringKey {
let random_id: String = thread_rng()
.sample_iter(&Alphanumeric)
.take(AUTH_KEY_LENGTH)
.map(char::from)
.collect();
debug!("Generated key: {}, valid for: {:?} seconds", random_id, lifetime);
ExpiringKey {
key: random_id.parse::<Key>().unwrap(),
valid_until: Current::add(&lifetime).unwrap(),
}
}
pub fn verify(auth_key: &ExpiringKey) -> Result<(), Error> {
let current_time: DurationSinceUnixEpoch = Current::now();
if auth_key.valid_until < current_time {
Err(Error::KeyExpired {
location: Location::caller(),
})
} else {
Ok(())
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
pub struct ExpiringKey {
pub key: Key,
pub valid_until: DurationSinceUnixEpoch,
}
impl std::fmt::Display for ExpiringKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "key: `{}`, valid until `{}`", self.key, self.expiry_time())
}
}
impl ExpiringKey {
#[must_use]
pub fn key(&self) -> Key {
self.key.clone()
}
#[must_use]
pub fn expiry_time(&self) -> chrono::DateTime<chrono::Utc> {
convert_from_timestamp_to_datetime_utc(self.valid_until)
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone, Display, Hash)]
pub struct Key(String);
#[derive(Debug, PartialEq, Eq)]
pub struct ParseKeyError;
impl FromStr for Key {
type Err = ParseKeyError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.len() != AUTH_KEY_LENGTH {
return Err(ParseKeyError);
}
Ok(Self(s.to_string()))
}
}
#[derive(Debug, Error)]
#[allow(dead_code)]
pub enum Error {
#[error("Key could not be verified: {source}")]
KeyVerificationError {
source: LocatedError<'static, dyn std::error::Error + Send + Sync>,
},
#[error("Failed to read key: {key}, {location}")]
UnableToReadKey {
location: &'static Location<'static>,
key: Box<Key>,
},
#[error("Key has expired, {location}")]
KeyExpired { location: &'static Location<'static> },
}
impl From<r2d2_sqlite::rusqlite::Error> for Error {
fn from(e: r2d2_sqlite::rusqlite::Error) -> Self {
Error::KeyVerificationError {
source: (Arc::new(e) as Arc<dyn std::error::Error + Send + Sync>).into(),
}
}
}
#[cfg(test)]
mod tests {
mod key {
use std::str::FromStr;
use crate::tracker::auth::Key;
#[test]
fn should_be_parsed_from_an_string() {
let key_string = "YZSl4lMZupRuOpSRC3krIKR5BPB14nrJ";
let key = Key::from_str(key_string);
assert!(key.is_ok());
assert_eq!(key.unwrap().to_string(), key_string);
}
}
mod expiring_auth_key {
use std::str::FromStr;
use std::time::Duration;
use crate::shared::clock::{Current, StoppedTime};
use crate::tracker::auth;
#[test]
fn should_be_parsed_from_an_string() {
let key_string = "YZSl4lMZupRuOpSRC3krIKR5BPB14nrJ";
let auth_key = auth::Key::from_str(key_string);
assert!(auth_key.is_ok());
assert_eq!(auth_key.unwrap().to_string(), key_string);
}
#[test]
fn should_be_displayed() {
Current::local_set_to_unix_epoch();
let expiring_key = auth::generate(Duration::from_secs(0));
assert_eq!(
expiring_key.to_string(),
format!("key: `{}`, valid until `1970-01-01 00:00:00 UTC`", expiring_key.key) );
}
#[test]
fn should_be_generated_with_a_expiration_time() {
let expiring_key = auth::generate(Duration::new(9999, 0));
assert!(auth::verify(&expiring_key).is_ok());
}
#[test]
fn should_be_generate_and_verified() {
Current::local_set_to_system_time_now();
let expiring_key = auth::generate(Duration::from_secs(19));
Current::local_add(&Duration::from_secs(10)).unwrap();
assert!(auth::verify(&expiring_key).is_ok());
Current::local_add(&Duration::from_secs(10)).unwrap();
assert!(auth::verify(&expiring_key).is_err());
}
}
}