Skip to main content

apfsds_storage/
postgres.rs

1use serde::{Deserialize, Serialize};
2use sqlx::{Pool, Postgres, Row, postgres::PgPoolOptions};
3use std::time::Duration;
4use thiserror::Error;
5
6#[derive(Error, Debug)]
7pub enum PgError {
8    #[error("Database error: {0}")]
9    DbError(#[from] sqlx::Error),
10}
11
12/// User Group definition (e.g., "Premium Asia", "Free US")
13#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
14pub struct ExitGroup {
15    pub id: i32,
16    pub name: String,
17    pub description: Option<String>,
18}
19
20/// User definition
21#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
22pub struct User {
23    pub id: i64,
24    pub username: String,
25    pub token_hash: String,
26    pub group_id: i32,
27    pub balance: i64, // simplified billing
28}
29
30/// Postgres Client helper
31#[derive(Clone)]
32pub struct PgClient {
33    pool: Pool<Postgres>,
34}
35
36impl PgClient {
37    pub async fn new(url: &str) -> Result<Self, PgError> {
38        let pool = PgPoolOptions::new()
39            .max_connections(20)
40            .acquire_timeout(Duration::from_secs(3))
41            .connect(url)
42            .await?;
43
44        Ok(Self { pool })
45    }
46
47    /// Initialize schema
48    pub async fn migrate(&self) -> Result<(), PgError> {
49        sqlx::query(
50            r#"
51            CREATE TABLE IF NOT EXISTS exit_groups (
52                id SERIAL PRIMARY KEY,
53                name VARCHAR(50) NOT NULL UNIQUE,
54                description TEXT
55            );
56
57            CREATE TABLE IF NOT EXISTS users (
58                id BIGSERIAL PRIMARY KEY,
59                username VARCHAR(100) NOT NULL UNIQUE,
60                token_hash VARCHAR(255) NOT NULL,
61                group_id INT REFERENCES exit_groups(id),
62                balance BIGINT DEFAULT 0,
63                created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
64            );
65
66            CREATE TABLE IF NOT EXISTS billing_logs (
67                id BIGSERIAL PRIMARY KEY,
68                user_id BIGINT REFERENCES users(id),
69                bytes_used BIGINT NOT NULL,
70                timestamp TIMESTAMP WITH TIME ZONE DEFAULT NOW()
71            );
72            "#,
73        )
74        .execute(&self.pool)
75        .await?;
76
77        // Seed default group if empty
78        let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM exit_groups")
79            .fetch_one(&self.pool)
80            .await?;
81
82        if count == 0 {
83            sqlx::query(
84                "INSERT INTO exit_groups (name, description) VALUES ('default', 'Default Group')",
85            )
86            .execute(&self.pool)
87            .await?;
88        }
89
90        Ok(())
91    }
92
93    pub async fn get_user_by_token(&self, token: &str) -> Result<Option<User>, PgError> {
94        // Note: In production, use bcrypt/argon2 to verify token_hash
95        // Current implementation does direct hash comparison for simplicity
96        sqlx::query_as::<_, User>("SELECT * FROM users WHERE token_hash = $1")
97            .bind(token)
98            .fetch_optional(&self.pool)
99            .await
100            .map_err(Into::into)
101    }
102
103    pub async fn record_usage(&self, user_id: i64, bytes: u64) -> Result<(), PgError> {
104        sqlx::query("INSERT INTO billing_logs (user_id, bytes_used) VALUES ($1, $2)")
105            .bind(user_id)
106            .bind(bytes as i64)
107            .execute(&self.pool)
108            .await?;
109        Ok(())
110    }
111}