use pgwire::api::results::{Response, Tag};
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
use crate::control::planner::physical::PhysicalTask;
use crate::control::security::identity::AuthenticatedIdentity;
use crate::types::{ReadConsistency, TenantId};
use super::super::types::{error_to_sqlstate, response_status_to_sqlstate};
use super::core::NodeDbPgHandler;
use super::plan::{PlanKind, describe_plan, extract_collection, payload_to_response};
impl NodeDbPgHandler {
pub(super) async fn execute_planned_sql(
&self,
identity: &AuthenticatedIdentity,
sql: &str,
tenant_id: TenantId,
addr: &std::net::SocketAddr,
) -> PgWireResult<Vec<Response>> {
let mut auth_ctx = if let Some(handle) =
self.sessions.get_parameter(addr, "nodedb.auth_session")
&& let Some(cached) = self.state.session_handles.resolve(&handle)
{
cached
} else {
crate::control::server::session_auth::build_auth_context_with_session(
identity,
&self.sessions,
addr,
)
};
let clean_sql =
crate::control::server::session_auth::extract_and_apply_on_deny(sql, &mut auth_ctx);
let tasks = self
.query_ctx
.plan_sql_with_rls(&clean_sql, tenant_id, &auth_ctx, &self.state.rls)
.await
.map_err(|e| {
let (severity, code, message) = error_to_sqlstate(&e);
PgWireError::UserError(Box::new(ErrorInfo::new(
severity.to_owned(),
code.to_owned(),
message,
)))
})?;
if tasks.is_empty() {
return Ok(vec![Response::Execution(Tag::new("OK"))]);
}
let consistency = self.consistency_for_tasks(&tasks);
if let Some(leader) = self.remote_leader_for_tasks(&tasks, consistency) {
return self.forward_sql(sql, tenant_id, leader).await;
}
let mut responses = Vec::with_capacity(tasks.len());
for task in tasks {
if task.tenant_id != tenant_id {
tracing::error!(
expected = %tenant_id,
actual = %task.tenant_id,
"SECURITY: task tenant_id mismatch — rejecting"
);
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
"ERROR".to_owned(),
"42501".to_owned(),
"tenant isolation violation: task targets wrong tenant".to_owned(),
))));
}
self.check_permission(identity, &task.plan)?;
if self.sessions.transaction_state(addr)
== crate::control::server::pgwire::session::TransactionState::InBlock
{
let is_write = crate::control::wal_replication::to_replicated_entry(
task.tenant_id,
task.vshard_id,
&task.plan,
)
.is_some();
if is_write {
self.sessions.buffer_write(addr, task);
responses.push(Response::Execution(Tag::new("OK")));
continue;
}
}
let plan_kind = describe_plan(&task.plan);
let collection_for_si = extract_collection(&task.plan).map(String::from);
let resp = self.dispatch_task(task).await.map_err(|e| {
let (severity, code, message) = error_to_sqlstate(&e);
PgWireError::UserError(Box::new(ErrorInfo::new(
severity.to_owned(),
code.to_owned(),
message,
)))
})?;
if let Some((severity, code, message)) =
response_status_to_sqlstate(resp.status, &resp.error_code)
{
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
severity.to_owned(),
code.to_owned(),
message,
))));
}
if self.sessions.transaction_state(addr)
== crate::control::server::pgwire::session::TransactionState::InBlock
&& let Some(collection) = collection_for_si
{
self.sessions
.record_read(addr, collection, String::new(), resp.watermark_lsn);
}
responses.push(payload_to_response(&resp.payload, plan_kind));
}
Ok(responses)
}
fn consistency_for_tasks(&self, tasks: &[PhysicalTask]) -> ReadConsistency {
let has_writes = tasks.iter().any(|t| {
crate::control::wal_replication::to_replicated_entry(t.tenant_id, t.vshard_id, &t.plan)
.is_some()
});
if has_writes {
ReadConsistency::Strong
} else {
ReadConsistency::BoundedStaleness(std::time::Duration::from_secs(5))
}
}
fn remote_leader_for_tasks(
&self,
tasks: &[PhysicalTask],
consistency: ReadConsistency,
) -> Option<u64> {
let routing = self.state.cluster_routing.as_ref()?;
let routing = routing.read().unwrap_or_else(|p| p.into_inner());
let my_node = self.state.node_id;
let mut remote_leader: Option<u64> = None;
for task in tasks {
let vshard_id = task.vshard_id.as_u16();
let group_id = routing.group_for_vshard(vshard_id).ok()?;
let info = routing.group_info(group_id)?;
let leader = info.leader;
if leader == my_node {
return None;
}
if !consistency.requires_leader() && info.members.contains(&my_node) {
return None;
}
if leader == 0 {
return None;
}
match remote_leader {
None => remote_leader = Some(leader),
Some(prev) if prev != leader => return None,
_ => {}
}
}
remote_leader
}
async fn forward_sql(
&self,
sql: &str,
tenant_id: TenantId,
leader: u64,
) -> PgWireResult<Vec<Response>> {
let transport = match &self.state.cluster_transport {
Some(t) => t,
None => {
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
"ERROR".to_owned(),
"55000".to_owned(),
"cluster transport not available".to_owned(),
))));
}
};
let req = nodedb_cluster::rpc_codec::RaftRpc::ForwardRequest(
nodedb_cluster::rpc_codec::ForwardRequest {
sql: sql.to_owned(),
tenant_id: tenant_id.as_u32(),
deadline_remaining_ms: std::time::Duration::from_secs(
self.state.tuning.network.default_deadline_secs,
)
.as_millis() as u64,
trace_id: 0,
},
);
let leader_addr = self
.state
.cluster_topology
.as_ref()
.and_then(|t| {
let topo = t.read().unwrap_or_else(|p| p.into_inner());
topo.get_node(leader).map(|n| n.addr.clone())
})
.unwrap_or_else(|| format!("node-{leader}"));
let resp = transport.send_rpc(leader, req).await.map_err(|e| {
PgWireError::UserError(Box::new(ErrorInfo::new(
"ERROR".to_owned(),
"01R01".to_owned(),
format!("not leader; redirect to {leader_addr} (forward failed: {e})"),
)))
})?;
match resp {
nodedb_cluster::rpc_codec::RaftRpc::ForwardResponse(fwd) => {
if !fwd.success {
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
"ERROR".to_owned(),
"XX000".to_owned(),
format!("remote execution failed: {}", fwd.error_message),
))));
}
let mut responses = Vec::with_capacity(fwd.payloads.len());
for payload in &fwd.payloads {
responses.push(payload_to_response(payload, PlanKind::MultiRow));
}
if responses.is_empty() {
responses.push(Response::Execution(Tag::new("OK")));
}
Ok(responses)
}
other => Err(PgWireError::UserError(Box::new(ErrorInfo::new(
"ERROR".to_owned(),
"XX000".to_owned(),
format!("unexpected response from leader: {other:?}"),
)))),
}
}
}