use crate::interface::{Cell, Column, Row};
use pgrx::{
IntoDatum,
list::List,
pg_sys::panic::{ErrorReport, ErrorReportable},
spi::Spi,
*,
};
use std::ffi::CStr;
use std::ffi::c_void;
use std::num::NonZeroUsize;
use std::ptr;
use thiserror::Error;
use tokio::runtime::{Builder, Runtime};
use uuid::Uuid;
const SENSITIVE_OPTION_NAMES: &[&str] = &[
"password",
"secret",
"token",
"api_key",
"apikey",
"api-key",
"auth_token",
"access_token",
"refresh_token",
"private_key",
"privatekey",
"credentials",
"credential",
"aws_secret_access_key",
"secret_access_key",
"aws_session_token",
"session_token",
"service_account_key",
"sa_key",
"client_secret",
"storage_key",
"connection_string",
"conn_string",
"connection_str",
"db_password",
"stripe_api_key",
"firebase_credentials",
"motherduck_token",
"jwt_secret",
"encryption_key",
"signing_key",
];
#[inline]
pub fn mask_credential_value(value: &str) -> String {
if let Some((byte_idx, _)) = value.char_indices().nth(4) {
format!("{}***", &value[..byte_idx])
} else {
"***".to_string()
}
}
#[inline]
pub fn is_sensitive_option(option_name: &str) -> bool {
let lower = option_name.to_lowercase();
SENSITIVE_OPTION_NAMES.iter().any(|&s| lower.contains(s))
}
pub fn mask_credentials_in_message(message: &str) -> String {
let mut result = message.to_string();
for sensitive_name in SENSITIVE_OPTION_NAMES {
let lower_name = sensitive_name.to_lowercase();
let name_len = lower_name.len();
let mut search_start = 0;
loop {
let lower_result = result.to_lowercase();
let Some(relative_pos) = lower_result[search_start..].find(&lower_name) else {
break;
};
let abs_pos = search_start + relative_pos;
let after_name = &result[abs_pos + name_len..];
let Some(eq_pos) = after_name.find('=').or_else(|| after_name.find(':')) else {
search_start = abs_pos + name_len;
continue;
};
let eq_abs = abs_pos + name_len + eq_pos;
let after_eq = &result[eq_abs + 1..];
let after_eq_trim = after_eq.trim_start();
let trim_offset = after_eq.len() - after_eq_trim.len();
let value_start = eq_abs + 1 + trim_offset;
let (value_start, value_end) = if let Some(stripped) = after_eq_trim.strip_prefix('\'')
{
match stripped.find('\'') {
Some(end) => (value_start + 1, value_start + 1 + end),
None => {
search_start = abs_pos + name_len;
continue;
}
}
} else if let Some(stripped) = after_eq_trim.strip_prefix('"') {
match stripped.find('"') {
Some(end) => (value_start + 1, value_start + 1 + end),
None => {
search_start = abs_pos + name_len;
continue;
}
}
} else {
let end = after_eq_trim
.find(|c: char| c.is_whitespace() || ",;)&#".contains(c))
.unwrap_or(after_eq_trim.len());
(value_start, value_start + end)
};
if value_end <= value_start {
search_start = abs_pos + name_len;
continue;
}
let masked = mask_credential_value(&result[value_start..value_end]);
result = format!(
"{}{}{}",
&result[..value_start],
masked,
&result[value_end..]
);
search_start = abs_pos + name_len;
}
}
result
}
#[inline]
pub fn sanitize_error_message(message: &str) -> String {
mask_credentials_in_message(message)
}
#[inline]
pub fn log_debug1(msg: &str) {
debug1!("wrappers: {}", msg);
}
#[inline]
pub fn report_info(msg: &str) {
ereport!(
PgLogLevel::INFO,
PgSqlErrorCode::ERRCODE_SUCCESSFUL_COMPLETION,
msg,
"Wrappers"
);
}
#[inline]
pub fn report_notice(msg: &str) {
ereport!(
PgLogLevel::NOTICE,
PgSqlErrorCode::ERRCODE_SUCCESSFUL_COMPLETION,
msg,
"Wrappers"
);
}
#[inline]
pub fn report_warning(msg: &str) {
ereport!(
PgLogLevel::WARNING,
PgSqlErrorCode::ERRCODE_WARNING,
msg,
"Wrappers"
);
}
#[inline]
pub fn report_error(code: PgSqlErrorCode, msg: &str) {
ereport!(PgLogLevel::ERROR, code, msg, "Wrappers");
}
#[derive(Error, Debug)]
pub enum CreateRuntimeError {
#[error("failed to create async runtime: {0}")]
FailedToCreateAsyncRuntime(#[from] std::io::Error),
}
impl From<CreateRuntimeError> for ErrorReport {
fn from(value: CreateRuntimeError) -> Self {
let error_message = format!("{value}");
ErrorReport::new(PgSqlErrorCode::ERRCODE_FDW_ERROR, error_message, "")
}
}
#[inline]
pub fn create_async_runtime() -> Result<Runtime, CreateRuntimeError> {
Ok(Builder::new_current_thread().enable_all().build()?)
}
pub fn get_vault_secret(secret_id: &str) -> Option<String> {
match Uuid::try_parse(secret_id) {
Ok(sid) => {
let sid = sid.into_bytes();
match Spi::get_one_with_args::<String>(
"select decrypted_secret from vault.decrypted_secrets where id = $1 or key_id = $1",
&[pgrx::Uuid::from_bytes(sid).into()],
) {
Ok(decrypted) => decrypted,
Err(err) => {
report_error(
PgSqlErrorCode::ERRCODE_FDW_ERROR,
&format!("query vault failed \"{secret_id}\": {err}"),
);
None
}
}
}
Err(err) => {
report_error(
PgSqlErrorCode::ERRCODE_FDW_ERROR,
&format!("invalid secret id \"{secret_id}\": {err}"),
);
None
}
}
}
pub fn get_vault_secret_by_name(secret_name: &str) -> Option<String> {
match Spi::get_one_with_args::<String>(
"select decrypted_secret from vault.decrypted_secrets where name = $1",
&[secret_name.into()],
) {
Ok(decrypted) => decrypted,
Err(err) => {
report_error(
PgSqlErrorCode::ERRCODE_FDW_ERROR,
&format!("query vault failed \"{secret_name}\": {err}"),
);
None
}
}
}
pub(super) unsafe fn tuple_table_slot_to_row(slot: *mut pg_sys::TupleTableSlot) -> Row {
let tup_desc = unsafe { PgTupleDesc::from_pg_copy((*slot).tts_tupleDescriptor) };
let mut should_free = false;
let htup = unsafe { pg_sys::ExecFetchSlotHeapTuple(slot, false, &mut should_free) };
let htup = unsafe { PgBox::from_pg(htup) };
let mut row = Row::new();
for (att_idx, attr) in tup_desc.iter().filter(|a| !a.attisdropped).enumerate() {
let col = pgrx::name_data_to_str(&attr.attname);
let attno = NonZeroUsize::new(att_idx + 1).unwrap();
let cell: Option<Cell> = pgrx::htup::heap_getattr(&htup, attno, &tup_desc);
row.push(col, cell);
}
row
}
pub(super) unsafe fn extract_target_columns(
root: *mut pg_sys::PlannerInfo,
baserel: *mut pg_sys::RelOptInfo,
) -> Vec<Column> {
unsafe {
let mut ret = Vec::new();
let mut col_vars: *mut pg_sys::List = ptr::null_mut();
memcx::current_context(|mcx| {
if let Some(tgt_list) =
List::<*mut c_void>::downcast_ptr_in_memcx((*(*baserel).reltarget).exprs, mcx)
{
for tgt in tgt_list.iter() {
let tgt_cols = pg_sys::pull_var_clause(
*tgt as _,
(pg_sys::PVC_RECURSE_AGGREGATES | pg_sys::PVC_RECURSE_PLACEHOLDERS)
.try_into()
.unwrap(),
);
col_vars = pg_sys::list_union(col_vars, tgt_cols);
}
}
if let Some(conds) =
List::<*mut c_void>::downcast_ptr_in_memcx((*baserel).baserestrictinfo, mcx)
{
for cond in conds.iter() {
let expr = (*(*cond as *mut pg_sys::RestrictInfo)).clause;
let tgt_cols = pg_sys::pull_var_clause(
expr as _,
(pg_sys::PVC_RECURSE_AGGREGATES | pg_sys::PVC_RECURSE_PLACEHOLDERS)
.try_into()
.unwrap(),
);
col_vars = pg_sys::list_union(col_vars, tgt_cols);
}
}
if let Some(col_vars) = List::<*mut c_void>::downcast_ptr_in_memcx(col_vars, mcx) {
for var in col_vars.iter() {
let var: pg_sys::Var = *(*var as *mut pg_sys::Var);
let rte = pg_sys::planner_rt_fetch(var.varno as _, root);
let attno = var.varattno;
let attname = pg_sys::get_attname((*rte).relid, attno, true);
if !attname.is_null() {
if pg_sys::get_attgenerated((*rte).relid, attno) > 0 {
report_warning("generated column is not supported");
continue;
}
let type_oid = pg_sys::get_atttype((*rte).relid, attno);
ret.push(Column {
name: CStr::from_ptr(attname).to_str().unwrap().to_owned(),
num: attno as usize,
type_oid,
});
}
}
}
});
ret
}
}
pub(super) trait SerdeList {
unsafe fn serialize_to_list(state: PgBox<Self>) -> *mut pg_sys::List
where
Self: Sized,
{
unsafe {
memcx::current_context(|mcx| {
let mut ret = List::<*mut c_void>::Nil;
let val = state.into_pg() as i64;
let cst: *mut pg_sys::Const = pg_sys::makeConst(
pg_sys::INT8OID,
-1,
pg_sys::InvalidOid,
8,
val.into_datum().unwrap(),
false,
true,
);
ret.unstable_push_in_context(cst as _, mcx);
ret.into_ptr()
})
}
}
unsafe fn deserialize_from_list(list: *mut pg_sys::List) -> PgBox<Self>
where
Self: Sized,
{
unsafe {
memcx::current_context(|mcx| {
if let Some(list) = List::<*mut c_void>::downcast_ptr_in_memcx(list, mcx)
&& let Some(cst) = list.get(0)
{
let cst = *(*cst as *mut pg_sys::Const);
let ptr = i64::from_datum(cst.constvalue, cst.constisnull).unwrap();
return PgBox::<Self>::from_pg(ptr as _);
}
PgBox::<Self>::null()
})
}
}
}
pub(crate) trait ReportableError {
type Output;
fn report_unwrap(self) -> Self::Output;
}
impl<T, E: Into<ErrorReport>> ReportableError for Result<T, E> {
type Output = T;
fn report_unwrap(self) -> Self::Output {
self.map_err(|e| e.into()).unwrap_or_report()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mask_credential_value_long_value() {
assert_eq!(mask_credential_value("wJalrXUtnFEMI/EXAMPLEKEY"), "wJal***");
assert_eq!(mask_credential_value("12345678"), "1234***");
assert_eq!(mask_credential_value("abcde"), "abcd***");
}
#[test]
fn test_mask_credential_value_short_value() {
assert_eq!(mask_credential_value("abc"), "***");
assert_eq!(mask_credential_value("abcd"), "***");
assert_eq!(mask_credential_value("a"), "***");
assert_eq!(mask_credential_value(""), "***");
}
#[test]
fn test_mask_credential_value_utf8() {
assert_eq!(mask_credential_value("pässwörd"), "päss***");
assert_eq!(mask_credential_value("短い"), "***");
}
#[test]
fn test_is_sensitive_option_generic_patterns() {
assert!(is_sensitive_option("password"));
assert!(is_sensitive_option("secret"));
assert!(is_sensitive_option("token"));
assert!(is_sensitive_option("api_key"));
assert!(is_sensitive_option("credentials"));
}
#[test]
fn test_is_sensitive_option_aws_patterns() {
assert!(is_sensitive_option("aws_secret_access_key"));
assert!(is_sensitive_option("secret_access_key"));
assert!(is_sensitive_option("aws_session_token"));
}
#[test]
fn test_is_sensitive_option_case_insensitive() {
assert!(is_sensitive_option("PASSWORD"));
assert!(is_sensitive_option("API_KEY"));
assert!(is_sensitive_option("AWS_SECRET_ACCESS_KEY"));
assert!(is_sensitive_option("Token"));
}
#[test]
fn test_is_sensitive_option_non_sensitive() {
assert!(!is_sensitive_option("region"));
assert!(!is_sensitive_option("bucket"));
assert!(!is_sensitive_option("endpoint_url"));
assert!(!is_sensitive_option("table_name"));
}
#[test]
fn test_is_sensitive_option_partial_match() {
assert!(is_sensitive_option("my_api_key_here"));
assert!(is_sensitive_option("stripe_api_key"));
}
#[test]
fn test_mask_credentials_sql_style_single_quotes() {
let msg = "Error: aws_secret_access_key = 'wJalrXUtnFEMI/EXAMPLEKEY' is invalid";
let masked = mask_credentials_in_message(msg);
assert!(!masked.contains("wJalrXUtnFEMI"));
assert!(masked.contains("wJal***"));
}
#[test]
fn test_mask_credentials_sql_style_double_quotes() {
let msg = r#"Error: password = "mysecretpassword" for user"#;
let masked = mask_credentials_in_message(msg);
assert!(!masked.contains("mysecretpassword"));
assert!(masked.contains("myse***"));
}
#[test]
fn test_mask_credentials_json_style() {
let msg = r#"{"api_key": "sk_live_12345678abcd"}"#;
let masked = mask_credentials_in_message(msg);
assert!(!masked.contains("sk_live_12345678abcd"));
assert!(masked.contains("sk_l***"));
}
#[test]
fn test_mask_credentials_unquoted() {
let msg = "Error: token=abc123xyz is expired";
let masked = mask_credentials_in_message(msg);
assert!(!masked.contains("abc123xyz"));
assert!(masked.contains("abc1***"));
}
#[test]
fn test_mask_credentials_url_params() {
let msg = "Error: api_key=sk_live_123®ion=us-west-2";
let masked = mask_credentials_in_message(msg);
assert!(!masked.contains("sk_live_123"));
assert!(masked.contains("sk_l***"));
assert!(masked.contains("®ion=us-west-2"));
}
#[test]
fn test_mask_credentials_url_with_fragment() {
let msg = "Error: api_key=secret_token_123#section";
let masked = mask_credentials_in_message(msg);
assert!(!masked.contains("secret_token_123"));
assert!(masked.contains("secr***"));
assert!(masked.contains("#section"));
}
#[test]
fn test_mask_credentials_multiple_occurrences() {
let msg = "password='first123' and api_key='second456'";
let masked = mask_credentials_in_message(msg);
assert!(!masked.contains("first123"));
assert!(!masked.contains("second456"));
assert!(masked.contains("firs***"));
assert!(masked.contains("seco***"));
}
#[test]
fn test_mask_credentials_no_sensitive_data() {
let msg = "Error: region = 'us-west-2' is not available";
let masked = mask_credentials_in_message(msg);
assert_eq!(masked, msg);
}
#[test]
fn test_mask_credentials_case_insensitive() {
let msg = "Error: PASSWORD = 'secret123' failed";
let masked = mask_credentials_in_message(msg);
assert!(!masked.contains("secret123"));
assert!(masked.contains("secr***"));
}
#[test]
fn test_mask_credentials_empty_value() {
let msg = "Error: password = '' is empty";
let masked = mask_credentials_in_message(msg);
assert!(masked.contains("''"));
}
#[test]
fn test_mask_credentials_no_value_after_key() {
let msg = "password option is missing";
let masked = mask_credentials_in_message(msg);
assert_eq!(masked, msg);
}
#[test]
fn test_mask_credentials_unclosed_quote() {
let msg = "Error: password = 'unclosed";
let masked = mask_credentials_in_message(msg);
assert!(masked.contains("password"));
}
#[test]
fn test_mask_credentials_service_specific() {
let msg = "stripe_api_key = 'sk_test_1234567890abcdef'";
let masked = mask_credentials_in_message(msg);
assert!(!masked.contains("sk_test_1234567890abcdef"));
assert!(masked.contains("sk_t***"));
}
#[test]
fn test_sanitize_error_message() {
let error = "Connection failed with password='secret123' for user 'admin'";
let safe_error = sanitize_error_message(error);
assert!(!safe_error.contains("secret123"));
assert!(safe_error.contains("secr***"));
assert!(safe_error.contains("admin")); }
#[test]
fn test_sanitize_error_message_complex() {
let error =
"Failed: aws_secret_access_key='AKIAIOSFODNN7EXAMPLE', api_key=\"test_key_123\"";
let safe_error = sanitize_error_message(error);
assert!(!safe_error.contains("AKIAIOSFODNN7EXAMPLE"));
assert!(!safe_error.contains("test_key_123"));
}
}