meegle/
token.rs

1use log::debug;
2use reqwest::Client as HttpClient;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::sync::{Arc, Mutex};
6use std::time::{Duration, SystemTime, UNIX_EPOCH};
7use tokio::sync::Mutex as AsyncMutex;
8
9use crate::error::ApiError;
10
11#[derive(Debug, Clone, Copy, Eq, Hash, PartialEq, Serialize)]
12#[serde(into = "i32")]
13pub enum AccessTokenType {
14    Plugin = 0,
15    VirtualPlugin = 1,
16    UserPlugin = 2,
17}
18
19impl From<AccessTokenType> for i32 {
20    fn from(token_type: AccessTokenType) -> i32 {
21        token_type as i32
22    }
23}
24
25#[derive(Debug, Serialize)]
26struct UserAuthRequest {
27    code: String,
28    grant_type: String,
29}
30
31#[derive(Debug, Serialize)]
32struct UserRefreshTokenRequest {
33    refresh_token: String,
34    #[serde(rename = "type")]
35    token_type: String,
36}
37
38#[derive(Debug, Serialize)]
39struct PluginTokenRequest {
40    plugin_id: String,
41    plugin_secret: String,
42    #[serde(rename = "type")]
43    token_type: AccessTokenType,
44}
45
46#[derive(Debug, Deserialize)]
47pub struct TokenResponseError {
48    pub code: i32,
49    pub msg: String,
50}
51
52#[derive(Debug, Deserialize)]
53pub struct PluginTokenResponseData {
54    pub token: String,
55    pub expire_time: u64,
56}
57
58#[derive(Debug, Deserialize)]
59pub struct PluginTokenResponse {
60    pub data: Option<PluginTokenResponseData>,
61    pub error: TokenResponseError,
62}
63
64#[derive(Debug, Deserialize)]
65pub struct UserTokenResponseData {
66    pub token: String,
67    pub expire_time: u64,
68    pub refresh_token: String,
69    pub refresh_token_expire_time: u64,
70    pub saas_tenant_key: Option<String>,
71    pub user_key: Option<String>,
72}
73
74#[derive(Debug, Deserialize)]
75pub struct UserTokenResponse {
76    pub data: Option<UserTokenResponseData>,
77    pub error: TokenResponseError,
78}
79
80#[derive(Clone)]
81pub struct CachedToken {
82    pub token_type: AccessTokenType,
83    pub token: String,
84    pub expired_at: u64,
85    pub refresh_token: Option<String>,
86    pub refresh_token_expired_at: Option<u64>,
87}
88
89#[derive(Clone)]
90pub struct TokenConfig {
91    plugin_id: String,
92    plugin_secret: String,
93    base_url: String,
94}
95
96#[derive(Clone)]
97pub struct TokenManager {
98    config: TokenConfig,
99    http_client: HttpClient,
100    cache: Arc<Mutex<HashMap<String, CachedToken>>>,
101    refresh_locks: Arc<Mutex<HashMap<String, Arc<AsyncMutex<()>>>>>,
102}
103
104impl TokenManager {
105    pub fn new(
106        plugin_id: impl Into<String>,
107        plugin_secret: impl Into<String>,
108        base_url: impl Into<String>,
109    ) -> Self {
110        let config = TokenConfig {
111            plugin_id: plugin_id.into(),
112            plugin_secret: plugin_secret.into(),
113            base_url: base_url.into(),
114        };
115
116        let http_client = HttpClient::builder()
117            .timeout(Duration::from_secs(30))
118            .build()
119            .expect("Failed to create HTTP client");
120
121        Self {
122            config,
123            http_client,
124            cache: Arc::new(Mutex::new(HashMap::new())),
125            refresh_locks: Arc::new(Mutex::new(HashMap::new())),
126        }
127    }
128
129    pub async fn auth_user_by_code(
130        &self,
131        code: &str,
132    ) -> Result<String, Box<dyn std::error::Error>> {
133        let response: UserTokenResponse = self
134            .request(
135                "authen/user_plugin_token",
136                &UserAuthRequest {
137                    code: code.to_string(),
138                    grant_type: "authorization_code".to_string(),
139                },
140                true,
141            )
142            .await?;
143
144        if response.error.code != 0 {
145            return Err(Box::new(ApiError::TokenError(response.error.msg)));
146        }
147
148        let data = response.data.unwrap();
149        let _ = self.cache_user_token(&data.user_key.clone().unwrap(), &data);
150        println!("Auth User By Code {:?}", data.refresh_token);
151        Ok(data.token)
152    }
153
154    pub async fn get_user_token(
155        &self,
156        user_key: &str,
157    ) -> Result<String, Box<dyn std::error::Error>> {
158        {
159            let cache = self.cache.lock().unwrap();
160            if let Some(cached_token) = cache.get(user_key) {
161                if !self.is_token_expired(cached_token) {
162                    debug!("Using cached token");
163                    return Ok(cached_token.token.clone());
164                }
165            }
166        }
167
168        let refresh_lock = self.get_refresh_lock(user_key);
169        let _guard = refresh_lock.lock().await;
170
171        {
172            let cache = self.cache.lock().unwrap();
173            if let Some(cached_token) = cache.get(user_key) {
174                if !self.is_token_expired(cached_token) {
175                    debug!("Using cached token after lock");
176                    return Ok(cached_token.token.clone());
177                }
178            }
179        }
180
181        self.refresh_user_token(user_key).await
182    }
183
184    pub async fn require_plugin_token(&self) -> Result<String, Box<dyn std::error::Error>> {
185        {
186            let cache = self.cache.lock().unwrap();
187            if let Some(cached_token) = cache.get("_plugin") {
188                if !self.is_token_expired(cached_token) {
189                    debug!("Using cached token");
190                    return Ok(cached_token.token.clone());
191                }
192            }
193        }
194
195        let refresh_lock = self.get_refresh_lock("_plugin");
196        let _guard = refresh_lock.lock().await;
197
198        {
199            let cache = self.cache.lock().unwrap();
200            if let Some(cached_token) = cache.get("_plugin") {
201                if !self.is_token_expired(cached_token) {
202                    debug!("Using cached plugin token after lock");
203                    return Ok(cached_token.token.clone());
204                }
205            }
206        }
207
208        let token_response = self.fetch_plugin_token().await?;
209        let data = token_response.data.unwrap();
210        let token = data.token.clone();
211
212        {
213            let mut cache = self.cache.lock().unwrap();
214            cache.insert(
215                "_plugin".to_owned(),
216                CachedToken {
217                    token_type: AccessTokenType::Plugin,
218                    token: data.token,
219                    expired_at: Self::get_timestamp() + data.expire_time,
220                    refresh_token: None,
221                    refresh_token_expired_at: None,
222                },
223            );
224        }
225
226        Ok(token)
227    }
228
229    fn get_timestamp() -> u64 {
230        SystemTime::now()
231            .duration_since(UNIX_EPOCH)
232            .unwrap()
233            .as_secs()
234    }
235
236    fn is_token_expired(&self, token: &CachedToken) -> bool {
237        Self::get_timestamp() >= (token.expired_at - 60)
238    }
239
240    async fn fetch_plugin_token(&self) -> Result<PluginTokenResponse, Box<dyn std::error::Error>> {
241        let response: PluginTokenResponse = self
242            .request(
243                "authen/plugin_token",
244                &PluginTokenRequest {
245                    plugin_id: self.config.plugin_id.clone(),
246                    plugin_secret: self.config.plugin_secret.clone(),
247                    token_type: AccessTokenType::Plugin,
248                },
249                false,
250            )
251            .await?;
252
253        if response.error.code != 0 {
254            return Err(Box::new(ApiError::TokenError(response.error.msg)));
255        }
256
257        Ok(response)
258    }
259
260    async fn refresh_user_token(
261        &self,
262        user_key: &str,
263    ) -> Result<String, Box<dyn std::error::Error>> {
264        let refresh_token = {
265            let cache = self.cache.lock().unwrap();
266            let user_token = cache
267                .get(user_key)
268                .ok_or_else(|| ApiError::TokenError("user token not found".to_string()))?;
269
270            if Self::get_timestamp() >= user_token.refresh_token_expired_at.unwrap_or_default() {
271                return Err(Box::new(ApiError::TokenError(
272                    "refresh token expired".to_string(),
273                )));
274            }
275
276            user_token
277                .refresh_token
278                .clone()
279                .ok_or_else(|| ApiError::TokenError("refresh token not found".to_string()))?
280        };
281
282        let response: UserTokenResponse = self
283            .request(
284                "authen/refresh_token",
285                &UserRefreshTokenRequest {
286                    refresh_token,
287                    token_type: "1".to_string(),
288                },
289                true,
290            )
291            .await?;
292
293        if response.error.code != 0 {
294            return Err(Box::new(ApiError::TokenError(response.error.msg)));
295        }
296
297        let data = response
298            .data
299            .ok_or_else(|| ApiError::TokenError("no data in response".to_string()))?;
300
301        let _ = self.cache_user_token(user_key, &data);
302        Ok(data.token)
303    }
304
305    pub fn cache_user_token(
306        &self,
307        user_key: &str,
308        data: &UserTokenResponseData,
309    ) -> Result<(), Box<dyn std::error::Error>> {
310        let mut cache = self.cache.lock().unwrap();
311
312        cache.insert(
313            user_key.to_string(),
314            CachedToken {
315                token_type: AccessTokenType::UserPlugin,
316                token: data.token.clone(),
317                expired_at: Self::get_timestamp() + data.expire_time,
318                refresh_token: Some(data.refresh_token.clone()),
319                refresh_token_expired_at: Some(
320                    Self::get_timestamp() + data.refresh_token_expire_time,
321                ),
322            },
323        );
324        Ok(())
325    }
326
327    async fn request<T: Serialize, R: for<'de> Deserialize<'de>>(
328        &self,
329        path: &str,
330        body: &T,
331        need_plugin_token: bool,
332    ) -> Result<R, Box<dyn std::error::Error>> {
333        let url = format!("{}/open_api/{}", self.config.base_url, path);
334
335        let mut request = self
336            .http_client
337            .post(&url)
338            .header("Content-Type", "application/json");
339
340        if need_plugin_token {
341            let token = Box::pin(self.require_plugin_token()).await?;
342            request = request.header("X-Plugin-Token", token);
343        }
344
345        let response = request.json(body).send().await?.json().await?;
346        Ok(response)
347    }
348
349    fn get_refresh_lock(&self, key: &str) -> Arc<AsyncMutex<()>> {
350        let mut locks = self.refresh_locks.lock().unwrap();
351        locks
352            .entry(key.to_string())
353            .or_insert_with(|| Arc::new(AsyncMutex::new(())))
354            .clone()
355    }
356}