trojan_auth/
reloadable.rs1use std::sync::Arc;
4
5use async_trait::async_trait;
6use parking_lot::RwLock;
7
8use crate::error::AuthError;
9use crate::result::AuthResult;
10use crate::traits::AuthBackend;
11
12pub struct ReloadableAuth {
27 inner: RwLock<Arc<dyn AuthBackend>>,
28}
29
30impl ReloadableAuth {
31 pub fn new<A: AuthBackend + 'static>(auth: A) -> Self {
33 Self {
34 inner: RwLock::new(Arc::new(auth)),
35 }
36 }
37
38 pub fn reload<A: AuthBackend + 'static>(&self, auth: A) {
43 let mut inner = self.inner.write();
44 *inner = Arc::new(auth);
45 }
46
47 pub fn reload_arc(&self, auth: Arc<dyn AuthBackend>) {
49 let mut inner = self.inner.write();
50 *inner = auth;
51 }
52
53 #[inline]
58 pub fn get(&self) -> Arc<dyn AuthBackend> {
59 self.inner.read().clone()
60 }
61}
62
63impl std::fmt::Debug for ReloadableAuth {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 f.debug_struct("ReloadableAuth").finish_non_exhaustive()
67 }
68}
69
70#[async_trait]
71impl AuthBackend for ReloadableAuth {
72 async fn verify(&self, hash: &str) -> Result<AuthResult, AuthError> {
73 let backend = self.get();
75 backend.verify(hash).await
76 }
77
78 async fn record_traffic(&self, user_id: &str, bytes: u64) -> Result<(), AuthError> {
79 let backend = self.get();
80 backend.record_traffic(user_id, bytes).await
81 }
82}
83
84#[cfg(test)]
85mod tests {
86 use super::*;
87 use crate::hash::sha224_hex;
88 use crate::memory::MemoryAuth;
89
90 #[tokio::test]
91 async fn test_reload() {
92 let auth = ReloadableAuth::new(MemoryAuth::from_passwords(["old_password"]));
93
94 let old_hash = sha224_hex("old_password");
95 let new_hash = sha224_hex("new_password");
96
97 assert!(auth.verify(&old_hash).await.is_ok());
99 assert!(auth.verify(&new_hash).await.is_err());
100
101 auth.reload(MemoryAuth::from_passwords(["new_password"]));
103
104 assert!(auth.verify(&old_hash).await.is_err());
106 assert!(auth.verify(&new_hash).await.is_ok());
107 }
108}