use cookie::Cookie;
use futures_util::io::{AsyncRead, AsyncWrite};
use http::header;
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
use std::convert::TryInto;
use std::fmt;
use std::time::Duration;
use crate::{server::ResponseWriter, Request};
#[derive(Clone)]
pub struct Identity {
server_key: String,
issuer: Option<String>,
expiration_time: Duration,
cookie_name: String,
cookie_path: String,
cookie_secure: bool,
cookie_http_only: bool,
}
impl Identity {
pub fn build(server_key: &str) -> IdentityBuilder {
IdentityBuilder::new(server_key)
}
pub fn authorized_user(&self, req: &Request) -> Option<String> {
let jwtstr = get_cookie(&req, &self.cookie_name);
if let Some(jwtstr) = jwtstr {
let token = decode::<Claims>(
&jwtstr,
&DecodingKey::from_secret(self.server_key.as_bytes()),
&Validation::default(),
)
.ok()?;
Some(token.claims.sub)
} else {
None
}
}
pub fn set_auth_token<W>(&self, user: &str, resp_wtr: &mut ResponseWriter<W>)
where
W: AsyncRead + AsyncWrite + Clone + Send + Sync + Unpin + 'static,
{
let token = self.make_token(Some(user), None).unwrap();
let cookie = Cookie::build(&self.cookie_name, token)
.path(&self.cookie_path)
.max_age(self.expiration_time.try_into().unwrap()) .http_only(self.cookie_http_only)
.secure(self.cookie_secure)
.finish();
resp_wtr.append_header(header::SET_COOKIE, cookie.to_string().parse().unwrap());
}
pub fn forget<W>(&self, resp_wtr: &mut ResponseWriter<W>)
where
W: AsyncRead + AsyncWrite + Clone + Send + Sync + Unpin + 'static,
{
let token = self.make_token(None, Some(0)).unwrap();
let cookie = Cookie::build(&self.cookie_name, token)
.path(&self.cookie_path)
.max_age(time::Duration::seconds(0)) .http_only(self.cookie_http_only)
.secure(self.cookie_secure)
.finish();
resp_wtr.append_header(header::SET_COOKIE, cookie.to_string().parse().unwrap());
}
fn make_token(
&self,
user: Option<&str>,
expiration: Option<u64>,
) -> Result<String, IdentityFail> {
let claims = Claims {
exp: expiration
.unwrap_or_else(|| self.expiration_time.as_secs() + current_numeric_date()),
iss: self
.issuer
.as_ref()
.cloned()
.unwrap_or_else(|| "".to_owned()),
sub: user.map(|s| s.to_owned()).unwrap_or_else(|| "".to_owned()),
};
encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(self.server_key.as_bytes()),
)
.map_err(IdentityFail::Encode)
}
}
pub struct IdentityBuilder {
server_key: String,
issuer: Option<String>,
expiration_time: Duration,
cookie_name: Option<String>, cookie_path: Option<String>, cookie_secure: bool, cookie_http_only: bool, }
impl IdentityBuilder {
pub fn new(server_key: &str) -> IdentityBuilder {
IdentityBuilder {
server_key: server_key.to_owned(),
issuer: None,
expiration_time: Duration::from_secs(60 * 60 * 24),
cookie_name: None,
cookie_path: None,
cookie_secure: true,
cookie_http_only: true,
}
}
pub fn cookie_name(mut self, name: &str) -> Self {
self.cookie_name = Some(name.to_owned());
self
}
pub fn cookie_path(mut self, path: &str) -> Self {
self.cookie_path = Some(path.to_owned());
self
}
pub fn cookie_secure(mut self, secure: bool) -> Self {
self.cookie_secure = secure;
self
}
pub fn cookie_http_only(mut self, http_only: bool) -> Self {
self.cookie_http_only = http_only;
self
}
pub fn issuer(mut self, issuer: &str) -> Self {
self.issuer = Some(issuer.to_owned());
self
}
pub fn expiration_time(mut self, expiration_time: Duration) -> Self {
self.expiration_time = expiration_time;
self
}
pub fn finish(self) -> Identity {
Identity {
server_key: self.server_key,
issuer: self.issuer,
expiration_time: self.expiration_time,
cookie_name: self.cookie_name.unwrap_or_else(|| "jwt".to_owned()),
cookie_path: self.cookie_path.unwrap_or_else(|| "/".to_owned()),
cookie_secure: self.cookie_secure,
cookie_http_only: self.cookie_http_only,
}
}
}
fn get_cookie(req: &Request, name: &str) -> Option<String> {
for cookie in req.headers().get_all(header::COOKIE) {
let cookie = Cookie::parse(cookie.to_str().ok()?).ok()?;
if cookie.name() == name {
return Some(cookie.value().to_string());
}
}
None
}
fn current_numeric_date() -> u64 {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.ok()
.unwrap()
.as_secs()
}
#[derive(Debug, Serialize, Deserialize)]
struct Claims {
exp: u64,
iss: String,
sub: String,
}
#[derive(Debug)]
pub enum IdentityFail {
Encode(jsonwebtoken::errors::Error),
Decode(jsonwebtoken::errors::Error),
}
impl std::error::Error for IdentityFail {}
impl fmt::Display for IdentityFail {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use IdentityFail::*;
match self {
Encode(err) => write!(f, "jwt encoding error: {}", err),
Decode(err) => write!(f, "jwt decoding error: {}", err),
}
}
}