use distributed_lock_core::error::{LockError, LockResult};
use sha1::{Digest, Sha1};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum PostgresAdvisoryLockKey {
Single(i64),
Pair(i32, i32),
}
impl PostgresAdvisoryLockKey {
const MAX_ASCII_LENGTH: usize = 9;
const ASCII_CHAR_BITS: u32 = 7;
const MAX_ASCII_VALUE: u32 = (1 << Self::ASCII_CHAR_BITS) - 1;
const HASH_STRING_LENGTH: usize = 16;
const HASH_PART_LENGTH: usize = 8;
const HASH_STRING_SEPARATOR: char = ',';
pub fn from_name(name: &str, allow_hashing: bool) -> LockResult<Self> {
if name.is_empty() {
return Err(LockError::InvalidName(
"lock name cannot be empty".to_string(),
));
}
if let Some(key) = Self::try_encode_ascii(name) {
return Ok(Self::Single(key));
}
if let Some(key) = Self::try_parse_hex_string(name) {
return Ok(key);
}
if let Some(key) = Self::try_parse_pair_string(name) {
return Ok(key);
}
if allow_hashing {
let hash = Self::hash_string(name);
return Ok(Self::Single(hash));
}
Err(LockError::InvalidName(format!(
"Name '{}' could not be encoded as a PostgresAdvisoryLockKey. Please specify allow_hashing or use one of the following formats: (1) a 0-{} character string using only ASCII characters, (2) a {} character hex string, or (3) a 2-part, {} character string of the form XXXXXXXX{}XXXXXXXX",
name,
Self::MAX_ASCII_LENGTH,
Self::HASH_STRING_LENGTH,
Self::HASH_PART_LENGTH * 2 + 1,
Self::HASH_STRING_SEPARATOR
)))
}
fn try_encode_ascii(name: &str) -> Option<i64> {
if name.len() > Self::MAX_ASCII_LENGTH {
return None;
}
let mut result = 0i64;
for ch in name.chars() {
let ch_val = ch as u32;
if ch_val > Self::MAX_ASCII_VALUE {
return None;
}
result = (result << Self::ASCII_CHAR_BITS) | (ch_val as i64);
}
result <<= 1;
for _ in name.len()..Self::MAX_ASCII_LENGTH {
result = (result << Self::ASCII_CHAR_BITS) | (Self::MAX_ASCII_VALUE as i64);
}
Some(result)
}
fn try_parse_hex_string(name: &str) -> Option<Self> {
if name.len() != Self::HASH_STRING_LENGTH {
return None;
}
i64::from_str_radix(name, 16).ok().map(Self::Single)
}
fn try_parse_pair_string(name: &str) -> Option<Self> {
let parts: Vec<&str> = name.split(Self::HASH_STRING_SEPARATOR).collect();
if parts.len() != 2 {
return None;
}
let key1 = i32::from_str_radix(parts[0], 16).ok()?;
let key2 = i32::from_str_radix(parts[1], 16).ok()?;
Some(Self::Pair(key1, key2))
}
#[allow(clippy::disallowed_methods)]
fn hash_string(name: &str) -> i64 {
let mut hasher = Sha1::new();
hasher.update(name.as_bytes());
let hash_bytes = hasher.finalize();
let mut result = 0i64;
for i in (0..8).rev() {
result = (result << 8) | (hash_bytes[i] as i64);
}
result
}
pub fn has_single_key(&self) -> bool {
matches!(self, Self::Single(_))
}
pub fn key(&self) -> i64 {
match self {
Self::Single(k) => *k,
Self::Pair(_, _) => panic!("key() called on Pair variant"),
}
}
pub fn keys(&self) -> (i32, i32) {
match self {
Self::Single(k) => {
let upper = (*k >> 32) as i32;
let lower = (*k & 0xFFFFFFFF) as i32;
(upper, lower)
}
Self::Pair(k1, k2) => (*k1, *k2),
}
}
#[allow(clippy::wrong_self_convention)]
pub fn to_sql_args(&self) -> String {
match self {
Self::Single(k) => format!("{k}"),
Self::Pair(k1, k2) => format!("{k1}, {k2}"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ascii_encoding() {
let key = PostgresAdvisoryLockKey::from_name("test", false).unwrap();
assert!(key.has_single_key());
}
#[test]
fn test_hex_encoding() {
let hex_str = "0000000000000001";
let key = PostgresAdvisoryLockKey::from_name(hex_str, false).unwrap();
assert!(key.has_single_key());
assert_eq!(key.key(), 1);
}
#[test]
fn test_pair_encoding() {
let pair_str = "00000001,00000002";
let key = PostgresAdvisoryLockKey::from_name(pair_str, false).unwrap();
match key {
PostgresAdvisoryLockKey::Pair(k1, k2) => {
assert_eq!(k1, 1);
assert_eq!(k2, 2);
}
_ => panic!("Expected Pair variant"),
}
}
#[test]
fn test_hash_encoding() {
let long_name = "this is a very long lock name that needs hashing";
let key = PostgresAdvisoryLockKey::from_name(long_name, true).unwrap();
assert!(key.has_single_key());
}
}