Skip to main content

tibba_session/
session.rs

1// Copyright 2026 Tree xie.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use super::{Error, LOG_TARGET};
16use axum::Json;
17use axum::extract::FromRequestParts;
18use axum::http::request::Parts;
19use axum::response::{IntoResponse, Response};
20use axum_extra::extract::cookie::{Key, SignedCookieJar};
21use cookie::CookieBuilder;
22use serde::{Deserialize, Serialize};
23use std::sync::Arc;
24use std::time::Duration;
25use tibba_cache::RedisCache;
26use tibba_state::CTX;
27use tibba_util::{from_timestamp, timestamp, uuid};
28use tracing::debug;
29
30type Result<T, E = tibba_error::Error> = std::result::Result<T, E>;
31
32/// 用户角色枚举,支持内置角色(Admin / SuperAdmin)和自定义角色。
33#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
34pub enum Role {
35    Admin,
36    SuperAdmin,
37    Custom(String),
38}
39
40impl From<&str> for Role {
41    fn from(s: &str) -> Self {
42        match s {
43            "admin" => Role::Admin,
44            "su" => Role::SuperAdmin,
45            _ => Role::Custom(s.to_string()),
46        }
47    }
48}
49
50/// Session 配置参数,包含签名密钥、Cookie 名称、TTL 和最大续期次数。
51#[derive(Debug, Clone, Serialize)]
52pub struct SessionParams {
53    /// Cookie 签名密钥,序列化时跳过(不暴露到外部)
54    #[serde(skip)]
55    key: Key,
56    /// 存储 Session ID 的 Cookie 名称
57    cookie: String,
58    /// Session 有效期(秒),默认 86400(24 小时)
59    ttl: i64,
60    /// 允许续期的最大次数,0 表示不允许续期
61    max_renewal: u8,
62}
63
64impl SessionParams {
65    /// 以签名密钥创建 SessionParams,其余字段使用默认值(TTL 24h,不允许续期)。
66    pub fn new(key: Key) -> Self {
67        Self {
68            key,
69            cookie: String::new(),
70            ttl: 24 * 60 * 60,
71            max_renewal: 0,
72        }
73    }
74
75    /// 设置存储 Session ID 的 Cookie 名称,支持链式调用。
76    #[must_use]
77    pub fn with_cookie(mut self, cookie: impl Into<String>) -> Self {
78        self.cookie = cookie.into();
79        self
80    }
81
82    /// 设置 Session 有效期(秒),支持链式调用。
83    #[must_use]
84    pub fn with_ttl(mut self, ttl: i64) -> Self {
85        self.ttl = ttl;
86        self
87    }
88
89    /// 设置允许续期的最大次数,支持链式调用。
90    #[must_use]
91    pub fn with_max_renewal(mut self, max_renewal: u8) -> Self {
92        self.max_renewal = max_renewal;
93        self
94    }
95}
96
97/// Session 的内部数据,序列化后存入 Redis。
98#[derive(Serialize, Deserialize, Default, Clone)]
99struct SessionData {
100    /// 用户 ID
101    user_id: i64,
102    /// Session 唯一标识(UUID)
103    id: String,
104    /// 签发时间戳(Unix 秒)
105    iat: i64,
106    /// 用户账号
107    account: String,
108    /// 已续期次数
109    renewal_count: u8,
110    /// 角色列表
111    roles: Vec<String>,
112    /// 用户组列表
113    groups: Vec<String>,
114}
115
116/// HTTP Session,持有 Redis 缓存引用、配置参数和当前会话数据。
117/// 实现了 axum `FromRequestParts`,可直接作为 handler 参数提取。
118#[derive(Clone)]
119pub struct Session {
120    cache: &'static RedisCache,
121    params: Arc<SessionParams>,
122    data: SessionData,
123}
124
125impl Session {
126    /// 创建未登录的空 Session,数据从下一次请求中按需加载。
127    pub fn new(cache: &'static RedisCache, params: Arc<SessionParams>) -> Self {
128        Self {
129            cache,
130            params,
131            data: SessionData::default(),
132        }
133    }
134
135    /// 生成 Redis 存储键,格式为 `ss:{session_id}`。
136    fn get_key(id: &str) -> String {
137        format!("ss:{id}")
138    }
139
140    /// 校验用户是否已登录,未登录时返回 401 错误。
141    fn validate_login(&self) -> Result<()> {
142        if !self.is_login() {
143            return Err(Error::UserNotLogin.into());
144        }
145        Ok(())
146    }
147
148    /// 返回 `true` 表示用户已登录(account 非空)。
149    pub fn is_login(&self) -> bool {
150        !self.data.account.is_empty()
151    }
152
153    /// 返回 `true` 表示 Session 尚未达到最大续期次数,可以续期。
154    pub fn can_renew(&self) -> bool {
155        self.data.renewal_count < self.params.max_renewal
156    }
157
158    /// 设置账号和用户 ID,账号变更时自动生成新的 Session ID,支持链式调用。
159    #[must_use]
160    pub fn with_account(mut self, account: impl Into<String>, user_id: i64) -> Self {
161        let account = account.into();
162        if self.data.id.is_empty() || self.data.account != account {
163            self.data.id = uuid();
164        }
165        self.data.account = account;
166        self.data.user_id = user_id;
167        self.data.iat = timestamp();
168        self
169    }
170
171    /// 设置角色列表,支持链式调用。
172    #[must_use]
173    pub fn with_roles(mut self, roles: Vec<String>) -> Self {
174        self.data.roles = roles;
175        self
176    }
177
178    /// 设置用户组列表,支持链式调用。
179    #[must_use]
180    pub fn with_groups(mut self, groups: Vec<String>) -> Self {
181        self.data.groups = groups;
182        self
183    }
184
185    /// 续期:累加续期计数并更新签发时间戳。
186    pub fn refresh(&mut self) {
187        self.data.renewal_count += 1;
188        self.data.iat = timestamp();
189    }
190
191    /// 返回当前登录的用户账号。
192    pub fn get_account(&self) -> &str {
193        &self.data.account
194    }
195
196    /// 返回当前登录的用户 ID。
197    pub fn get_user_id(&self) -> i64 {
198        self.data.user_id
199    }
200
201    /// 返回 Session 过期时间的格式化字符串。
202    pub fn get_expired_at(&self) -> String {
203        from_timestamp(self.data.iat + self.params.ttl, 0)
204    }
205
206    /// 返回 `true` 表示 Session 将在 1 小时内过期。
207    pub fn is_will_expired(&self) -> bool {
208        self.data.iat + self.params.ttl - timestamp() < 3600
209    }
210
211    /// 返回 Session 签发时间的格式化字符串。
212    pub fn get_issued_at(&self) -> String {
213        from_timestamp(self.data.iat, 0)
214    }
215
216    /// 返回 `true` 表示 Session 已超过 TTL 过期。
217    pub fn is_expired(&self) -> bool {
218        self.data.iat + self.params.ttl < timestamp()
219    }
220
221    /// 重置 Session(登出),清除 ID 和账号信息。
222    pub fn reset(&mut self) {
223        self.data.id = String::new();
224        self.data.account = String::new();
225    }
226
227    /// 将当前 Session 数据持久化到 Redis,TTL 与配置一致。
228    pub async fn save(&self) -> Result<()> {
229        if self.data.id.is_empty() {
230            return Err(Error::SessionIdEmpty.into());
231        }
232        self.cache
233            .set_struct(
234                &Self::get_key(&self.data.id),
235                &self.data,
236                Some(Duration::from_secs(self.params.ttl as u64)),
237            )
238            .await?;
239        Ok(())
240    }
241}
242
243/// 将 Session 转换为携带签名 Cookie 的 `SignedCookieJar`。
244/// Session ID 为空时将 Cookie max-age 设为 0(即删除 Cookie)。
245impl TryFrom<&Session> for SignedCookieJar {
246    type Error = tibba_error::Error;
247
248    fn try_from(se: &Session) -> Result<Self, Self::Error> {
249        let mut c = CookieBuilder::new(se.params.cookie.clone(), se.data.id.clone())
250            .path("/")
251            .http_only(true)
252            .max_age(time::Duration::seconds(se.params.ttl));
253
254        if se.data.id.is_empty() {
255            // ID 为空表示登出,将 max-age 置 0 以清除客户端 Cookie
256            c = c.max_age(time::Duration::days(0));
257        }
258
259        Ok(SignedCookieJar::new(se.params.key.clone()).add(c))
260    }
261}
262
263/// Session 登出/刷新接口的响应体。
264#[derive(Debug, Serialize, Deserialize, Default)]
265struct SessionResp {
266    account: String,
267    renewal_count: u8,
268}
269
270/// 将 Session 序列化为 HTTP 响应:设置签名 Cookie + JSON 账号信息。
271impl IntoResponse for Session {
272    fn into_response(self) -> Response {
273        let result: Result<SignedCookieJar, _> = (&self).try_into();
274        match result {
275            Ok(jar) => (
276                jar,
277                Json(SessionResp {
278                    account: self.data.account,
279                    renewal_count: self.data.renewal_count,
280                }),
281            )
282                .into_response(),
283            Err(err) => err.into_response(),
284        }
285    }
286}
287
288/// 将 Session 和额外数据一起序列化为 HTTP 响应,同时设置签名 Cookie。
289pub struct SessionResponse<T>(pub Session, pub T);
290
291impl<T> IntoResponse for SessionResponse<T>
292where
293    T: IntoResponse,
294{
295    fn into_response(self) -> Response {
296        let result: Result<SignedCookieJar, _> = (&self.0).try_into();
297        match result {
298            Ok(jar) => (jar, self.1).into_response(),
299            Err(err) => err.into_response(),
300        }
301    }
302}
303
304/// axum extractor:从请求扩展中提取 Session,按需从 Redis 加载数据。
305/// 若 Cookie 中存在有效 Session ID 且 Redis 中有对应数据,则填充 SessionData。
306impl<S> FromRequestParts<S> for Session
307where
308    S: Send + Sync,
309{
310    type Rejection = tibba_error::Error;
311
312    async fn from_request_parts(
313        parts: &mut Parts,
314        _state: &S,
315    ) -> std::result::Result<Self, Self::Rejection> {
316        let mut se = parts
317            .extensions
318            .get::<Session>()
319            .ok_or::<Error>(Error::SessionNotFound)?
320            .clone();
321        debug!(
322            target: LOG_TARGET,
323            id = se.data.id,
324            iat = se.data.iat,
325            "from_request_parts"
326        );
327        // iat == 0 表示本次请求尚未从 Redis 加载过数据
328        if se.data.iat == 0 {
329            let jar = SignedCookieJar::from_headers(&parts.headers, se.params.key.clone());
330            let Some(c) = jar.get(&se.params.cookie) else {
331                return Ok(se);
332            };
333            let session_id = c.value();
334            if session_id.len() < 36 {
335                return Err(Error::SessionIdInvalid.into());
336            }
337            if let Some(data) = se
338                .cache
339                .get_struct::<SessionData>(&Session::get_key(session_id))
340                .await?
341            {
342                debug!(
343                    target: LOG_TARGET,
344                    id = data.id,
345                    iat = data.iat,
346                    "load from cache"
347                );
348                se.data = data;
349                // 回写到扩展,同一请求内后续提取无需再查 Redis
350                parts.extensions.insert(se.clone());
351                if se.is_login() {
352                    CTX.get().set_account(se.get_account());
353                }
354
355                return Ok(se);
356            }
357        }
358        Ok(se)
359    }
360}
361
362/// axum extractor:要求用户已登录,否则返回 401。
363/// 通过 `Deref`/`DerefMut` 可直接访问内部 `Session` 的所有方法。
364pub struct UserSession(Session);
365
366impl From<UserSession> for Session {
367    fn from(se: UserSession) -> Self {
368        se.0
369    }
370}
371
372impl<S> FromRequestParts<S> for UserSession
373where
374    S: Send + Sync,
375{
376    type Rejection = tibba_error::Error;
377
378    async fn from_request_parts(
379        parts: &mut Parts,
380        _state: &S,
381    ) -> std::result::Result<Self, Self::Rejection> {
382        let se = Session::from_request_parts(parts, _state).await?;
383        se.validate_login()?;
384        Ok(UserSession(se))
385    }
386}
387
388impl std::ops::Deref for UserSession {
389    type Target = Session;
390
391    fn deref(&self) -> &Self::Target {
392        &self.0
393    }
394}
395
396impl std::ops::DerefMut for UserSession {
397    fn deref_mut(&mut self) -> &mut Self::Target {
398        &mut self.0
399    }
400}
401
402/// axum extractor:要求用户已登录且具有 Admin 或 SuperAdmin 角色,否则返回 401/403。
403/// 通过 `Deref`/`DerefMut` 可直接访问内部 `Session` 的所有方法。
404pub struct AdminSession(Session);
405
406impl From<AdminSession> for Session {
407    fn from(se: AdminSession) -> Self {
408        se.0
409    }
410}
411
412impl<S> FromRequestParts<S> for AdminSession
413where
414    S: Send + Sync,
415{
416    type Rejection = tibba_error::Error;
417
418    async fn from_request_parts(
419        parts: &mut Parts,
420        _state: &S,
421    ) -> std::result::Result<Self, Self::Rejection> {
422        let se = Session::from_request_parts(parts, _state).await?;
423        se.validate_login()?;
424        if !se.data.roles.iter().any(|role| {
425            let r = Role::from(role.as_str());
426            r == Role::Admin || r == Role::SuperAdmin
427        }) {
428            return Err(Error::UserNotAdmin.into());
429        }
430        Ok(AdminSession(se))
431    }
432}
433
434impl std::ops::Deref for AdminSession {
435    type Target = Session;
436
437    fn deref(&self) -> &Self::Target {
438        &self.0
439    }
440}
441
442impl std::ops::DerefMut for AdminSession {
443    fn deref_mut(&mut self) -> &mut Self::Target {
444        &mut self.0
445    }
446}