use std::fmt::Debug;
use std::sync::Arc;
use async_trait::async_trait;
use futures::SinkExt;
use futures::sink::Sink;
use pgwire::api::auth::noop::NoopStartupHandler;
use pgwire::api::portal::Portal;
use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
use pgwire::api::results::{DescribePortalResponse, DescribeStatementResponse, Response};
use pgwire::api::stmt::StoredStatement;
use pgwire::api::store::PortalStore;
use pgwire::api::{ClientInfo, ClientPortalStore};
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
use pgwire::messages::PgWireBackendMessage;
use pgwire::messages::PgWireFrontendMessage;
use crate::bridge::envelope::PhysicalPlan;
use crate::config::auth::AuthMode;
use crate::control::planner::context::QueryContext;
use crate::control::security::audit::{
AuditEmitContext, AuditEmitter, AuditEvent, NoopAuditEmitter,
};
use crate::control::security::identity::{
AuthMethod, AuthenticatedIdentity, Role, required_permission, role_grants_permission,
};
use crate::control::state::SharedState;
use crate::types::{RequestId, TenantId};
use super::super::session::{SessionStore, TransactionState};
use super::super::types::notice_warning;
use super::plan::extract_collection;
use super::prepared::{NodeDbQueryParser, ParsedStatement};
pub struct NodeDbPgHandler {
pub(crate) state: Arc<SharedState>,
pub(super) query_ctx: QueryContext,
query_parser: Arc<NodeDbQueryParser>,
auth_mode: AuthMode,
pub(crate) sessions: SessionStore,
pub(crate) restore_state: Arc<crate::control::backup::RestoreState>,
}
impl NodeDbPgHandler {
pub fn new(state: Arc<SharedState>, auth_mode: AuthMode) -> Self {
let query_ctx = QueryContext::for_state_with_lease(&state);
let query_parser = Arc::new(NodeDbQueryParser::new(Arc::clone(&state)));
Self {
state,
query_ctx,
query_parser,
auth_mode,
sessions: SessionStore::new(),
restore_state: Arc::new(crate::control::backup::RestoreState::new()),
}
}
pub(super) fn next_request_id(&self) -> RequestId {
self.state.next_request_id()
}
pub(crate) fn resolve_identity<C: ClientInfo>(
&self,
client: &C,
) -> PgWireResult<AuthenticatedIdentity> {
let username = client
.metadata()
.get("user")
.cloned()
.unwrap_or_else(|| "unknown".to_string());
match self.auth_mode {
AuthMode::Trust => {
self.state
.credentials
.to_identity(&username, AuthMethod::Trust)
.ok_or_else(|| {
PgWireError::UserError(Box::new(ErrorInfo::new(
"FATAL".to_owned(),
"28000".to_owned(),
format!("trust auth: user '{username}' does not exist"),
)))
})
}
AuthMode::Password | AuthMode::Certificate => self
.state
.credentials
.to_identity(&username, AuthMethod::ScramSha256)
.ok_or_else(|| {
PgWireError::UserError(Box::new(ErrorInfo::new(
"FATAL".to_owned(),
"28000".to_owned(),
format!("authenticated user '{username}' not found in credential store"),
)))
}),
}
}
pub(super) fn enforce_database_access(
&self,
identity: &AuthenticatedIdentity,
addr: &std::net::SocketAddr,
) -> PgWireResult<()> {
if identity.is_superuser {
return Ok(());
}
let db = self
.sessions
.get_current_database(addr)
.unwrap_or(crate::types::DatabaseId::DEFAULT);
if !identity.can_access_database(db) {
let emitter = crate::control::security::audit::ArcAuditEmitter(std::sync::Arc::clone(
&self.state.audit,
));
emitter.emit(
AuditEvent::PermissionDenied,
&identity.username,
&format!("database access denied: db={}", db.as_u64()),
AuditEmitContext::new(
Some(identity.tenant_id),
&identity.user_id.to_string(),
&identity.username,
),
);
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
"FATAL".to_owned(),
"42501".to_owned(),
format!(
"permission denied for database: user '{}' does not have access",
identity.username
),
))));
}
Ok(())
}
pub(super) fn check_permission(
&self,
identity: &AuthenticatedIdentity,
plan: &PhysicalPlan,
) -> PgWireResult<()> {
if identity.is_superuser {
return Ok(());
}
let required = required_permission(plan);
let collection = extract_collection(plan);
if let Some(coll) = collection
&& coll.starts_with("_system")
{
let emitter = crate::control::security::audit::ArcAuditEmitter(std::sync::Arc::clone(
&self.state.audit,
));
emitter.emit(
AuditEvent::PermissionDenied,
&identity.username,
&format!("system catalog access denied: {coll}"),
AuditEmitContext::new(
Some(identity.tenant_id),
&identity.user_id.to_string(),
&identity.username,
),
);
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
"ERROR".to_owned(),
"42501".to_owned(),
"permission denied: system catalog access requires superuser".to_owned(),
))));
}
if let Some(coll) = collection
&& self.state.permissions.check(
identity,
required,
coll,
&self.state.roles,
&NoopAuditEmitter,
)
{
return Ok(());
}
let has_permission = identity
.roles
.iter()
.any(|role| role_grants_permission(role, required));
if has_permission {
Ok(())
} else {
let emitter = crate::control::security::audit::ArcAuditEmitter(std::sync::Arc::clone(
&self.state.audit,
));
emitter.emit(
AuditEvent::PermissionDenied,
&identity.username,
&format!(
"permission {:?} denied{}",
required,
collection.map(|c| format!(" on '{c}'")).unwrap_or_default()
),
AuditEmitContext::new(
Some(identity.tenant_id),
&identity.user_id.to_string(),
&identity.username,
),
);
Err(PgWireError::UserError(Box::new(ErrorInfo::new(
"ERROR".to_owned(),
"42501".to_owned(),
format!(
"permission denied: user '{}' lacks {:?} permission{}",
identity.username,
required,
collection.map(|c| format!(" on '{c}'")).unwrap_or_default()
),
))))
}
}
}
#[async_trait]
impl SimpleQueryHandler for NodeDbPgHandler {
async fn do_query<C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response>>
where
C: ClientInfo + ClientPortalStore + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
let addr = client.socket_addr();
self.sessions.ensure_session(addr);
let identity = self.resolve_identity(client)?;
self.enforce_database_access(&identity, &addr)?;
let current_db = self
.sessions
.get_current_database(&addr)
.unwrap_or(crate::types::DatabaseId::DEFAULT);
let db_name: String = self
.state
.credentials
.catalog()
.as_ref()
.and_then(|cat| cat.get_database(current_db).ok().flatten())
.map(|d| d.name.clone())
.unwrap_or_else(|| "default".to_string());
tracing::debug!(
db.id = current_db.as_u64(),
db.name = %db_name,
user = %identity.username,
"session query dispatch",
);
let upper = query.trim().to_uppercase();
if (upper == "BEGIN" || upper == "BEGIN TRANSACTION" || upper == "START TRANSACTION")
&& self.sessions.transaction_state(&addr) == TransactionState::InBlock
{
let notice = notice_warning("there is already a transaction in progress");
let _ = client
.send(PgWireBackendMessage::NoticeResponse(notice))
.await;
}
if (upper == "COMMIT" || upper == "END")
&& self.sessions.transaction_state(&addr) == TransactionState::Idle
{
let notice = notice_warning("there is no transaction in progress");
let _ = client
.send(PgWireBackendMessage::NoticeResponse(notice))
.await;
}
let _audit_scope = crate::control::server::pgwire::session::audit_context::AuditScope::new(
crate::control::server::pgwire::session::audit_context::AuditCtx {
auth_user_id: identity.user_id.to_string(),
auth_user_name: identity.username.clone(),
sql_text: query.to_string(),
},
);
let result = self.execute_sql(&identity, &addr, query).await;
for message in self.sessions.drain_notices(&addr) {
let notice = notice_warning(&message);
let _ = client
.send(PgWireBackendMessage::NoticeResponse(notice))
.await;
}
if self.sessions.has_live_subscriptions(&addr) {
let notifications = self.sessions.drain_live_notifications(&addr);
for (channel, payload) in notifications {
let notification = pgwire::messages::response::NotificationResponse::new(
0, channel, payload,
);
let _ = client
.send(PgWireBackendMessage::NotificationResponse(notification))
.await;
}
}
if self.sessions.has_listen_subscriptions(&addr) {
let notifications = self.sessions.drain_listen_notifications(&addr);
for n in notifications {
let notification = pgwire::messages::response::NotificationResponse::new(
n.pid, n.channel, n.payload,
);
let _ = client
.send(PgWireBackendMessage::NotificationResponse(notification))
.await;
}
}
result
}
}
#[async_trait]
impl ExtendedQueryHandler for NodeDbPgHandler {
type Statement = ParsedStatement;
type QueryParser = NodeDbQueryParser;
fn query_parser(&self) -> Arc<Self::QueryParser> {
self.query_parser.clone()
}
async fn do_query<C>(
&self,
client: &mut C,
portal: &Portal<Self::Statement>,
max_rows: usize,
) -> PgWireResult<Response>
where
C: ClientInfo + ClientPortalStore + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
C::PortalStore: PortalStore<Statement = Self::Statement>,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
let result = self.execute_prepared(client, portal, max_rows).await;
let addr = client.socket_addr();
for message in self.sessions.drain_notices(&addr) {
let notice = notice_warning(&message);
let _ = client
.send(PgWireBackendMessage::NoticeResponse(notice))
.await;
}
result
}
async fn do_describe_statement<C>(
&self,
client: &mut C,
target: &StoredStatement<Self::Statement>,
) -> PgWireResult<DescribeStatementResponse>
where
C: ClientInfo + ClientPortalStore + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
C::PortalStore: PortalStore<Statement = Self::Statement>,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
self.describe_statement_impl(client, target).await
}
async fn do_describe_portal<C>(
&self,
client: &mut C,
target: &Portal<Self::Statement>,
) -> PgWireResult<DescribePortalResponse>
where
C: ClientInfo + ClientPortalStore + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
C::PortalStore: PortalStore<Statement = Self::Statement>,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
self.describe_portal_impl(client, target).await
}
}
#[async_trait]
impl NoopStartupHandler for NodeDbPgHandler {
async fn post_startup<C>(
&self,
client: &mut C,
_message: PgWireFrontendMessage,
) -> PgWireResult<()>
where
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
if !matches!(self.auth_mode, AuthMode::Trust) {
return Ok(());
}
let username = client
.metadata()
.get("user")
.cloned()
.unwrap_or_else(|| "unknown".to_string());
if self
.state
.credentials
.to_identity(&username, AuthMethod::Trust)
.is_some()
{
return Ok(());
}
if self.state.credentials.is_empty() {
let _ = self.state.credentials.create_user(
&username,
"",
TenantId::new(1),
vec![Role::Superuser],
);
return Ok(());
}
let source = client.socket_addr().to_string();
self.state.audit_record(
AuditEvent::AuthFailure,
None,
&source,
&format!("trust auth: user '{username}' does not exist"),
);
Err(PgWireError::UserError(Box::new(ErrorInfo::new(
"FATAL".to_owned(),
"28000".to_owned(),
format!("trust auth: user '{username}' does not exist"),
))))
}
}