use std::sync::Arc;
use bytes::Bytes;
use cookie::{Cookie, CookieBuilder, SameSite};
use http::{
HeaderValue, Request, Response, StatusCode,
header::{COOKIE, LOCATION, SET_COOKIE},
};
use http_body_util::{BodyExt, Empty, combinators::BoxBody};
use hyper::body::Incoming;
use redis::{aio::ConnectionManager, cmd};
use tiny_google_oidc::{
code::{AccessType, AdditionalScope, CodeRequest, RawCodeResponse},
config::Config,
csrf_token::CSRFToken,
id_token::{IDToken, IDTokenRequest, send_id_token_req},
nonce::Nonce,
};
use uuid::Uuid;
use crate::protected::see_location_res;
static CSRF_COOKIE_KEY: &str = "csrf_key";
pub static SESSION_COOKIE_KEY: &str = "session";
#[derive(Clone)]
pub struct LoginService {
config: Arc<Config>,
redis_conn: ConnectionManager,
}
impl LoginService {
pub fn new(config: Arc<Config>, redis_conn: ConnectionManager) -> Self {
Self { config, redis_conn }
}
pub async fn entry(&mut self) -> anyhow::Result<Response<BoxBody<Bytes, std::io::Error>>> {
let csrf_token = CSRFToken::new()?;
let csrf_key = Uuid::new_v4().to_string();
let cookie = CookieBuilder::new(CSRF_COOKIE_KEY, csrf_key.clone())
.same_site(SameSite::Lax)
.http_only(true)
.build();
let scope = AdditionalScope::Both;
let _ = cmd("SET")
.arg(&csrf_key)
.arg(csrf_token.value())
.query_async::<String>(&mut self.redis_conn)
.await?;
let url = CodeRequest::new(
AccessType::Offline,
&self.config,
scope,
&csrf_token,
&Nonce::new(),
)
.try_into_url()?;
let res = Response::builder()
.status(StatusCode::SEE_OTHER)
.header(LOCATION, url.to_string())
.header(SET_COOKIE, cookie.to_string())
.body(Empty::new().map_err(|e| match e {}).boxed())
.unwrap();
Ok(res)
}
pub async fn callback(
&mut self,
req: Request<Incoming>,
) -> anyhow::Result<Response<BoxBody<Bytes, std::io::Error>>> {
let cookie_header_val = match req.headers().get(COOKIE) {
Some(v) => v,
None => return Ok(see_location_res("/")),
};
let cookies = Self::parse_cookies(&cookie_header_val)?;
let csrf_key = match cookies.iter().find(|c| c.name() == CSRF_COOKIE_KEY) {
Some(cookie) => cookie.value(),
None => return Ok(see_location_res("/")),
};
let csrf_token = cmd("GET")
.arg(&csrf_key)
.query_async::<String>(&mut self.redis_conn)
.await?;
let code_res = RawCodeResponse::new(&req)?;
let code = code_res.exchange_with_code(&csrf_token)?;
let id_token_res = send_id_token_req(&IDTokenRequest::new(&self.config, code)).await?;
let _access_token = id_token_res.access_token();
let _refresh_token = id_token_res.refresh_token();
let id_token = IDToken::from_id_token_raw(id_token_res.id_token())?;
let session_id = Uuid::new_v4().to_string();
let _ = cmd("SET")
.arg(&session_id)
.arg(&id_token.sub)
.query_async::<String>(&mut self.redis_conn)
.await?;
let _ = cmd("DEL")
.arg(&csrf_key)
.query_async::<String>(&mut self.redis_conn)
.await?;
let new_cookie = CookieBuilder::new(SESSION_COOKIE_KEY, session_id)
.same_site(SameSite::Lax)
.http_only(true)
.path("/")
.build();
let res = Response::builder()
.status(StatusCode::SEE_OTHER)
.header(SET_COOKIE, new_cookie.to_string())
.header(LOCATION, "/protected")
.body(Empty::new().map_err(|e| match e {}).boxed())
.unwrap();
Ok(res)
}
fn parse_cookies(header_val: &HeaderValue) -> anyhow::Result<Vec<Cookie<'_>>> {
let values = header_val.to_str()?;
let cookies: Vec<Cookie<'_>> = values
.split(';')
.filter_map(|c| Cookie::parse(c.trim().to_string()).ok())
.collect();
Ok(cookies)
}
}