Skip to main content

openlark_auth/
token_provider.rs

1//! openlark-auth 的 TokenProvider 实现
2//!
3//! `openlark-core` 通过 `TokenProvider` 抽象获取 token,而不关心具体获取/刷新/缓存策略。
4//! 这里提供一个带缓存的实现:缓存 token 并在过期前复用。
5
6use openlark_core::{
7    SDKResult,
8    auth::{TokenProvider, TokenRequest},
9    config::Config,
10    constants::{AccessTokenType, AppType},
11    error::{api_error, configuration_error},
12};
13use serde_json::{Value, json};
14use std::collections::HashMap;
15use std::future::Future;
16use std::pin::Pin;
17use std::sync::Arc;
18use std::time::{SystemTime, UNIX_EPOCH};
19use tokio::sync::RwLock;
20
21/// 缓存的 token 信息
22#[derive(Clone)]
23struct CachedToken {
24    /// token 值
25    token: String,
26    /// 过期时间戳(Unix 时间戳,秒)
27    expires_at: i64,
28}
29
30impl std::fmt::Debug for CachedToken {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        f.debug_struct("CachedToken")
33            .field("token", &"***")
34            .field("expires_at", &self.expires_at)
35            .finish()
36    }
37}
38
39impl CachedToken {
40    fn now_epoch_secs() -> i64 {
41        SystemTime::now()
42            .duration_since(UNIX_EPOCH)
43            .map(|d| d.as_secs() as i64)
44            .unwrap_or(0)
45    }
46
47    /// 创建新的缓存 token
48    fn new(token: String, expires_in_seconds: i64) -> Self {
49        let now = Self::now_epoch_secs();
50        // 提前 60 秒过期,避免临界情况(小于 60 秒则视为立即过期)
51        let expires_at = now.saturating_add(expires_in_seconds.saturating_sub(60));
52
53        Self { token, expires_at }
54    }
55
56    /// 检查 token 是否已过期
57    fn is_expired(&self) -> bool {
58        Self::now_epoch_secs() >= self.expires_at
59    }
60}
61
62/// 基于 openlark-auth API 的 TokenProvider(带缓存)
63#[derive(Debug)]
64pub struct AuthTokenProvider {
65    config: Config,
66    /// token 缓存:key 为 token 类型字符串,value 为缓存的 token
67    cache: Arc<RwLock<HashMap<String, CachedToken>>>,
68}
69
70impl Clone for AuthTokenProvider {
71    fn clone(&self) -> Self {
72        Self {
73            config: self.config.clone(),
74            cache: Arc::clone(&self.cache),
75        }
76    }
77}
78
79impl AuthTokenProvider {
80    /// 创建基于 openlark-auth API 的 TokenProvider 实例
81    ///
82    /// # 参数
83    /// - `config`: SDK 配置信息
84    pub fn new(config: Config) -> Self {
85        Self {
86            config,
87            cache: Arc::new(RwLock::new(HashMap::new())),
88        }
89    }
90
91    /// 生成缓存键
92    fn cache_key(
93        token_type: &AccessTokenType,
94        app_type: &AppType,
95        request: &TokenRequest,
96    ) -> String {
97        match token_type {
98            AccessTokenType::Tenant => {
99                let tenant_key = request.tenant_key.as_deref().unwrap_or("default");
100                format!("{token_type:?}_{app_type:?}_{tenant_key}")
101            }
102            AccessTokenType::App if app_type == &AppType::Marketplace => {
103                let app_ticket = request.app_ticket.as_deref().unwrap_or("default");
104                format!("{token_type:?}_{app_type:?}_{app_ticket}")
105            }
106            _ => format!("{token_type:?}_{app_type:?}"),
107        }
108    }
109
110    async fn get_cached(&self, cache_key: &str) -> Option<String> {
111        let cache = self.cache.read().await;
112        cache
113            .get(cache_key)
114            .filter(|cached| !cached.is_expired())
115            .map(|cached| cached.token.clone())
116    }
117
118    async fn set_cached(&self, cache_key: String, token: String, expires_in_seconds: i64) {
119        let cached = CachedToken::new(token, expires_in_seconds);
120        self.cache.write().await.insert(cache_key, cached);
121    }
122
123    async fn get_or_fetch<F, Fut>(&self, cache_key: String, fetch: F) -> SDKResult<String>
124    where
125        F: FnOnce() -> Fut,
126        Fut: Future<Output = SDKResult<(String, i64)>>,
127    {
128        if let Some(token) = self.get_cached(&cache_key).await {
129            return Ok(token);
130        }
131
132        let (token, expires_in_seconds) = fetch().await?;
133        self.set_cached(cache_key, token.clone(), expires_in_seconds)
134            .await;
135        Ok(token)
136    }
137
138    async fn fetch_token_via_http(
139        &self,
140        endpoint: &str,
141        payload: Value,
142        token_field: &str,
143    ) -> SDKResult<(String, i64)> {
144        let url = format!(
145            "{}/{}",
146            self.config.base_url().trim_end_matches('/'),
147            endpoint.trim_start_matches('/')
148        );
149
150        let response = self
151            .config
152            .http_client()
153            .post(&url)
154            .json(&payload)
155            .send()
156            .await
157            .map_err(|e| api_error(500, endpoint, format!("请求飞书认证接口失败: {e}"), None))?;
158
159        let status = response.status().as_u16();
160        let body: Value = response
161            .json()
162            .await
163            .map_err(|e| api_error(status, endpoint, format!("解析飞书认证响应失败: {e}"), None))?;
164
165        let code = body.get("code").and_then(Value::as_i64).unwrap_or(-1);
166        if code != 0 {
167            let msg = body
168                .get("msg")
169                .and_then(Value::as_str)
170                .unwrap_or("未知错误");
171            return Err(api_error(
172                status,
173                endpoint,
174                format!("飞书认证接口返回错误: code={code}, msg={msg}"),
175                None,
176            ));
177        }
178
179        let token = body
180            .get(token_field)
181            .and_then(Value::as_str)
182            .ok_or_else(|| configuration_error(format!("飞书认证响应缺少字段: {token_field}")))?
183            .to_string();
184
185        let expires_in = body.get("expire").and_then(Value::as_i64).unwrap_or(7200);
186
187        Ok((token, expires_in))
188    }
189}
190
191impl TokenProvider for AuthTokenProvider {
192    fn get_token(
193        &self,
194        request: TokenRequest,
195    ) -> Pin<Box<dyn Future<Output = SDKResult<String>> + Send + '_>> {
196        Box::pin(async move {
197            match request.token_type {
198                AccessTokenType::App => {
199                    let cache_key =
200                        Self::cache_key(&AccessTokenType::App, &self.config.app_type(), &request);
201                    self.get_or_fetch(cache_key, || async {
202                        let (token, expires_in) = match self.config.app_type() {
203                            AppType::SelfBuild => {
204                                self.fetch_token_via_http(
205                                    "/open-apis/auth/v3/app_access_token/internal",
206                                    json!({
207                                        "app_id": self.config.app_id(),
208                                        "app_secret": self.config.app_secret(),
209                                    }),
210                                    "app_access_token",
211                                )
212                                .await?
213                            }
214                            AppType::Marketplace => {
215                                self.fetch_token_via_http(
216                                    "/open-apis/auth/v3/app_access_token",
217                                    json!({
218                                        "app_id": self.config.app_id(),
219                                        "app_secret": self.config.app_secret(),
220                                    }),
221                                    "app_access_token",
222                                )
223                                .await?
224                            }
225                        };
226                        Ok((token, expires_in))
227                    })
228                    .await
229                }
230                AccessTokenType::Tenant => {
231                    let cache_key = Self::cache_key(
232                        &AccessTokenType::Tenant,
233                        &self.config.app_type(),
234                        &request,
235                    );
236                    self.get_or_fetch(cache_key, || async {
237                    let (token, expires_in) = match self.config.app_type() {
238                        AppType::SelfBuild => {
239                            self.fetch_token_via_http(
240                                "/open-apis/auth/v3/tenant_access_token/internal",
241                                json!({
242                                    "app_id": self.config.app_id(),
243                                    "app_secret": self.config.app_secret(),
244                                }),
245                                "tenant_access_token",
246                            )
247                            .await?
248                        }
249                        AppType::Marketplace => {
250                            let app_ticket = request.app_ticket.clone().ok_or_else(|| {
251                                configuration_error(
252                                    "token_provider: marketplace app requires app_ticket to fetch tenant_access_token",
253                                )
254                            })?;
255
256                            self.fetch_token_via_http(
257                                "/open-apis/auth/v3/tenant_access_token",
258                                json!({
259                                    "app_id": self.config.app_id(),
260                                    "app_secret": self.config.app_secret(),
261                                    "app_ticket": app_ticket,
262                                }),
263                                "tenant_access_token",
264                            )
265                            .await?
266                        }
267                    };
268                    Ok((token, expires_in))
269                })
270                .await
271                }
272                AccessTokenType::User => Err(configuration_error(
273                    "token_provider: user token 不应由 core 自动获取,请在 RequestOption 中显式传入 user_access_token(或由上层自行实现 TokenProvider 扩展)。",
274                )),
275                AccessTokenType::None => Err(configuration_error(
276                    "token_provider: AccessTokenType::None 不应触发 token 获取",
277                )),
278            }
279        })
280    }
281}
282
283#[cfg(test)]
284#[allow(unused_imports)]
285mod tests {
286    use super::AuthTokenProvider;
287    use openlark_core::{
288        auth::{TokenProvider, TokenRequest},
289        config::Config,
290        constants::AppType,
291    };
292
293    #[tokio::test]
294    async fn tenant_token_fetch_no_longer_uses_noop_provider() {
295        let config = Config::builder()
296            .app_id("test_app_id")
297            .app_secret("test_app_secret")
298            .base_url("http://127.0.0.1:9")
299            .build();
300
301        let provider = AuthTokenProvider::new(config);
302        let err = provider
303            .get_token(TokenRequest::tenant())
304            .await
305            .expect_err("should fail on unreachable test endpoint");
306
307        assert!(!err.to_string().contains("NoOpTokenProvider"));
308    }
309
310    #[tokio::test]
311    async fn tenant_cache_key_should_include_tenant_key() {
312        let request = TokenRequest::tenant().tenant_key("tenant_key_001");
313
314        let key = AuthTokenProvider::cache_key(
315            &openlark_core::constants::AccessTokenType::Tenant,
316            &AppType::SelfBuild,
317            &request,
318        );
319
320        assert_eq!(key, "Tenant_SelfBuild_tenant_key_001");
321    }
322
323    #[tokio::test]
324    async fn app_cache_key_should_include_app_ticket_for_marketplace() {
325        let request = TokenRequest::app().app_ticket("ticket_001");
326
327        let key = AuthTokenProvider::cache_key(
328            &openlark_core::constants::AccessTokenType::App,
329            &AppType::Marketplace,
330            &request,
331        );
332
333        assert_eq!(key, "App_Marketplace_ticket_001");
334    }
335}