dbpulse/queries/
postgres.rs

1use anyhow::{anyhow, Context, Result};
2use chrono::prelude::*;
3use chrono::{DateTime, Utc};
4use dsn::DSN;
5use rand::Rng;
6use sqlx::{
7    postgres::{PgConnectOptions, PgDatabaseError},
8    ConnectOptions, Connection,
9};
10use uuid::Uuid;
11
12pub async fn test_rw(dsn: &DSN, now: DateTime<Utc>, range: u32) -> Result<String> {
13    let mut options = PgConnectOptions::new()
14        .username(dsn.username.clone().unwrap_or_default().as_ref())
15        .password(dsn.password.clone().unwrap_or_default().as_str())
16        .database(dsn.database.clone().unwrap_or_default().as_ref());
17
18    if let Some(host) = &dsn.host {
19        options = options.host(host.as_str()).port(dsn.port.unwrap_or(5432));
20    } else if let Some(socket) = &dsn.socket {
21        options = options.socket(socket.as_str());
22    }
23
24    let mut conn = match options.connect().await {
25        Ok(conn) => conn,
26        Err(err) => match err {
27            sqlx::Error::Database(db_err) => {
28                if db_err
29                    .as_error()
30                    .downcast_ref::<PgDatabaseError>()
31                    .map(PgDatabaseError::code)
32                    == Some("3D000")
33                {
34                    let tmp_options = options.clone().database("postgres");
35                    let mut tmp_conn = tmp_options.connect().await?;
36                    sqlx::query(&format!(
37                        "CREATE DATABASE {}",
38                        dsn.database.clone().unwrap_or_default()
39                    ))
40                    .execute(&mut tmp_conn)
41                    .await?;
42                    drop(tmp_conn);
43                    options.connect().await?
44                } else {
45                    return Err(db_err.into());
46                }
47            }
48            _ => return Err(err.into()),
49        },
50    };
51
52    // Get database version
53    let version: Option<String> = sqlx::query_scalar("SHOW server_version")
54        .fetch_optional(&mut conn)
55        .await
56        .context("Failed to fetch database version")?;
57
58    // Query to check if the database is in recovery (read-only)
59    let is_in_recovery: (bool,) = sqlx::query_as("SELECT pg_is_in_recovery();")
60        .fetch_one(&mut conn)
61        .await?;
62
63    // can't write to a read-only database
64    if is_in_recovery.0 {
65        return Ok(format!(
66            "{} - Database is in recovery mode",
67            version.unwrap_or_default()
68        ));
69    }
70
71    // for UUID
72    sqlx::query("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"")
73        .execute(&mut conn)
74        .await?;
75
76    // create table
77    let create_table_sql = r#"
78        CREATE TABLE IF NOT EXISTS dbpulse_rw (
79            id SERIAL PRIMARY KEY,
80            t1 BIGINT NOT NULL,
81            t2 TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
82            uuid UUID NOT NULL,
83            CONSTRAINT uuid_unique UNIQUE (uuid)
84        )
85    "#;
86
87    sqlx::query(create_table_sql).execute(&mut conn).await?;
88
89    // write into table
90    let id: u32 = rand::thread_rng().gen_range(0..range);
91    let uuid = Uuid::new_v4();
92
93    // SQL Query
94    sqlx::query(
95        r#"
96        INSERT INTO dbpulse_rw (id, t1, uuid)
97        VALUES ($1, $2, $3)
98        ON CONFLICT (id)
99        DO UPDATE SET t1 = EXCLUDED.t1, uuid = EXCLUDED.uuid
100        "#,
101    )
102    .bind(id as i32)
103    .bind(now.timestamp())
104    .bind(uuid)
105    .execute(&mut conn) // Ensure we're using PgConnection here
106    .await?;
107
108    // Check if stored record matches
109    let row: Option<(i64, Uuid)> = sqlx::query_as(
110        r#"
111        SELECT t1, uuid
112        FROM dbpulse_rw
113        WHERE id = $1
114        "#,
115    )
116    .bind(id as i32)
117    .fetch_optional(&mut conn)
118    .await?;
119
120    // Ensure the row exists and matches
121    let (t1, v4) = row.context("Expected records")?;
122    if now.timestamp() != t1 || uuid != v4 {
123        return Err(anyhow!(
124            "Records don't match: expected ({}, {}), got ({}, {})",
125            now.timestamp(),
126            uuid,
127            t1,
128            v4
129        ));
130    }
131
132    // Start a transaction to set all `t1` records to 0
133    let mut tx = conn.begin().await?;
134    sqlx::query("UPDATE dbpulse_rw SET t1 = $1")
135        .bind(0)
136        .execute(tx.as_mut())
137        .await?;
138    let rows: Vec<i64> = sqlx::query_scalar("SELECT t1 FROM dbpulse_rw")
139        .fetch_all(tx.as_mut())
140        .await?;
141
142    for row in rows {
143        if row != 0 {
144            return Err(anyhow!("Records don't match: {} != {}", row, 0));
145        }
146    }
147
148    // Roll back this transaction
149    tx.rollback().await?;
150
151    // Start a new transaction to update record 0 with current timestamp
152    let mut tx = conn.begin().await?;
153    sqlx::query(
154        r#"
155        INSERT INTO dbpulse_rw (id, t1, uuid)
156        VALUES (0, $1, UUID_GENERATE_V4())
157        ON CONFLICT (id)
158        DO UPDATE SET t1 = EXCLUDED.t1
159        "#,
160    )
161    .bind(now.timestamp())
162    .execute(tx.as_mut())
163    .await
164    .context("Failed to insert or update record")?;
165    tx.commit().await?;
166
167    // Drop the table conditionally
168    if now.minute() == id {
169        sqlx::query("DROP TABLE dbpulse_rw")
170            .execute(&mut conn)
171            .await
172            .context("Failed to drop table")?;
173    }
174
175    drop(conn);
176
177    version.context("Expected database version")
178}