Skip to main content

rs_auth_postgres/
session.rs

1use async_trait::async_trait;
2use sqlx::Row;
3
4use rs_auth_core::error::AuthError;
5use rs_auth_core::store::SessionStore;
6use rs_auth_core::types::{NewSession, Session};
7
8use crate::db::AuthDb;
9
10#[async_trait]
11impl SessionStore for AuthDb {
12    async fn create_session(&self, session: NewSession) -> Result<Session, AuthError> {
13        sqlx::query(
14            r#"
15            INSERT INTO sessions (token_hash, user_id, expires_at, ip_address, user_agent)
16            VALUES ($1, $2, $3, $4, $5)
17            RETURNING id, token_hash, user_id, expires_at, ip_address, user_agent, created_at, updated_at
18            "#,
19        )
20        .bind(session.token_hash)
21        .bind(session.user_id)
22        .bind(session.expires_at)
23        .bind(session.ip_address)
24        .bind(session.user_agent)
25        .fetch_one(&self.pool)
26        .await
27        .map(|row| Session {
28            id: row.get("id"),
29            token_hash: row.get("token_hash"),
30            user_id: row.get("user_id"),
31            expires_at: row.get("expires_at"),
32            ip_address: row.get("ip_address"),
33            user_agent: row.get("user_agent"),
34            created_at: row.get("created_at"),
35            updated_at: row.get("updated_at"),
36        })
37        .map_err(|error| AuthError::Store(error.to_string()))
38    }
39
40    async fn find_by_token_hash(&self, token_hash: &str) -> Result<Option<Session>, AuthError> {
41        sqlx::query(
42            r#"
43            SELECT id, token_hash, user_id, expires_at, ip_address, user_agent, created_at, updated_at
44            FROM sessions
45            WHERE token_hash = $1
46            "#,
47        )
48        .bind(token_hash)
49        .fetch_optional(&self.pool)
50        .await
51        .map(|row| {
52            row.map(|row| Session {
53                id: row.get("id"),
54                token_hash: row.get("token_hash"),
55                user_id: row.get("user_id"),
56                expires_at: row.get("expires_at"),
57                ip_address: row.get("ip_address"),
58                user_agent: row.get("user_agent"),
59                created_at: row.get("created_at"),
60                updated_at: row.get("updated_at"),
61            })
62        })
63        .map_err(|error| AuthError::Store(error.to_string()))
64    }
65
66    async fn find_by_user_id(&self, user_id: i64) -> Result<Vec<Session>, AuthError> {
67        sqlx::query(
68            r#"
69            SELECT id, token_hash, user_id, expires_at, ip_address, user_agent, created_at, updated_at
70            FROM sessions
71            WHERE user_id = $1 AND expires_at > now()
72            ORDER BY created_at DESC
73            "#,
74        )
75        .bind(user_id)
76        .fetch_all(&self.pool)
77        .await
78        .map(|rows| {
79            rows.into_iter()
80                .map(|row| Session {
81                    id: row.get("id"),
82                    token_hash: row.get("token_hash"),
83                    user_id: row.get("user_id"),
84                    expires_at: row.get("expires_at"),
85                    ip_address: row.get("ip_address"),
86                    user_agent: row.get("user_agent"),
87                    created_at: row.get("created_at"),
88                    updated_at: row.get("updated_at"),
89                })
90                .collect()
91        })
92        .map_err(|error| AuthError::Store(error.to_string()))
93    }
94
95    async fn delete_session(&self, id: i64) -> Result<(), AuthError> {
96        sqlx::query(r#"DELETE FROM sessions WHERE id = $1"#)
97            .bind(id)
98            .execute(&self.pool)
99            .await
100            .map(|_| ())
101            .map_err(|error| AuthError::Store(error.to_string()))
102    }
103
104    async fn delete_by_user_id(&self, user_id: i64) -> Result<(), AuthError> {
105        sqlx::query(r#"DELETE FROM sessions WHERE user_id = $1"#)
106            .bind(user_id)
107            .execute(&self.pool)
108            .await
109            .map(|_| ())
110            .map_err(|error| AuthError::Store(error.to_string()))
111    }
112
113    async fn delete_expired(&self) -> Result<u64, AuthError> {
114        sqlx::query(r#"DELETE FROM sessions WHERE expires_at < now()"#)
115            .execute(&self.pool)
116            .await
117            .map(|result| result.rows_affected())
118            .map_err(|error| AuthError::Store(error.to_string()))
119    }
120}