Skip to main content

heliosdb_proxy/
branch.rs

1//! Instant branch databases.
2//!
3//! Provisions copy-on-write-ish database branches through the proxy:
4//! `CREATE DATABASE <branch> TEMPLATE <base>` (vanilla PostgreSQL's fast
5//! file-clone — the proxy-layer approximation of Neon/PlanetScale branching;
6//! a Nano backend can branch natively). The proxy issues the maintenance
7//! statements over its backend PG-wire client; clients then connect to the
8//! branch simply by database name (the proxy relays the startup unchanged), so
9//! routing is native — branching here is the *provisioning + lifecycle* layer.
10
11use std::time::Duration;
12
13use crate::backend::types::TextValue;
14use crate::backend::{tls::default_client_config, BackendClient, BackendConfig, TlsMode};
15use crate::config::BranchConfig;
16
17fn admin_cfg(cfg: &BranchConfig) -> BackendConfig {
18    BackendConfig {
19        host: cfg.backend_host.clone(),
20        port: cfg.backend_port,
21        user: cfg.admin_user.clone(),
22        password: cfg.admin_password.clone(),
23        database: Some(cfg.admin_database.clone()),
24        application_name: Some("heliosproxy-branch".to_string()),
25        tls_mode: TlsMode::Disable,
26        connect_timeout: Duration::from_secs(5),
27        query_timeout: Duration::from_secs(120),
28        tls_config: default_client_config(),
29    }
30}
31
32/// A database identifier must be a simple `[A-Za-z_][A-Za-z0-9_]*` name. We
33/// validate (rather than escape) because `CREATE DATABASE` cannot be
34/// parameterised, so a strict allowlist is the safe path.
35fn valid_ident(name: &str) -> bool {
36    !name.is_empty()
37        && name.len() <= 63
38        && name
39            .chars()
40            .next()
41            .map(|c| c.is_ascii_alphabetic() || c == '_')
42            .unwrap_or(false)
43        && name.chars().all(|c| c.is_ascii_alphanumeric() || c == '_')
44}
45
46/// Create a branch database from `base` (or the configured default).
47pub async fn create(cfg: &BranchConfig, name: &str, base: Option<&str>) -> Result<(), String> {
48    let base = base.unwrap_or(&cfg.base_database);
49    if !valid_ident(name) {
50        return Err(format!(
51            "invalid branch name '{}' (use [A-Za-z_][A-Za-z0-9_]*)",
52            name
53        ));
54    }
55    if !valid_ident(base) {
56        return Err(format!("invalid base name '{}'", base));
57    }
58    let mut c = BackendClient::connect(&admin_cfg(cfg))
59        .await
60        .map_err(|e| format!("admin connect: {}", e))?;
61    let sql = format!("CREATE DATABASE \"{}\" TEMPLATE \"{}\"", name, base);
62    let r = c
63        .execute(&sql)
64        .await
65        .map_err(|e| format!("create branch: {}", e));
66    c.close().await;
67    r.map(|_| ())
68}
69
70/// Drop a branch database.
71pub async fn drop(cfg: &BranchConfig, name: &str) -> Result<(), String> {
72    if !valid_ident(name) {
73        return Err(format!("invalid branch name '{}'", name));
74    }
75    let mut c = BackendClient::connect(&admin_cfg(cfg))
76        .await
77        .map_err(|e| format!("admin connect: {}", e))?;
78    let r = c
79        .execute(&format!("DROP DATABASE IF EXISTS \"{}\"", name))
80        .await
81        .map_err(|e| format!("drop branch: {}", e));
82    c.close().await;
83    r.map(|_| ())
84}
85
86/// List user databases (excludes templates + the built-in maintenance DBs).
87pub async fn list(cfg: &BranchConfig) -> Result<Vec<String>, String> {
88    let mut c = BackendClient::connect(&admin_cfg(cfg))
89        .await
90        .map_err(|e| format!("admin connect: {}", e))?;
91    let res = c
92        .simple_query(
93            "SELECT datname FROM pg_database \
94             WHERE datistemplate = false AND datname NOT IN ('postgres','template0','template1') \
95             ORDER BY datname",
96        )
97        .await
98        .map_err(|e| format!("list branches: {}", e));
99    c.close().await;
100    let res = res?;
101    Ok(res
102        .rows
103        .into_iter()
104        .filter_map(|row| match row.into_iter().next() {
105            Some(TextValue::Text(s)) => Some(s),
106            _ => None,
107        })
108        .collect())
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114
115    #[test]
116    fn ident_validation() {
117        assert!(valid_ident("feature_x"));
118        assert!(valid_ident("_b1"));
119        assert!(!valid_ident("1bad"));
120        assert!(!valid_ident("drop;table"));
121        assert!(!valid_ident("a b"));
122        assert!(!valid_ident(""));
123        assert!(!valid_ident("\"inject\""));
124    }
125}