use super::Error;
use crate::comms::WebsocketSender;
use crate::config::get_config;
use crate::http::{Authorization, Request, Response};
use crate::view::{ToTemplateValue, Value};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use time::{Duration, OffsetDateTime};
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::Arc;
#[derive(Clone)]
pub struct AuthHandler {
auth: Arc<Box<dyn Authentication>>,
}
impl Default for AuthHandler {
fn default() -> Self {
Self::new(AllowAll {})
}
}
impl AuthHandler {
pub fn new(auth: impl Authentication + 'static) -> Self {
AuthHandler {
auth: Arc::new(Box::new(auth)),
}
}
pub fn auth(&self) -> &Box<dyn Authentication> {
&self.auth
}
}
#[async_trait]
#[allow(unused_variables)]
pub trait Authentication: Sync + Send {
async fn authorize(&self, request: &Request) -> Result<bool, Error>;
async fn denied(&self, request: &Request) -> Result<Response, Error> {
Ok(Response::forbidden())
}
fn handler(self) -> AuthHandler
where
Self: Sized + 'static,
{
AuthHandler::new(self)
}
}
pub struct AllowAll;
#[async_trait]
impl Authentication for AllowAll {
async fn authorize(&self, _request: &Request) -> Result<bool, Error> {
Ok(true)
}
}
pub struct DenyAll;
#[async_trait]
impl Authentication for DenyAll {
async fn authorize(&self, _request: &Request) -> Result<bool, Error> {
Ok(false)
}
}
pub struct BasicAuth {
pub user: String,
pub password: String,
}
#[async_trait]
impl Authentication for BasicAuth {
async fn authorize(&self, request: &Request) -> Result<bool, Error> {
Ok(
if let Some(Authorization::Basic { user, password }) = request.authorization() {
self.user == user && self.password == password
} else {
false
},
)
}
async fn denied(&self, _request: &Request) -> Result<Response, Error> {
Ok(Response::unauthorized("Basic"))
}
}
pub struct Token {
pub token: String,
}
#[async_trait]
impl Authentication for Token {
async fn authorize(&self, request: &Request) -> Result<bool, Error> {
Ok(
if let Some(Authorization::Token { token }) = request.authorization() {
self.token == token
} else {
false
},
)
}
}
#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, Eq)]
pub enum SessionId {
Guest(String),
Authenticated(i64),
}
impl SessionId {
pub fn authenticated(&self) -> bool {
use SessionId::*;
match self {
Guest(_) => false,
Authenticated(_) => true,
}
}
pub fn guest(&self) -> bool {
!self.authenticated()
}
pub fn user_id(&self) -> Option<i64> {
match self {
SessionId::Authenticated(id) => Some(*id),
_ => None,
}
}
}
impl std::fmt::Display for SessionId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SessionId::Authenticated(id) => write!(f, "{}", id),
SessionId::Guest(id) => write!(f, "{}", id),
}
}
}
impl Default for SessionId {
fn default() -> Self {
use rand::{distributions::Alphanumeric, thread_rng, Rng};
SessionId::Guest(
thread_rng()
.sample_iter(&Alphanumeric)
.take(16)
.map(char::from)
.collect::<String>(),
)
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Session {
#[serde(rename = "p")]
pub payload: serde_json::Value,
#[serde(rename = "e")]
pub expiration: i64,
#[serde(rename = "s")]
pub session_id: SessionId,
}
impl Default for Session {
fn default() -> Self {
Self::new(serde_json::json!({})).expect("json")
}
}
impl ToTemplateValue for Session {
fn to_template_value(&self) -> Result<Value, crate::view::Error> {
let mut hash = HashMap::new();
hash.insert("expiration".into(), Value::Integer(self.expiration));
hash.insert(
"session_id".into(),
Value::String(self.session_id.to_string()),
);
hash.insert(
"payload".into(),
Value::String(serde_json::to_string(&self.payload).unwrap()),
);
Ok(Value::Hash(hash))
}
}
impl Session {
pub fn anonymous() -> Self {
Self::default()
}
pub fn empty() -> Self {
Self::default()
}
pub fn new(payload: impl Serialize) -> Result<Self, Error> {
Ok(Self {
payload: serde_json::to_value(payload)?,
expiration: (OffsetDateTime::now_utc() + get_config().general.session_duration())
.unix_timestamp(),
session_id: SessionId::default(),
})
}
pub fn new_authenticated(payload: impl Serialize, user_id: i64) -> Result<Self, Error> {
let mut session = Self::new(payload)?;
session.session_id = SessionId::Authenticated(user_id);
Ok(session)
}
pub fn renew(mut self, renew_for: Duration) -> Self {
self.expiration = (OffsetDateTime::now_utc() + renew_for).unix_timestamp();
self
}
pub fn should_renew(&self) -> bool {
if let Ok(expiration) = OffsetDateTime::from_unix_timestamp(self.expiration) {
let now = OffsetDateTime::now_utc();
let remains = expiration - now;
let session_duration = get_config().general.session_duration();
remains < session_duration / 2 && remains.is_positive() } else {
true
}
}
pub fn expired(&self) -> bool {
if let Ok(expiration) = OffsetDateTime::from_unix_timestamp(self.expiration) {
let now = OffsetDateTime::now_utc();
expiration < now
} else {
false
}
}
pub fn websocket(&self) -> WebsocketSender {
use crate::comms::Comms;
Comms::websocket(&self.session_id)
}
pub fn authenticated(&self) -> bool {
!self.expired() && self.session_id.authenticated()
}
pub fn guest(&self) -> bool {
!self.expired() && self.session_id.guest()
}
}
#[derive(Default)]
pub struct SessionAuth {
redirect: Option<String>,
}
impl SessionAuth {
pub fn redirect(url: impl ToString) -> Self {
Self {
redirect: Some(url.to_string()),
}
}
}
#[async_trait]
impl Authentication for SessionAuth {
async fn authorize(&self, request: &Request) -> Result<bool, Error> {
Ok(request.session().authenticated())
}
async fn denied(&self, _request: &Request) -> Result<Response, Error> {
if let Some(ref redirect) = self.redirect {
Ok(Response::new().redirect(redirect))
} else {
Ok(Response::forbidden())
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_should_renew() {
let mut session = Session::default();
assert!(!session.should_renew());
assert_eq!(get_config().general.session_duration(), Duration::weeks(4));
session.expiration = (OffsetDateTime::now_utc() + Duration::weeks(2)
- Duration::seconds(5))
.unix_timestamp();
assert!(session.should_renew());
session.expiration =
(OffsetDateTime::now_utc() + Duration::weeks(2) + Duration::seconds(5))
.unix_timestamp();
assert!(!session.should_renew());
}
}