1use async_trait::async_trait;
2use axum_core::{
3 extract::FromRequestParts,
4 response::{IntoResponse, Response},
5};
6use http::{request::Parts, HeaderMap, StatusCode};
7use std::{num::ParseIntError, str::FromStr};
8
9#[derive(Clone, Copy, Debug)]
11pub enum Group {
12 Participant,
13 Sponsor,
14 Mlh,
15 Organizer,
16 Director,
17}
18
19impl FromStr for Group {
20 type Err = String;
21
22 fn from_str(s: &str) -> Result<Self, Self::Err> {
23 let uniform = s.to_lowercase();
24 match uniform.as_str() {
25 "participant" => Ok(Self::Participant),
26 "sponsor" => Ok(Self::Sponsor),
27 "mlh" => Ok(Self::Mlh),
28 "organizer" => Ok(Self::Organizer),
29 "director" => Ok(Self::Director),
30 _ => Err(uniform),
31 }
32 }
33}
34
35#[derive(Debug)]
37pub struct CurrentUser {
38 pub id: i32,
40 pub first_name: String,
42 pub last_name: String,
44 pub email: String,
46 pub group: Group,
49 pub context: String,
51}
52
53impl CurrentUser {
54 pub fn is_admin(&self) -> bool {
56 self.context == "admin"
57 }
58
59 pub fn is_registering(&self) -> bool {
61 self.context == "manage"
62 }
63}
64
65#[async_trait]
66impl<S> FromRequestParts<S> for CurrentUser
67where
68 S: Send + Sync,
69{
70 type Rejection = InvalidRequestUser;
71
72 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
73 let id = get_header(&parts.headers, "X-Identity-ID")?.parse()?;
74 let first_name = get_header(&parts.headers, "X-Identity-First-Name")?.to_owned();
75 let last_name = get_header(&parts.headers, "X-Identity-Last-Name")?.to_owned();
76 let email = get_header(&parts.headers, "X-Identity-Email")?.to_owned();
77 let group = get_header(&parts.headers, "X-Identity-Group")?;
78 let context = get_header(&parts.headers, "X-Identity-Context")?.to_owned();
79
80 Ok(CurrentUser {
81 id,
82 first_name,
83 last_name,
84 email,
85 group: Group::from_str(group).map_err(InvalidRequestUser::InvalidGroup)?,
86 context,
87 })
88 }
89}
90
91fn get_header<'h>(
92 headers: &'h HeaderMap,
93 header: &'static str,
94) -> Result<&'h str, InvalidRequestUser> {
95 headers
96 .get(header)
97 .ok_or_else(|| InvalidRequestUser::MissingHeader(header))?
98 .to_str()
99 .map_err(|_| InvalidRequestUser::InvalidHeader(header))
100}
101
102#[derive(Debug)]
104pub enum InvalidRequestUser {
105 MissingHeader(&'static str),
106 InvalidHeader(&'static str),
107 InvalidGroup(String),
108 InvalidId,
109}
110
111impl From<ParseIntError> for InvalidRequestUser {
112 fn from(_: ParseIntError) -> Self {
113 Self::InvalidId
114 }
115}
116
117impl IntoResponse for InvalidRequestUser {
118 fn into_response(self) -> Response {
119 let message = match self {
120 Self::MissingHeader(header) => format!("missing header {header:?}"),
121 Self::InvalidHeader(header) => format!("invalid header value for {header:?}"),
122 Self::InvalidGroup(group) => format!("invalid group {group:?}"),
123 Self::InvalidId => format!("user id must be an integer"),
124 };
125
126 (StatusCode::UNAUTHORIZED, message).into_response()
127 }
128}