Skip to main content

faucet_state_postgres/
store.rs

1//! PostgreSQL-backed [`StateStore`].
2
3use async_trait::async_trait;
4use faucet_core::state::{DOCTOR_SENTINEL_KEY, StateStore, validate_state_key};
5use faucet_core::util::quote_ident;
6use faucet_core::{FaucetError, Value};
7use sqlx::postgres::PgPoolOptions;
8use sqlx::{PgPool, Row};
9
10/// Default name for the state-store table. Override via
11/// [`PostgresStateStore::connect_with`].
12pub const DEFAULT_TABLE: &str = "faucet_state";
13
14/// A `StateStore` that persists each entry as a row in a single Postgres
15/// table.
16///
17/// The table layout is:
18///
19/// ```sql
20/// CREATE TABLE faucet_state (
21///     key        TEXT        PRIMARY KEY,
22///     value      JSONB       NOT NULL,
23///     updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
24/// );
25/// ```
26///
27/// Calling [`PostgresStateStore::ensure_table`] creates the table if it does
28/// not exist; integrators that manage schema separately can skip this call.
29pub struct PostgresStateStore {
30    pool: PgPool,
31    table: String,
32}
33
34impl PostgresStateStore {
35    /// Connect to a PostgreSQL server with the default `faucet_state` table.
36    pub async fn connect(connection_url: &str) -> Result<Self, FaucetError> {
37        Self::connect_with(connection_url, 5, DEFAULT_TABLE).await
38    }
39
40    /// Connect with explicit pool size and table name.
41    pub async fn connect_with(
42        connection_url: &str,
43        max_connections: u32,
44        table: &str,
45    ) -> Result<Self, FaucetError> {
46        validate_table_name(table)?;
47        let pool = PgPoolOptions::new()
48            .max_connections(max_connections)
49            .connect(connection_url)
50            .await
51            .map_err(|e| {
52                FaucetError::State(format!("PostgreSQL state-store connection failed: {e}"))
53            })?;
54        Ok(Self {
55            pool,
56            table: table.to_owned(),
57        })
58    }
59
60    /// Construct from an existing pool. Useful when integrators already
61    /// manage a shared `PgPool`.
62    pub fn from_pool(pool: PgPool, table: impl Into<String>) -> Result<Self, FaucetError> {
63        let table = table.into();
64        validate_table_name(&table)?;
65        Ok(Self { pool, table })
66    }
67
68    /// Returns the table name in use.
69    pub fn table(&self) -> &str {
70        &self.table
71    }
72
73    /// Create the state-store table if it does not already exist.
74    pub async fn ensure_table(&self) -> Result<(), FaucetError> {
75        let sql = create_table_sql(&self.table);
76        sqlx::query(&sql).execute(&self.pool).await.map_err(|e| {
77            FaucetError::State(format!("failed to ensure state table {}: {e}", self.table))
78        })?;
79        Ok(())
80    }
81}
82
83/// SQL identifiers in faucet-stream must already be safe for `quote_ident`;
84/// additionally restrict the table name to printable ASCII to keep error
85/// messages readable.
86pub(crate) fn validate_table_name(table: &str) -> Result<(), FaucetError> {
87    if table.is_empty() {
88        return Err(FaucetError::Config(
89            "state-store table name must not be empty".into(),
90        ));
91    }
92    if table.len() > 63 {
93        return Err(FaucetError::Config(format!(
94            "state-store table name '{table}' exceeds Postgres' 63-character identifier limit"
95        )));
96    }
97    for (i, c) in table.char_indices() {
98        let ok = c.is_ascii_alphanumeric() || c == '_';
99        if !ok {
100            return Err(FaucetError::Config(format!(
101                "state-store table name '{table}' contains illegal character {c:?} at byte {i}"
102            )));
103        }
104    }
105    Ok(())
106}
107
108pub(crate) fn create_table_sql(table: &str) -> String {
109    format!(
110        "CREATE TABLE IF NOT EXISTS {table_ident} (\
111            key TEXT PRIMARY KEY,\
112            value JSONB NOT NULL,\
113            updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()\
114        )",
115        table_ident = quote_ident(table)
116    )
117}
118
119pub(crate) fn select_sql(table: &str) -> String {
120    format!("SELECT value FROM {} WHERE key = $1", quote_ident(table))
121}
122
123pub(crate) fn upsert_sql(table: &str) -> String {
124    format!(
125        "INSERT INTO {tbl} (key, value, updated_at) VALUES ($1, $2, NOW()) \
126         ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value, updated_at = NOW()",
127        tbl = quote_ident(table)
128    )
129}
130
131pub(crate) fn delete_sql(table: &str) -> String {
132    format!("DELETE FROM {} WHERE key = $1", quote_ident(table))
133}
134
135#[async_trait]
136impl StateStore for PostgresStateStore {
137    async fn get(&self, key: &str) -> Result<Option<Value>, FaucetError> {
138        validate_state_key(key)?;
139        let row = sqlx::query(&select_sql(&self.table))
140            .bind(key)
141            .fetch_optional(&self.pool)
142            .await
143            .map_err(|e| {
144                FaucetError::State(format!("Postgres SELECT for key '{key}' failed: {e}"))
145            })?;
146        match row {
147            None => Ok(None),
148            Some(r) => {
149                let value: Value = r.try_get(0).map_err(|e| {
150                    FaucetError::State(format!(
151                        "failed to decode JSONB column for key '{key}': {e}"
152                    ))
153                })?;
154                Ok(Some(value))
155            }
156        }
157    }
158
159    async fn put(&self, key: &str, value: &Value) -> Result<(), FaucetError> {
160        validate_state_key(key)?;
161        sqlx::query(&upsert_sql(&self.table))
162            .bind(key)
163            .bind(value)
164            .execute(&self.pool)
165            .await
166            .map_err(|e| {
167                FaucetError::State(format!("Postgres UPSERT for key '{key}' failed: {e}"))
168            })?;
169        tracing::debug!(key, table = %self.table, "state written to Postgres");
170        Ok(())
171    }
172
173    async fn delete(&self, key: &str) -> Result<(), FaucetError> {
174        validate_state_key(key)?;
175        sqlx::query(&delete_sql(&self.table))
176            .bind(key)
177            .execute(&self.pool)
178            .await
179            .map_err(|e| {
180                FaucetError::State(format!("Postgres DELETE for key '{key}' failed: {e}"))
181            })?;
182        Ok(())
183    }
184
185    async fn check(
186        &self,
187        ctx: &faucet_core::check::CheckContext,
188    ) -> Result<faucet_core::check::CheckReport, FaucetError> {
189        use faucet_core::check::{CheckReport, Probe};
190
191        // Exercise the real upsert → select → delete cycle on a sentinel key.
192        // This validates connectivity, auth, the table's existence, and
193        // read/write permissions through the actual code path and leaves no
194        // residue.
195        let start = std::time::Instant::now();
196        let probe = match tokio::time::timeout(ctx.timeout, self.sentinel_roundtrip()).await {
197            Ok(Ok(())) => Probe::pass("sentinel", start.elapsed()),
198            Ok(Err(e)) => Probe::fail_hint(
199                "sentinel",
200                start.elapsed(),
201                e.to_string(),
202                format!(
203                    "verify the server is reachable, the credentials grant read/write access, \
204                     and the '{}' table exists (call ensure_table or create it manually)",
205                    self.table
206                ),
207            ),
208            Err(_) => Probe::fail_hint(
209                "sentinel",
210                start.elapsed(),
211                format!(
212                    "round-trip timed out after {:?}; Postgres did not respond",
213                    ctx.timeout
214                ),
215                "verify the server is reachable or raise the check timeout",
216            ),
217        };
218        Ok(CheckReport::single(probe))
219    }
220}
221
222impl PostgresStateStore {
223    /// Write, read back, and delete a sentinel key — the body of the `check()`
224    /// probe, factored out so the happy path stays linear. Reuses the store's
225    /// own `put`/`get`/`delete` against the configured table.
226    async fn sentinel_roundtrip(&self) -> Result<(), FaucetError> {
227        let probe = serde_json::json!({ "faucet_doctor": true });
228        self.put(DOCTOR_SENTINEL_KEY, &probe).await?;
229        let got = self.get(DOCTOR_SENTINEL_KEY).await?;
230        // Best-effort cleanup regardless of the read result.
231        let _ = self.delete(DOCTOR_SENTINEL_KEY).await;
232        match got {
233            Some(v) if v == probe => Ok(()),
234            _ => Err(FaucetError::State(
235                "sentinel readback did not match what was written".into(),
236            )),
237        }
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    #[test]
246    fn validate_table_name_accepts_typical_values() {
247        for t in [
248            "faucet_state",
249            "FaucetState",
250            "state_v2",
251            "f1",
252            "abcdefghijklmnopqrstuvwxyz_0123456789_FOO",
253        ] {
254            validate_table_name(t).unwrap_or_else(|e| panic!("expected ok for {t:?}: {e}"));
255        }
256    }
257
258    #[test]
259    fn validate_table_name_rejects_empty() {
260        let err = validate_table_name("").unwrap_err();
261        assert!(matches!(err, FaucetError::Config(_)));
262    }
263
264    #[test]
265    fn validate_table_name_rejects_illegal_chars() {
266        for t in [
267            "table-name",
268            "schema.table",
269            "drop table users;--",
270            "spaces in name",
271        ] {
272            let err = validate_table_name(t).expect_err(&format!("expected error for {t:?}"));
273            assert!(matches!(err, FaucetError::Config(_)));
274        }
275    }
276
277    #[test]
278    fn validate_table_name_rejects_over_long() {
279        let t = "a".repeat(64);
280        assert!(validate_table_name(&t).is_err());
281    }
282
283    #[test]
284    fn create_table_sql_quotes_identifier() {
285        let sql = create_table_sql("faucet_state");
286        assert!(sql.contains("\"faucet_state\""));
287        assert!(sql.contains("CREATE TABLE IF NOT EXISTS"));
288        assert!(sql.contains("PRIMARY KEY"));
289        assert!(sql.contains("JSONB"));
290    }
291
292    #[test]
293    fn create_table_sql_escapes_embedded_quote() {
294        // quote_ident escapes embedded quotes by doubling — verify the
295        // identifier remains a single quoted token.
296        let sql = create_table_sql("ab\"c");
297        // The opening quote, doubled inner quote, closing quote — should
298        // appear in the output regardless of whether the table-name
299        // validator would normally allow it.
300        assert!(sql.contains("\"ab\"\"c\""));
301    }
302
303    #[test]
304    fn select_sql_uses_parameter_marker() {
305        let sql = select_sql("faucet_state");
306        assert_eq!(sql, "SELECT value FROM \"faucet_state\" WHERE key = $1");
307    }
308
309    #[test]
310    fn upsert_sql_uses_on_conflict_do_update() {
311        let sql = upsert_sql("faucet_state");
312        assert!(sql.contains("INSERT INTO \"faucet_state\""));
313        assert!(sql.contains("ON CONFLICT (key) DO UPDATE"));
314        assert!(sql.contains("value = EXCLUDED.value"));
315        assert!(sql.contains("updated_at = NOW()"));
316    }
317
318    #[test]
319    fn delete_sql_uses_parameter_marker() {
320        let sql = delete_sql("faucet_state");
321        assert_eq!(sql, "DELETE FROM \"faucet_state\" WHERE key = $1");
322    }
323
324    #[tokio::test]
325    async fn connect_rejects_invalid_table_name() {
326        let result =
327            PostgresStateStore::connect_with("postgres://localhost/does_not_matter", 5, "bad-name")
328                .await;
329        match result {
330            Err(FaucetError::Config(_)) => {}
331            Err(other) => panic!("expected Config error, got {other:?}"),
332            Ok(_) => panic!("expected error, got Ok"),
333        }
334    }
335}