1use log::debug;
2use reqwest::Client as HttpClient;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::sync::{Arc, Mutex};
6use std::time::{Duration, SystemTime, UNIX_EPOCH};
7use tokio::sync::Mutex as AsyncMutex;
8
9use crate::error::ApiError;
10
11#[derive(Debug, Clone, Copy, Eq, Hash, PartialEq, Serialize)]
12#[serde(into = "i32")]
13pub enum AccessTokenType {
14 Plugin = 0,
15 VirtualPlugin = 1,
16 UserPlugin = 2,
17}
18
19impl From<AccessTokenType> for i32 {
20 fn from(token_type: AccessTokenType) -> i32 {
21 token_type as i32
22 }
23}
24
25#[derive(Debug, Serialize)]
26struct UserAuthRequest {
27 code: String,
28 grant_type: String,
29}
30
31#[derive(Debug, Serialize)]
32struct UserRefreshTokenRequest {
33 refresh_token: String,
34 #[serde(rename = "type")]
35 token_type: String,
36}
37
38#[derive(Debug, Serialize)]
39struct PluginTokenRequest {
40 plugin_id: String,
41 plugin_secret: String,
42 #[serde(rename = "type")]
43 token_type: AccessTokenType,
44}
45
46#[derive(Debug, Deserialize)]
47pub struct TokenResponseError {
48 pub code: i32,
49 pub msg: String,
50}
51
52#[derive(Debug, Deserialize)]
53pub struct PluginTokenResponseData {
54 pub token: String,
55 pub expire_time: u64,
56}
57
58#[derive(Debug, Deserialize)]
59pub struct PluginTokenResponse {
60 pub data: Option<PluginTokenResponseData>,
61 pub error: TokenResponseError,
62}
63
64#[derive(Debug, Deserialize)]
65pub struct UserTokenResponseData {
66 pub token: String,
67 pub expire_time: u64,
68 pub refresh_token: String,
69 pub refresh_token_expire_time: u64,
70 pub saas_tenant_key: Option<String>,
71 pub user_key: Option<String>,
72}
73
74#[derive(Debug, Deserialize)]
75pub struct UserTokenResponse {
76 pub data: Option<UserTokenResponseData>,
77 pub error: TokenResponseError,
78}
79
80#[derive(Clone)]
81pub struct CachedToken {
82 pub token_type: AccessTokenType,
83 pub token: String,
84 pub expired_at: u64,
85 pub refresh_token: Option<String>,
86 pub refresh_token_expired_at: Option<u64>,
87}
88
89#[derive(Clone)]
90pub struct TokenConfig {
91 plugin_id: String,
92 plugin_secret: String,
93 base_url: String,
94}
95
96#[derive(Clone)]
97pub struct TokenManager {
98 config: TokenConfig,
99 http_client: HttpClient,
100 cache: Arc<Mutex<HashMap<String, CachedToken>>>,
101 refresh_locks: Arc<Mutex<HashMap<String, Arc<AsyncMutex<()>>>>>,
102}
103
104impl TokenManager {
105 pub fn new(
106 plugin_id: impl Into<String>,
107 plugin_secret: impl Into<String>,
108 base_url: impl Into<String>,
109 ) -> Self {
110 let config = TokenConfig {
111 plugin_id: plugin_id.into(),
112 plugin_secret: plugin_secret.into(),
113 base_url: base_url.into(),
114 };
115
116 let http_client = HttpClient::builder()
117 .timeout(Duration::from_secs(30))
118 .build()
119 .expect("Failed to create HTTP client");
120
121 Self {
122 config,
123 http_client,
124 cache: Arc::new(Mutex::new(HashMap::new())),
125 refresh_locks: Arc::new(Mutex::new(HashMap::new())),
126 }
127 }
128
129 pub async fn auth_user_by_code(
130 &self,
131 code: &str,
132 ) -> Result<String, Box<dyn std::error::Error>> {
133 let response: UserTokenResponse = self
134 .request(
135 "authen/user_plugin_token",
136 &UserAuthRequest {
137 code: code.to_string(),
138 grant_type: "authorization_code".to_string(),
139 },
140 true,
141 )
142 .await?;
143
144 if response.error.code != 0 {
145 return Err(Box::new(ApiError::TokenError(response.error.msg)));
146 }
147
148 let data = response.data.unwrap();
149 let _ = self.cache_user_token(&data.user_key.clone().unwrap(), &data);
150 println!("Auth User By Code {:?}", data.refresh_token);
151 Ok(data.token)
152 }
153
154 pub async fn get_user_token(
155 &self,
156 user_key: &str,
157 ) -> Result<String, Box<dyn std::error::Error>> {
158 {
159 let cache = self.cache.lock().unwrap();
160 if let Some(cached_token) = cache.get(user_key) {
161 if !self.is_token_expired(cached_token) {
162 debug!("Using cached token");
163 return Ok(cached_token.token.clone());
164 }
165 }
166 }
167
168 let refresh_lock = self.get_refresh_lock(user_key);
169 let _guard = refresh_lock.lock().await;
170
171 {
172 let cache = self.cache.lock().unwrap();
173 if let Some(cached_token) = cache.get(user_key) {
174 if !self.is_token_expired(cached_token) {
175 debug!("Using cached token after lock");
176 return Ok(cached_token.token.clone());
177 }
178 }
179 }
180
181 self.refresh_user_token(user_key).await
182 }
183
184 pub async fn require_plugin_token(&self) -> Result<String, Box<dyn std::error::Error>> {
185 {
186 let cache = self.cache.lock().unwrap();
187 if let Some(cached_token) = cache.get("_plugin") {
188 if !self.is_token_expired(cached_token) {
189 debug!("Using cached token");
190 return Ok(cached_token.token.clone());
191 }
192 }
193 }
194
195 let refresh_lock = self.get_refresh_lock("_plugin");
196 let _guard = refresh_lock.lock().await;
197
198 {
199 let cache = self.cache.lock().unwrap();
200 if let Some(cached_token) = cache.get("_plugin") {
201 if !self.is_token_expired(cached_token) {
202 debug!("Using cached plugin token after lock");
203 return Ok(cached_token.token.clone());
204 }
205 }
206 }
207
208 let token_response = self.fetch_plugin_token().await?;
209 let data = token_response.data.unwrap();
210 let token = data.token.clone();
211
212 {
213 let mut cache = self.cache.lock().unwrap();
214 cache.insert(
215 "_plugin".to_owned(),
216 CachedToken {
217 token_type: AccessTokenType::Plugin,
218 token: data.token,
219 expired_at: Self::get_timestamp() + data.expire_time,
220 refresh_token: None,
221 refresh_token_expired_at: None,
222 },
223 );
224 }
225
226 Ok(token)
227 }
228
229 fn get_timestamp() -> u64 {
230 SystemTime::now()
231 .duration_since(UNIX_EPOCH)
232 .unwrap()
233 .as_secs()
234 }
235
236 fn is_token_expired(&self, token: &CachedToken) -> bool {
237 Self::get_timestamp() >= (token.expired_at - 60)
238 }
239
240 async fn fetch_plugin_token(&self) -> Result<PluginTokenResponse, Box<dyn std::error::Error>> {
241 let response: PluginTokenResponse = self
242 .request(
243 "authen/plugin_token",
244 &PluginTokenRequest {
245 plugin_id: self.config.plugin_id.clone(),
246 plugin_secret: self.config.plugin_secret.clone(),
247 token_type: AccessTokenType::Plugin,
248 },
249 false,
250 )
251 .await?;
252
253 if response.error.code != 0 {
254 return Err(Box::new(ApiError::TokenError(response.error.msg)));
255 }
256
257 Ok(response)
258 }
259
260 async fn refresh_user_token(
261 &self,
262 user_key: &str,
263 ) -> Result<String, Box<dyn std::error::Error>> {
264 let refresh_token = {
265 let cache = self.cache.lock().unwrap();
266 let user_token = cache
267 .get(user_key)
268 .ok_or_else(|| ApiError::TokenError("user token not found".to_string()))?;
269
270 if Self::get_timestamp() >= user_token.refresh_token_expired_at.unwrap_or_default() {
271 return Err(Box::new(ApiError::TokenError(
272 "refresh token expired".to_string(),
273 )));
274 }
275
276 user_token
277 .refresh_token
278 .clone()
279 .ok_or_else(|| ApiError::TokenError("refresh token not found".to_string()))?
280 };
281
282 let response: UserTokenResponse = self
283 .request(
284 "authen/refresh_token",
285 &UserRefreshTokenRequest {
286 refresh_token,
287 token_type: "1".to_string(),
288 },
289 true,
290 )
291 .await?;
292
293 if response.error.code != 0 {
294 return Err(Box::new(ApiError::TokenError(response.error.msg)));
295 }
296
297 let data = response
298 .data
299 .ok_or_else(|| ApiError::TokenError("no data in response".to_string()))?;
300
301 let _ = self.cache_user_token(user_key, &data);
302 Ok(data.token)
303 }
304
305 pub fn cache_user_token(
306 &self,
307 user_key: &str,
308 data: &UserTokenResponseData,
309 ) -> Result<(), Box<dyn std::error::Error>> {
310 let mut cache = self.cache.lock().unwrap();
311
312 cache.insert(
313 user_key.to_string(),
314 CachedToken {
315 token_type: AccessTokenType::UserPlugin,
316 token: data.token.clone(),
317 expired_at: Self::get_timestamp() + data.expire_time,
318 refresh_token: Some(data.refresh_token.clone()),
319 refresh_token_expired_at: Some(
320 Self::get_timestamp() + data.refresh_token_expire_time,
321 ),
322 },
323 );
324 Ok(())
325 }
326
327 async fn request<T: Serialize, R: for<'de> Deserialize<'de>>(
328 &self,
329 path: &str,
330 body: &T,
331 need_plugin_token: bool,
332 ) -> Result<R, Box<dyn std::error::Error>> {
333 let url = format!("{}/open_api/{}", self.config.base_url, path);
334
335 let mut request = self
336 .http_client
337 .post(&url)
338 .header("Content-Type", "application/json");
339
340 if need_plugin_token {
341 let token = Box::pin(self.require_plugin_token()).await?;
342 request = request.header("X-Plugin-Token", token);
343 }
344
345 let response = request.json(body).send().await?.json().await?;
346 Ok(response)
347 }
348
349 fn get_refresh_lock(&self, key: &str) -> Arc<AsyncMutex<()>> {
350 let mut locks = self.refresh_locks.lock().unwrap();
351 locks
352 .entry(key.to_string())
353 .or_insert_with(|| Arc::new(AsyncMutex::new(())))
354 .clone()
355 }
356}