use actix_web::{web, HttpRequest, HttpResponse, Responder, Result};
use base64::{engine::general_purpose, Engine as _};
use parking_lot::RwLock;
use rand::RngCore;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::sync::Arc;
use time::{Duration, OffsetDateTime};
use url::Url;
#[derive(Clone, Debug)]
pub struct Client {
pub client_id: String,
pub client_secret: Option<String>, pub redirect_uris: Vec<String>,
}
#[derive(Clone, Debug)]
pub struct AuthCode {
pub code: String,
pub client_id: String,
pub redirect_uri: String,
pub subject: String,
pub expires_at: OffsetDateTime,
pub code_challenge: Option<String>,
pub code_challenge_method: Option<String>,
}
#[derive(Clone, Debug)]
pub struct RefreshToken {
pub token: String,
pub subject: String,
pub client_id: String,
pub expires_at: OffsetDateTime,
}
#[derive(Clone, Debug)]
pub struct AccessToken {
pub token: String,
pub subject: String,
pub client_id: String,
pub expires_at: OffsetDateTime,
}
#[derive(Clone, Debug)]
pub struct AppState {
pub clients: HashMap<String, Client>,
pub codes: HashMap<String, AuthCode>,
pub access_tokens: HashMap<String, AccessToken>,
pub refresh_tokens: HashMap<String, RefreshToken>,
pub users: HashMap<String, String>, }
pub type SharedState = Arc<RwLock<AppState>>;
pub fn rand_b64url(len: usize) -> String {
let mut buf = vec![0u8; len];
let mut rng = rand::rng();
rng.fill_bytes(&mut buf);
general_purpose::URL_SAFE_NO_PAD.encode(&buf)
}
pub async fn authorize(req: HttpRequest, data: web::Data<SharedState>) -> Result<impl Responder> {
let qp: HashMap<_, _> = url::form_urlencoded::parse(req.query_string().as_bytes())
.into_owned()
.collect();
let response_type = qp.get("response_type").map(|s| s.as_str()).unwrap_or("");
if response_type != "code" {
return Ok(HttpResponse::BadRequest().body("unsupported response_type"));
}
let client_id = match qp.get("client_id") {
Some(v) => v,
None => return Ok(HttpResponse::BadRequest().body("missing client_id")),
};
let redirect_uri = match qp.get("redirect_uri") {
Some(v) => v,
None => return Ok(HttpResponse::BadRequest().body("missing redirect_uri")),
};
let state_param = qp.get("state").cloned();
let code_challenge = qp.get("code_challenge").cloned();
let code_challenge_method = qp.get("code_challenge_method").cloned();
let s = data.read();
let client = match s.clients.get(client_id) {
Some(c) => c.clone(),
None => return Ok(HttpResponse::BadRequest().body("unknown client_id")),
};
if !client.redirect_uris.iter().any(|r| r == redirect_uri) {
return Ok(HttpResponse::BadRequest().body("invalid redirect_uri"));
}
drop(s);
let subject = "user-1".to_string();
let code = rand_b64url(32);
let expires_at = OffsetDateTime::now_utc() + Duration::minutes(5);
let auth_code = AuthCode {
code: code.clone(),
client_id: client_id.clone(),
redirect_uri: redirect_uri.clone(),
subject: subject.clone(),
expires_at,
code_challenge,
code_challenge_method,
};
{
let mut s = data.write();
s.codes.insert(code.clone(), auth_code);
}
let mut redirect = match Url::parse(redirect_uri) {
Ok(u) => u,
Err(_) => return Ok(HttpResponse::BadRequest().body("invalid redirect_uri format")),
};
{
let mut qp = redirect.query_pairs_mut();
qp.append_pair("code", &code);
if let Some(st) = state_param {
qp.append_pair("state", &st);
}
}
Ok(HttpResponse::Found()
.append_header(("Location", redirect.to_string()))
.finish())
}
#[derive(Deserialize)]
pub struct TokenForm {
pub grant_type: String,
pub code: Option<String>,
pub redirect_uri: Option<String>,
pub client_id: Option<String>,
pub client_secret: Option<String>,
pub code_verifier: Option<String>,
pub refresh_token: Option<String>,
pub scope: Option<String>,
}
pub async fn token(
form: web::Form<TokenForm>,
_req: HttpRequest,
data: web::Data<SharedState>,
) -> Result<impl Responder> {
let f = form.into_inner();
match f.grant_type.as_str() {
"authorization_code" => {
let code = match f.code {
Some(c) => c,
None => return Ok(HttpResponse::BadRequest().body("missing code")),
};
let client_id = match f.client_id {
Some(c) => c,
None => return Ok(HttpResponse::BadRequest().body("missing client_id")),
};
{
let s = data.read();
if !s.clients.contains_key(&client_id) {
return Ok(HttpResponse::BadRequest().body("unknown client"));
}
}
let auth_code = {
let mut s = data.write();
match s.codes.remove(&code) {
Some(c) => c,
None => return Ok(HttpResponse::BadRequest().body("invalid code")),
}
};
if OffsetDateTime::now_utc() > auth_code.expires_at {
return Ok(HttpResponse::BadRequest().body("code expired"));
}
if let Some(ref r) = f.redirect_uri {
if r != &auth_code.redirect_uri {
return Ok(HttpResponse::BadRequest().body("redirect_uri mismatch"));
}
}
if let Some(challenge) = auth_code.code_challenge {
let method = auth_code
.code_challenge_method
.unwrap_or_else(|| "S256".to_string());
let verifier = match f.code_verifier {
Some(v) => v,
None => {
return Ok(HttpResponse::BadRequest().body("missing code_verifier for PKCE"))
}
};
match method.as_str() {
"S256" => {
let digest = Sha256::digest(verifier.as_bytes());
let enc = general_purpose::URL_SAFE_NO_PAD.encode(digest);
if enc != challenge {
return Ok(HttpResponse::BadRequest().body("pkce verification failed"));
}
}
"plain" => {
if verifier != challenge {
return Ok(HttpResponse::BadRequest().body("pkce verification failed"));
}
}
_ => {
return Ok(
HttpResponse::BadRequest().body("unsupported code_challenge_method")
)
}
}
}
let access_token = rand_b64url(32);
let refresh_token = rand_b64url(48);
let now = OffsetDateTime::now_utc();
let at = AccessToken {
token: access_token.clone(),
subject: auth_code.subject.clone(),
client_id: auth_code.client_id.clone(),
expires_at: now + Duration::minutes(60),
};
let rt = RefreshToken {
token: refresh_token.clone(),
subject: auth_code.subject.clone(),
client_id: auth_code.client_id.clone(),
expires_at: now + Duration::days(30),
};
{
let mut s = data.write();
s.access_tokens.insert(access_token.clone(), at);
s.refresh_tokens.insert(refresh_token.clone(), rt);
}
#[derive(Serialize)]
struct TokenResp {
access_token: String,
token_type: &'static str,
expires_in: u64,
refresh_token: String,
}
let resp = TokenResp {
access_token,
token_type: "Bearer",
expires_in: 3600,
refresh_token,
};
Ok(HttpResponse::Ok().json(resp))
}
"client_credentials" => {
let client_id = match f.client_id {
Some(c) => c,
None => return Ok(HttpResponse::BadRequest().body("missing client_id")),
};
let client_secret = f.client_secret;
{
let s = data.read();
match s.clients.get(&client_id) {
Some(c) => {
if c.client_secret.is_some() {
if client_secret.is_none()
|| client_secret.unwrap() != c.client_secret.clone().unwrap()
{
return Ok(
HttpResponse::Unauthorized().body("invalid client_secret")
);
}
} else {
return Ok(HttpResponse::Unauthorized().body("client not confidential"));
}
}
None => return Ok(HttpResponse::BadRequest().body("unknown client")),
}
}
let access_token = rand_b64url(32);
let now = OffsetDateTime::now_utc();
let at = AccessToken {
token: access_token.clone(),
subject: "".into(), client_id: client_id.clone(),
expires_at: now + Duration::minutes(3600),
};
{
let mut s = data.write();
s.access_tokens.insert(access_token.clone(), at);
}
#[derive(Serialize)]
struct CCResp {
access_token: String,
token_type: &'static str,
expires_in: u64,
}
Ok(HttpResponse::Ok().json(CCResp {
access_token,
token_type: "Bearer",
expires_in: 3600,
}))
}
"refresh_token" => {
let rt = match f.refresh_token {
Some(t) => t,
None => return Ok(HttpResponse::BadRequest().body("missing refresh_token")),
};
let (subject, client_id) = {
let mut s = data.write();
match s.refresh_tokens.remove(&rt) {
Some(info) => {
if OffsetDateTime::now_utc() > info.expires_at {
return Ok(HttpResponse::BadRequest().body("refresh token expired"));
}
(info.subject.clone(), info.client_id.clone())
}
None => return Ok(HttpResponse::BadRequest().body("invalid refresh_token")),
}
};
let access_token = rand_b64url(32);
let refresh_token = rand_b64url(48);
let now = OffsetDateTime::now_utc();
let at = AccessToken {
token: access_token.clone(),
subject: subject.clone(),
client_id: client_id.clone(),
expires_at: now + Duration::minutes(3600),
};
let new_rt = RefreshToken {
token: refresh_token.clone(),
subject: subject.clone(),
client_id: client_id.clone(),
expires_at: now + Duration::days(30),
};
{
let mut s = data.write();
s.access_tokens.insert(access_token.clone(), at);
s.refresh_tokens.insert(refresh_token.clone(), new_rt);
}
#[derive(Serialize)]
struct RTResp {
access_token: String,
token_type: &'static str,
expires_in: u64,
refresh_token: String,
}
Ok(HttpResponse::Ok().json(RTResp {
access_token,
token_type: "Bearer",
expires_in: 3600,
refresh_token,
}))
}
_ => Ok(HttpResponse::BadRequest().body("unsupported grant_type")),
}
}
#[derive(Deserialize)]
pub struct IntrospectForm {
token: String,
}
pub async fn introspect(
form: web::Form<IntrospectForm>,
data: web::Data<SharedState>,
) -> Result<impl Responder> {
let token = &form.token;
let s = data.read();
if let Some(at) = s.access_tokens.get(token) {
let active = OffsetDateTime::now_utc() < at.expires_at;
#[derive(Serialize)]
struct Resp<'a> {
active: bool,
client_id: &'a str,
exp: i64,
sub: &'a str,
}
let r = Resp {
active,
client_id: &at.client_id,
exp: at.expires_at.unix_timestamp(),
sub: &at.subject,
};
return Ok(HttpResponse::Ok().json(r));
}
#[derive(Serialize)]
struct Inactive {
active: bool,
}
Ok(HttpResponse::Ok().json(Inactive { active: false }))
}
pub async fn userinfo(req: HttpRequest, data: web::Data<SharedState>) -> Result<impl Responder> {
let header = req
.headers()
.get("authorization")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if !header.starts_with("Bearer ") {
return Ok(HttpResponse::Unauthorized().body("missing bearer token"));
}
let token = &header["Bearer ".len()..];
let s = data.read();
let at = match s.access_tokens.get(token) {
Some(a) => a.clone(),
None => return Ok(HttpResponse::Unauthorized().body("invalid token")),
};
if OffsetDateTime::now_utc() > at.expires_at {
return Ok(HttpResponse::Unauthorized().body("token expired"));
}
let username = s
.users
.get(&at.subject)
.cloned()
.unwrap_or_else(|| "unknown".into());
#[derive(Serialize)]
struct Ui {
sub: String,
preferred_username: String,
}
Ok(HttpResponse::Ok().json(Ui {
sub: at.subject,
preferred_username: username,
}))
}
pub fn init_state() -> SharedState {
let mut clients = HashMap::new();
clients.insert(
"demo-public".into(),
Client {
client_id: "demo-public".into(),
client_secret: None,
redirect_uris: vec!["http://localhost:8081/callback".into()],
},
);
clients.insert(
"demo-confidential".into(),
Client {
client_id: "demo-confidential".into(),
client_secret: Some("topsecret".into()),
redirect_uris: vec![],
},
);
let mut users = HashMap::new();
users.insert("user-1".into(), "alice".into());
Arc::new(RwLock::new(AppState {
clients,
codes: HashMap::new(),
access_tokens: HashMap::new(),
refresh_tokens: HashMap::new(),
users,
}))
}