use std::collections::BTreeMap;
use std::collections::HashMap;
use std::sync::{Arc, Mutex, OnceLock};
use sha2::{Digest, Sha256};
use crate::value::VmValue;
use super::PoolRecord;
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub(super) struct PoolKey(String);
impl PoolKey {
pub(super) fn new(
primary_url: &str,
replica_urls: &[String],
options: Option<&BTreeMap<String, VmValue>>,
single_connection: bool,
) -> Self {
let mut hasher = Sha256::new();
hasher.update(b"harn-pg-shared-pool-key\x01");
hasher.update(b"single_connection:");
hasher.update([u8::from(single_connection)]);
hasher.update(b"\x00primary:");
hash_len_prefixed(&mut hasher, primary_url.as_bytes());
hasher.update(b"\x00replicas:");
hasher.update((replica_urls.len() as u64).to_le_bytes());
for url in replica_urls {
hash_len_prefixed(&mut hasher, url.as_bytes());
}
hasher.update(b"\x00options:");
if let Some(options) = options {
let canonical: BTreeMap<&str, String> = options
.iter()
.map(|(key, value)| (key.as_str(), canonical_option_value(value)))
.collect();
hasher.update((canonical.len() as u64).to_le_bytes());
for (key, value) in canonical {
hash_len_prefixed(&mut hasher, key.as_bytes());
hash_len_prefixed(&mut hasher, value.as_bytes());
}
} else {
hasher.update(0u64.to_le_bytes());
}
PoolKey(hex::encode(hasher.finalize()))
}
}
fn hash_len_prefixed(hasher: &mut Sha256, bytes: &[u8]) {
hasher.update((bytes.len() as u64).to_le_bytes());
hasher.update(bytes);
}
fn canonical_option_value(value: &VmValue) -> String {
match value {
VmValue::List(items) => {
let mut out = String::from("[");
for (i, item) in items.iter().enumerate() {
if i > 0 {
out.push(',');
}
out.push_str(&canonical_option_value(item));
}
out.push(']');
out
}
VmValue::Dict(dict) => {
let sorted: BTreeMap<&String, &VmValue> = dict.iter().collect();
let mut out = String::from("{");
for (i, (key, val)) in sorted.iter().enumerate() {
if i > 0 {
out.push(',');
}
out.push_str(key);
out.push('=');
out.push_str(&canonical_option_value(val));
}
out.push('}');
out
}
other => other.display(),
}
}
type SharedRegistry = Mutex<HashMap<PoolKey, Arc<PoolRecord>>>;
static SHARED_POOLS: OnceLock<SharedRegistry> = OnceLock::new();
pub fn install_shared_pool_registry() {
let _ = SHARED_POOLS.get_or_init(|| Mutex::new(HashMap::new()));
}
pub(super) fn is_installed() -> bool {
SHARED_POOLS.get().is_some()
}
pub(super) fn get(key: &PoolKey) -> Option<Arc<PoolRecord>> {
let registry = SHARED_POOLS.get()?;
let guard = registry.lock().expect("shared pg pool registry poisoned");
guard.get(key).map(Arc::clone)
}
pub(super) fn get_or_insert(key: PoolKey, record: Arc<PoolRecord>) -> Arc<PoolRecord> {
let Some(registry) = SHARED_POOLS.get() else {
return record;
};
let mut guard = registry.lock().expect("shared pg pool registry poisoned");
Arc::clone(guard.entry(key).or_insert(record))
}
#[cfg(test)]
pub(super) fn clear_for_test() {
if let Some(registry) = SHARED_POOLS.get() {
registry
.lock()
.expect("shared pg pool registry poisoned")
.clear();
}
}
#[cfg(test)]
pub(super) fn len_for_test() -> usize {
SHARED_POOLS
.get()
.map(|registry| registry.lock().expect("poisoned").len())
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
fn s(value: &str) -> VmValue {
VmValue::String(std::sync::Arc::from(value))
}
fn opts(pairs: &[(&str, VmValue)]) -> BTreeMap<String, VmValue> {
pairs
.iter()
.map(|(k, v)| ((*k).to_string(), v.clone()))
.collect()
}
const URL_A: &str = "postgres://app:secret@db.internal:5432/tenants";
const URL_B: &str = "postgres://app:secret@db.internal:5432/other";
const URL_A_DIFF_PW: &str = "postgres://app:HUNTER2@db.internal:5432/tenants";
#[test]
fn same_identity_same_key() {
let o = opts(&[("max_connections", VmValue::Int(5))]);
let k1 = PoolKey::new(URL_A, &[], Some(&o), false);
let k2 = PoolKey::new(URL_A, &[], Some(&o), false);
assert_eq!(k1, k2);
}
#[test]
fn different_database_different_key() {
let k1 = PoolKey::new(URL_A, &[], None, false);
let k2 = PoolKey::new(URL_B, &[], None, false);
assert_ne!(k1, k2);
}
#[test]
fn different_credentials_different_key() {
let k1 = PoolKey::new(URL_A, &[], None, false);
let k2 = PoolKey::new(URL_A_DIFF_PW, &[], None, false);
assert_ne!(k1, k2);
}
#[test]
fn single_connection_flag_distinguishes_key() {
let k_pool = PoolKey::new(URL_A, &[], None, false);
let k_conn = PoolKey::new(URL_A, &[], None, true);
assert_ne!(k_pool, k_conn);
}
#[test]
fn option_order_is_canonical() {
let o1 = opts(&[
("max_connections", VmValue::Int(5)),
("application_name", s("svc")),
]);
let o2 = opts(&[
("application_name", s("svc")),
("max_connections", VmValue::Int(5)),
]);
assert_eq!(
PoolKey::new(URL_A, &[], Some(&o1), false),
PoolKey::new(URL_A, &[], Some(&o2), false)
);
}
#[test]
fn different_pool_shape_different_key() {
let o1 = opts(&[("max_connections", VmValue::Int(5))]);
let o2 = opts(&[("max_connections", VmValue::Int(20))]);
assert_ne!(
PoolKey::new(URL_A, &[], Some(&o1), false),
PoolKey::new(URL_A, &[], Some(&o2), false)
);
}
#[test]
fn different_application_name_different_key() {
let o1 = opts(&[("application_name", s("svc-a"))]);
let o2 = opts(&[("application_name", s("svc-b"))]);
assert_ne!(
PoolKey::new(URL_A, &[], Some(&o1), false),
PoolKey::new(URL_A, &[], Some(&o2), false)
);
}
#[test]
fn replica_set_is_part_of_key() {
let r1 = vec![URL_B.to_string()];
let r2 = vec![URL_B.to_string(), URL_A.to_string()];
assert_ne!(
PoolKey::new(URL_A, &[], None, false),
PoolKey::new(URL_A, &r1, None, false)
);
assert_ne!(
PoolKey::new(URL_A, &r1, None, false),
PoolKey::new(URL_A, &r2, None, false)
);
}
#[test]
fn nested_option_dicts_affect_key() {
let cb1 = VmValue::Dict(std::sync::Arc::new(opts(&[(
"failure_threshold",
VmValue::Int(3),
)])));
let cb2 = VmValue::Dict(std::sync::Arc::new(opts(&[(
"failure_threshold",
VmValue::Int(9),
)])));
let o1 = opts(&[("circuit_breaker", cb1)]);
let o2 = opts(&[("circuit_breaker", cb2)]);
assert_ne!(
PoolKey::new(URL_A, &[], Some(&o1), false),
PoolKey::new(URL_A, &[], Some(&o2), false)
);
}
#[test]
fn key_does_not_leak_plaintext_credentials() {
let key = PoolKey::new(URL_A_DIFF_PW, &[], None, false);
assert!(!key.0.contains("HUNTER2"));
assert_eq!(key.0.len(), 64); assert!(key.0.chars().all(|c| c.is_ascii_hexdigit()));
}
#[test]
fn not_installed_returns_none_by_default() {
let key = PoolKey::new(
"postgres://nobody@nowhere/db_never_inserted",
&[],
None,
true,
);
assert!(get(&key).is_none());
}
}