1use std::marker::PhantomData;
4use std::sync::Arc;
5
6use argon2::{
7 password_hash::{rand_core::OsRng, PasswordHasher, PasswordVerifier, SaltString},
8 Argon2, PasswordHash,
9};
10use async_trait::async_trait;
11use axum::extract::{FromRef, FromRequestParts};
12use axum::http::request::Parts;
13use parking_lot::RwLock;
14use serde::{de::DeserializeOwned, Deserialize, Serialize};
15use tower_sessions::Session;
16
17use crate::container::Container;
18use crate::Error;
19
20pub const SESSION_USER_ID_KEY: &str = "_auth.user_id";
21
22#[async_trait]
28pub trait Authenticatable: Send + Sync + Sized + Clone + 'static {
29 type Id: Serialize + DeserializeOwned + Send + Sync + Clone + 'static;
30
31 fn id(&self) -> Self::Id;
33
34 async fn find_by_id(container: &Container, id: &Self::Id) -> Result<Option<Self>, Error>;
36
37 async fn find_by_credentials(
40 container: &Container,
41 identifier: &str,
42 ) -> Result<Option<(Self, String)>, Error>;
43}
44
45#[derive(Default, Clone)]
48pub struct AuthManager {
49 #[allow(dead_code)]
50 inner: Arc<RwLock<AuthInner>>,
51}
52
53#[derive(Default)]
54struct AuthInner {
55 #[allow(dead_code)]
56 hasher_pepper: Option<String>,
57}
58
59impl AuthManager {
60 pub fn new() -> Self {
61 Self::default()
62 }
63
64 pub fn with_pepper(self, pepper: impl Into<String>) -> Self {
65 self.inner.write().hasher_pepper = Some(pepper.into());
66 self
67 }
68}
69
70pub fn hash_password(plain: &str) -> Result<String, Error> {
72 let salt = SaltString::generate(&mut OsRng);
73 let argon2 = Argon2::default();
74 argon2
75 .hash_password(plain.as_bytes(), &salt)
76 .map(|h| h.to_string())
77 .map_err(|e| Error::Internal(format!("password hash failed: {e}")))
78}
79
80pub fn verify_password(plain: &str, hash: &str) -> bool {
82 let Ok(parsed) = PasswordHash::new(hash) else {
83 return false;
84 };
85 Argon2::default()
86 .verify_password(plain.as_bytes(), &parsed)
87 .is_ok()
88}
89
90pub async fn attempt<U: Authenticatable>(
94 container: &Container,
95 identifier: &str,
96 password: &str,
97) -> Result<Option<U>, Error> {
98 let Some((user, hash)) = U::find_by_credentials(container, identifier).await? else {
99 return Ok(None);
100 };
101 if verify_password(password, &hash) {
102 Ok(Some(user))
103 } else {
104 Ok(None)
105 }
106}
107
108pub async fn login<U: Authenticatable>(session: &Session, user: &U) -> Result<(), Error> {
110 let id = user.id();
111 session
112 .insert(SESSION_USER_ID_KEY, id)
113 .await
114 .map_err(|e| Error::Internal(format!("session write failed: {e}")))?;
115 Ok(())
116}
117
118pub async fn logout(session: &Session) -> Result<(), Error> {
120 session
121 .remove::<serde_json::Value>(SESSION_USER_ID_KEY)
122 .await
123 .map_err(|e| Error::Internal(format!("session clear failed: {e}")))?;
124 Ok(())
125}
126
127pub struct Auth<U: Authenticatable>(pub U);
135
136#[async_trait]
137impl<U, S> FromRequestParts<S> for Auth<U>
138where
139 U: Authenticatable,
140 Container: FromRef<S>,
141 S: Send + Sync,
142{
143 type Rejection = Error;
144
145 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
146 let session = Session::from_request_parts(parts, state)
147 .await
148 .map_err(|_| Error::Unauthenticated)?;
149 let id: Option<U::Id> = session
150 .get(SESSION_USER_ID_KEY)
151 .await
152 .map_err(|e| Error::Internal(e.to_string()))?;
153 let id = id.ok_or(Error::Unauthenticated)?;
154 let container = Container::from_ref(state);
155 let user = U::find_by_id(&container, &id)
156 .await?
157 .ok_or(Error::Unauthenticated)?;
158 Ok(Auth(user))
159 }
160}
161
162pub struct OptionalAuth<U: Authenticatable>(pub Option<U>);
166
167#[async_trait]
168impl<U, S> FromRequestParts<S> for OptionalAuth<U>
169where
170 U: Authenticatable,
171 Container: FromRef<S>,
172 S: Send + Sync,
173{
174 type Rejection = Error;
175
176 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
177 let Ok(session) = Session::from_request_parts(parts, state).await else {
178 return Ok(OptionalAuth(None));
179 };
180 let Some(id): Option<U::Id> = session
181 .get(SESSION_USER_ID_KEY)
182 .await
183 .map_err(|e| Error::Internal(e.to_string()))?
184 else {
185 return Ok(OptionalAuth(None));
186 };
187 let container = Container::from_ref(state);
188 let user = U::find_by_id(&container, &id).await?;
189 Ok(OptionalAuth(user))
190 }
191}
192
193pub trait Policy<U, S> {
195 fn check(user: &U, ability: &str, subject: &S) -> bool;
196}
197
198pub fn authorize<P, U, S>(user: &U, ability: &str, subject: &S) -> Result<(), Error>
200where
201 P: Policy<U, S>,
202{
203 if P::check(user, ability, subject) {
204 Ok(())
205 } else {
206 Err(Error::forbidden(ability))
207 }
208}
209
210pub struct WebGuard;
213pub struct ApiGuard;
214
215pub trait Guard: Send + Sync + 'static {}
216impl Guard for WebGuard {}
217impl Guard for ApiGuard {}
218
219pub struct Guarded<U, G>(PhantomData<(U, G)>);