1use async_trait::async_trait;
4use faucet_core::state::{DOCTOR_SENTINEL_KEY, StateStore, validate_state_key};
5use faucet_core::util::quote_ident;
6use faucet_core::{FaucetError, Value};
7use sqlx::postgres::PgPoolOptions;
8use sqlx::{PgPool, Row};
9
10pub const DEFAULT_TABLE: &str = "faucet_state";
13
14pub struct PostgresStateStore {
30 pool: PgPool,
31 table: String,
32}
33
34impl PostgresStateStore {
35 pub async fn connect(connection_url: &str) -> Result<Self, FaucetError> {
37 Self::connect_with(connection_url, 5, DEFAULT_TABLE).await
38 }
39
40 pub async fn connect_with(
42 connection_url: &str,
43 max_connections: u32,
44 table: &str,
45 ) -> Result<Self, FaucetError> {
46 validate_table_name(table)?;
47 let pool = PgPoolOptions::new()
48 .max_connections(max_connections)
49 .connect(connection_url)
50 .await
51 .map_err(|e| {
52 FaucetError::State(format!("PostgreSQL state-store connection failed: {e}"))
53 })?;
54 Ok(Self {
55 pool,
56 table: table.to_owned(),
57 })
58 }
59
60 pub fn from_pool(pool: PgPool, table: impl Into<String>) -> Result<Self, FaucetError> {
63 let table = table.into();
64 validate_table_name(&table)?;
65 Ok(Self { pool, table })
66 }
67
68 pub fn table(&self) -> &str {
70 &self.table
71 }
72
73 pub async fn ensure_table(&self) -> Result<(), FaucetError> {
75 let sql = create_table_sql(&self.table);
76 sqlx::query(&sql).execute(&self.pool).await.map_err(|e| {
77 FaucetError::State(format!("failed to ensure state table {}: {e}", self.table))
78 })?;
79 Ok(())
80 }
81}
82
83pub(crate) fn validate_table_name(table: &str) -> Result<(), FaucetError> {
87 if table.is_empty() {
88 return Err(FaucetError::Config(
89 "state-store table name must not be empty".into(),
90 ));
91 }
92 if table.len() > 63 {
93 return Err(FaucetError::Config(format!(
94 "state-store table name '{table}' exceeds Postgres' 63-character identifier limit"
95 )));
96 }
97 for (i, c) in table.char_indices() {
98 let ok = c.is_ascii_alphanumeric() || c == '_';
99 if !ok {
100 return Err(FaucetError::Config(format!(
101 "state-store table name '{table}' contains illegal character {c:?} at byte {i}"
102 )));
103 }
104 }
105 Ok(())
106}
107
108pub(crate) fn create_table_sql(table: &str) -> String {
109 format!(
110 "CREATE TABLE IF NOT EXISTS {table_ident} (\
111 key TEXT PRIMARY KEY,\
112 value JSONB NOT NULL,\
113 updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()\
114 )",
115 table_ident = quote_ident(table)
116 )
117}
118
119pub(crate) fn select_sql(table: &str) -> String {
120 format!("SELECT value FROM {} WHERE key = $1", quote_ident(table))
121}
122
123pub(crate) fn upsert_sql(table: &str) -> String {
124 format!(
125 "INSERT INTO {tbl} (key, value, updated_at) VALUES ($1, $2, NOW()) \
126 ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value, updated_at = NOW()",
127 tbl = quote_ident(table)
128 )
129}
130
131pub(crate) fn delete_sql(table: &str) -> String {
132 format!("DELETE FROM {} WHERE key = $1", quote_ident(table))
133}
134
135#[async_trait]
136impl StateStore for PostgresStateStore {
137 async fn get(&self, key: &str) -> Result<Option<Value>, FaucetError> {
138 validate_state_key(key)?;
139 let row = sqlx::query(&select_sql(&self.table))
140 .bind(key)
141 .fetch_optional(&self.pool)
142 .await
143 .map_err(|e| {
144 FaucetError::State(format!("Postgres SELECT for key '{key}' failed: {e}"))
145 })?;
146 match row {
147 None => Ok(None),
148 Some(r) => {
149 let value: Value = r.try_get(0).map_err(|e| {
150 FaucetError::State(format!(
151 "failed to decode JSONB column for key '{key}': {e}"
152 ))
153 })?;
154 Ok(Some(value))
155 }
156 }
157 }
158
159 async fn put(&self, key: &str, value: &Value) -> Result<(), FaucetError> {
160 validate_state_key(key)?;
161 sqlx::query(&upsert_sql(&self.table))
162 .bind(key)
163 .bind(value)
164 .execute(&self.pool)
165 .await
166 .map_err(|e| {
167 FaucetError::State(format!("Postgres UPSERT for key '{key}' failed: {e}"))
168 })?;
169 tracing::debug!(key, table = %self.table, "state written to Postgres");
170 Ok(())
171 }
172
173 async fn delete(&self, key: &str) -> Result<(), FaucetError> {
174 validate_state_key(key)?;
175 sqlx::query(&delete_sql(&self.table))
176 .bind(key)
177 .execute(&self.pool)
178 .await
179 .map_err(|e| {
180 FaucetError::State(format!("Postgres DELETE for key '{key}' failed: {e}"))
181 })?;
182 Ok(())
183 }
184
185 async fn check(
186 &self,
187 ctx: &faucet_core::check::CheckContext,
188 ) -> Result<faucet_core::check::CheckReport, FaucetError> {
189 use faucet_core::check::{CheckReport, Probe};
190
191 let start = std::time::Instant::now();
196 let probe = match tokio::time::timeout(ctx.timeout, self.sentinel_roundtrip()).await {
197 Ok(Ok(())) => Probe::pass("sentinel", start.elapsed()),
198 Ok(Err(e)) => Probe::fail_hint(
199 "sentinel",
200 start.elapsed(),
201 e.to_string(),
202 format!(
203 "verify the server is reachable, the credentials grant read/write access, \
204 and the '{}' table exists (call ensure_table or create it manually)",
205 self.table
206 ),
207 ),
208 Err(_) => Probe::fail_hint(
209 "sentinel",
210 start.elapsed(),
211 format!(
212 "round-trip timed out after {:?}; Postgres did not respond",
213 ctx.timeout
214 ),
215 "verify the server is reachable or raise the check timeout",
216 ),
217 };
218 Ok(CheckReport::single(probe))
219 }
220}
221
222impl PostgresStateStore {
223 async fn sentinel_roundtrip(&self) -> Result<(), FaucetError> {
227 let probe = serde_json::json!({ "faucet_doctor": true });
228 self.put(DOCTOR_SENTINEL_KEY, &probe).await?;
229 let got = self.get(DOCTOR_SENTINEL_KEY).await?;
230 let _ = self.delete(DOCTOR_SENTINEL_KEY).await;
232 match got {
233 Some(v) if v == probe => Ok(()),
234 _ => Err(FaucetError::State(
235 "sentinel readback did not match what was written".into(),
236 )),
237 }
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 #[test]
246 fn validate_table_name_accepts_typical_values() {
247 for t in [
248 "faucet_state",
249 "FaucetState",
250 "state_v2",
251 "f1",
252 "abcdefghijklmnopqrstuvwxyz_0123456789_FOO",
253 ] {
254 validate_table_name(t).unwrap_or_else(|e| panic!("expected ok for {t:?}: {e}"));
255 }
256 }
257
258 #[test]
259 fn validate_table_name_rejects_empty() {
260 let err = validate_table_name("").unwrap_err();
261 assert!(matches!(err, FaucetError::Config(_)));
262 }
263
264 #[test]
265 fn validate_table_name_rejects_illegal_chars() {
266 for t in [
267 "table-name",
268 "schema.table",
269 "drop table users;--",
270 "spaces in name",
271 ] {
272 let err = validate_table_name(t).expect_err(&format!("expected error for {t:?}"));
273 assert!(matches!(err, FaucetError::Config(_)));
274 }
275 }
276
277 #[test]
278 fn validate_table_name_rejects_over_long() {
279 let t = "a".repeat(64);
280 assert!(validate_table_name(&t).is_err());
281 }
282
283 #[test]
284 fn create_table_sql_quotes_identifier() {
285 let sql = create_table_sql("faucet_state");
286 assert!(sql.contains("\"faucet_state\""));
287 assert!(sql.contains("CREATE TABLE IF NOT EXISTS"));
288 assert!(sql.contains("PRIMARY KEY"));
289 assert!(sql.contains("JSONB"));
290 }
291
292 #[test]
293 fn create_table_sql_escapes_embedded_quote() {
294 let sql = create_table_sql("ab\"c");
297 assert!(sql.contains("\"ab\"\"c\""));
301 }
302
303 #[test]
304 fn select_sql_uses_parameter_marker() {
305 let sql = select_sql("faucet_state");
306 assert_eq!(sql, "SELECT value FROM \"faucet_state\" WHERE key = $1");
307 }
308
309 #[test]
310 fn upsert_sql_uses_on_conflict_do_update() {
311 let sql = upsert_sql("faucet_state");
312 assert!(sql.contains("INSERT INTO \"faucet_state\""));
313 assert!(sql.contains("ON CONFLICT (key) DO UPDATE"));
314 assert!(sql.contains("value = EXCLUDED.value"));
315 assert!(sql.contains("updated_at = NOW()"));
316 }
317
318 #[test]
319 fn delete_sql_uses_parameter_marker() {
320 let sql = delete_sql("faucet_state");
321 assert_eq!(sql, "DELETE FROM \"faucet_state\" WHERE key = $1");
322 }
323
324 #[tokio::test]
325 async fn connect_rejects_invalid_table_name() {
326 let result =
327 PostgresStateStore::connect_with("postgres://localhost/does_not_matter", 5, "bad-name")
328 .await;
329 match result {
330 Err(FaucetError::Config(_)) => {}
331 Err(other) => panic!("expected Config error, got {other:?}"),
332 Ok(_) => panic!("expected error, got Ok"),
333 }
334 }
335}