use std::{
collections::HashMap,
sync::{Arc, Mutex},
};
use anyhow::Context;
use axum::{
Json, Router,
extract::{Request, State},
response::{IntoResponse, Redirect},
routing::{get, post},
};
use axum_extra::extract::{CookieJar, cookie::Cookie};
use http::StatusCode;
use serde::Deserialize;
use tiny_google_oidc::{
code::{AccessType, AdditionalScope, CodeRequest, RawCodeResponse},
config::{Config, ConfigBuilder},
csrf_token::CSRFToken,
id_token::{IDToken, IDTokenRequest, send_id_token_req},
nonce::Nonce,
refresh_token::{RefreshToken, RefreshTokenRequest, send_refresh_token_req},
revoke_token::{RevokeToken, RevokeTokenRequest, send_revoke_token_req},
};
use tracing::error;
use uuid::Uuid;
extern crate tiny_google_oidc;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt::init();
let auth_endpoint = read_env("auth_endpoint")?;
let client_id = read_env("client_id")?;
let client_secret = read_env("client_secret")?;
let token_endpoint = read_env("token_endpoint")?;
let redirect_uri = read_env("redirect_uri")?;
let config = ConfigBuilder::new()
.auth_endpoint(auth_endpoint)
.client_id(client_id)
.client_secret(client_secret)
.token_endpoint(token_endpoint)
.redirect_uri(redirect_uri)
.build();
let app_state = AppState::new(config);
let listener = tokio::net::TcpListener::bind("0.0.0.0:80").await.unwrap();
let app = Router::new()
.route("/auth/callback", get(call_back))
.route("/", get(start_auth))
.route("/revoke", post(revoke_token))
.route("/refresh", post(refresh_token))
.with_state(Arc::new(app_state));
axum::serve(listener, app).await.unwrap();
anyhow::Ok(())
}
static COOKIE_KEY: &str = "csrf_token";
async fn start_auth(
State(app_state): State<Arc<AppState>>,
jar: CookieJar,
) -> Result<impl IntoResponse, StatusCode> {
let state = CSRFToken::new().map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let csrf_key = Uuid::new_v4().to_string();
let cookie = Cookie::new(COOKIE_KEY, csrf_key.clone());
{
app_state
.token
.lock()
.unwrap()
.insert(csrf_key, state.clone());
}
let nonce = Nonce::new();
let scope = AdditionalScope::Both;
let req = CodeRequest::new(
AccessType::Offline,
&app_state.config,
scope,
&state,
&nonce,
);
let url = req
.try_into_url()
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok((jar.add(cookie), Redirect::to(&url.to_string())))
}
async fn call_back(
State(app_state): State<Arc<AppState>>,
jar: CookieJar,
req: Request,
) -> Result<impl IntoResponse, StatusCode> {
let code_res = RawCodeResponse::new(req).map_err(|e| {
error!("Failed to parse url: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
let csrf_token: CSRFToken;
let cookie = jar.get(COOKIE_KEY).ok_or_else(|| StatusCode::BAD_REQUEST)?;
let csrf_key = cookie.value();
{
let lock = app_state.token.lock().unwrap();
csrf_token = lock
.get(csrf_key)
.ok_or_else(|| StatusCode::BAD_REQUEST)?
.to_owned();
}
let code = code_res
.exchange_with_code(csrf_token.value())
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let id_token_req = IDTokenRequest::new(&app_state.config, code);
let res = send_id_token_req(&id_token_req)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
println!("{:#?}", res);
let refresh_token = res.access_token();
println!("{:?}", refresh_token);
let id_token_raw = res.id_token();
let id_token = IDToken::from_id_token_raw(id_token_raw).unwrap();
Ok((StatusCode::OK, Json(id_token)))
}
async fn revoke_token(Json(refresh_token): Json<Token>) -> Result<impl IntoResponse, StatusCode> {
let token = RevokeToken::new_access_token(&refresh_token.token);
let req = RevokeTokenRequest::new(&token);
send_revoke_token_req(&req)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(StatusCode::OK)
}
async fn refresh_token(
State(app_state): State<Arc<AppState>>,
Json(refresh_token): Json<Token>,
) -> Result<impl IntoResponse, StatusCode> {
let refresh_token = RefreshToken::new(&refresh_token.token);
let req = RefreshTokenRequest::new(&app_state.config, &refresh_token);
let res = send_refresh_token_req(&req)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok((StatusCode::OK, Json(res)))
}
fn read_env(key: &str) -> anyhow::Result<String> {
dotenvy::var(key).context("Failed to read env")
}
#[derive(Debug, Clone)]
struct AppState {
config: Arc<Config>,
token: Arc<Mutex<HashMap<String, CSRFToken>>>,
}
impl AppState {
fn new(config: Config) -> Self {
Self {
config: Arc::new(config),
token: Arc::default(),
}
}
}
#[derive(Debug, Clone, Deserialize)]
struct Token {
token: String,
}