1use std::collections::HashMap;
7use std::time::SystemTime;
8
9use parking_lot::Mutex;
10use rand::Rng;
11use sha2::{Digest, Sha256};
12
13const TOKEN_LENGTH: usize = 32;
15
16const ONE_TIME_EXPIRATION_SECS: i64 = 10 * 60;
18
19const BAN_DURATION_SECS: i64 = 10 * 60;
21
22struct OneTimeToken {
24 user_id: i64,
25 expires_at: i64,
26}
27
28pub struct TokenManager {
33 one_time_tokens: Mutex<HashMap<String, OneTimeToken>>,
34 blocked_ips: Mutex<HashMap<String, i64>>,
35 tokens_salt: String,
36 tokens_dir: String,
37}
38
39impl TokenManager {
40 pub fn new(tokens_dir: String, tokens_salt: String) -> Self {
42 Self {
43 one_time_tokens: Mutex::new(HashMap::new()),
44 blocked_ips: Mutex::new(HashMap::new()),
45 tokens_dir,
46 tokens_salt,
47 }
48 }
49
50 pub fn gen_one_time_token(&self, user_id: i64) -> String {
52 let token = gen_token();
53 let expires_at = now_timestamp() + ONE_TIME_EXPIRATION_SECS;
54 self.one_time_tokens.lock().insert(
55 token.clone(),
56 OneTimeToken {
57 user_id,
58 expires_at,
59 },
60 );
61 token
62 }
63
64 pub fn issue_permanent_token(&self, one_time_token: &str) -> Option<String> {
66 let user_id = {
67 let tokens = self.one_time_tokens.lock();
68 let data = tokens.get(one_time_token)?;
69 if now_timestamp() > data.expires_at {
70 return None;
71 }
72 data.user_id
73 };
74 self.one_time_tokens.lock().remove(one_time_token);
75
76 let permanent = gen_token();
77 let hashed = self.hash_token(&permanent);
78
79 let path = std::path::Path::new(&self.tokens_dir).join(&hashed);
81 if let Some(parent) = path.parent() {
82 let _ = std::fs::create_dir_all(parent);
83 }
84 let _ = std::fs::write(&path, user_id.to_string());
85
86 Some(permanent)
87 }
88
89 pub fn find_user_id(&self, token: &str) -> Option<i64> {
91 let hashed = self.hash_token(token);
92 let path = std::path::Path::new(&self.tokens_dir).join(&hashed);
93 let data = std::fs::read_to_string(&path).ok()?;
94 data.parse().ok()
95 }
96
97 pub fn is_ip_blocked(&self, ip: &str) -> bool {
99 let blocked = self.blocked_ips.lock();
100 if let Some(unblock_time) = blocked.get(ip) {
101 now_timestamp() < *unblock_time
102 } else {
103 false
104 }
105 }
106
107 pub fn block_ip(&self, ip: &str) {
109 self.blocked_ips
110 .lock()
111 .insert(ip.to_string(), now_timestamp() + BAN_DURATION_SECS);
112 }
113
114 pub fn get_ip_from_remote_addr(remote_addr: &str) -> String {
116 remote_addr
117 .rsplit_once(':')
118 .map(|(host, _)| host.to_string())
119 .unwrap_or(remote_addr.to_string())
120 }
121
122 fn hash_token(&self, token: &str) -> String {
123 let mut hasher = Sha256::new();
124 hasher.update(token.as_bytes());
125 hasher.update(self.tokens_salt.as_bytes());
126 hex::encode(hasher.finalize())
127 }
128}
129
130fn gen_token() -> String {
131 let mut rng = rand::thread_rng();
132 let bytes: [u8; TOKEN_LENGTH] = rng.gen();
133 hex::encode(bytes)
134}
135
136fn now_timestamp() -> i64 {
137 SystemTime::now()
138 .duration_since(SystemTime::UNIX_EPOCH)
139 .map(|d| d.as_secs() as i64)
140 .unwrap_or(0)
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146
147 #[test]
148 fn test_gen_token_length() {
149 assert_eq!(gen_token().len(), 64);
150 }
151
152 #[test]
153 fn test_ip_extraction() {
154 assert_eq!(
155 TokenManager::get_ip_from_remote_addr("1.2.3.4:8080"),
156 "1.2.3.4"
157 );
158 assert_eq!(TokenManager::get_ip_from_remote_addr("1.2.3.4"), "1.2.3.4");
159 }
160
161 #[test]
162 fn test_block_unblock() {
163 let mgr = TokenManager::new("/tmp/test_tokens".into(), "salt".into());
164 assert!(!mgr.is_ip_blocked("1.2.3.4"));
165 mgr.block_ip("1.2.3.4");
166 assert!(mgr.is_ip_blocked("1.2.3.4"));
167 }
168}