use std::fmt::Debug;
use std::net::SocketAddr;
use std::sync::Arc;
use nodedb_types::error::sqlstate as ss;
use async_trait::async_trait;
use futures::stream;
use futures::{Sink, SinkExt};
use pgwire::api::copy::CopyHandler;
use pgwire::api::results::{CopyResponse, Response, Tag};
use pgwire::api::{ClientInfo, PgWireConnectionState};
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
use pgwire::messages::PgWireBackendMessage;
use pgwire::messages::copy::{CopyData, CopyDone};
use crate::control::backup;
use crate::control::backup::CopyIntent;
use crate::control::backup::state::AppendError;
use crate::control::security::audit::ArcAuditEmitter;
use crate::control::security::identity::{AuthenticatedIdentity, Permission};
use crate::control::state::SharedState;
use crate::types::TenantId;
use super::core::NodeDbPgHandler;
const COPY_IN_CAP: u64 = 16 * 1024 * 1024 * 1024;
impl NodeDbPgHandler {
pub(super) async fn intent_to_response(
&self,
identity: &AuthenticatedIdentity,
addr: SocketAddr,
intent: CopyIntent,
) -> PgWireResult<Response> {
let tenant_id = match &intent {
CopyIntent::BackupTenant { tenant_id } => *tenant_id,
CopyIntent::RestoreTenant { tenant_id, .. } => *tenant_id,
};
if !identity.is_superuser {
let emitter = ArcAuditEmitter(Arc::clone(&self.state.audit));
let allowed = self.state.permissions.check_tenant(
identity,
Permission::Backup,
TenantId::new(tenant_id),
&self.state.roles,
&emitter,
);
if !allowed {
return Err(sqlstate(
ss::INSUFFICIENT_PRIVILEGE,
"permission denied: BACKUP permission on the tenant required",
));
}
}
match intent {
CopyIntent::BackupTenant { tenant_id } => {
let bytes = backup::backup_tenant(&self.state, tenant_id)
.await
.map_err(internal)?;
let copy_data = Ok(CopyData::new(bytes));
let stream = stream::once(async move { copy_data });
Ok(Response::CopyOut(CopyResponse::new(0, 0, stream)))
}
CopyIntent::RestoreTenant { tenant_id, dry_run } => {
self.restore_state.begin(
conn_id(&addr),
backup::RestorePending::new(tenant_id, dry_run, COPY_IN_CAP),
);
let empty = stream::empty();
Ok(Response::CopyIn(CopyResponse::new(0, 0, empty)))
}
}
}
}
pub struct NodeDbCopyHandler {
pub state: Arc<SharedState>,
pub restore_state: Arc<backup::RestoreState>,
}
#[async_trait]
impl CopyHandler for NodeDbCopyHandler {
async fn on_copy_data<C>(&self, client: &mut C, copy_data: CopyData) -> PgWireResult<()>
where
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
let id = conn_id(&client.socket_addr());
match self.restore_state.append(id, ©_data.data) {
Ok(()) => Ok(()),
Err(e @ AppendError::NotPending) => {
Err(sqlstate(ss::FEATURE_NOT_SUPPORTED, &e.to_string()))
}
Err(e @ AppendError::OverCap { .. }) => {
self.restore_state.cancel(id);
Err(sqlstate(ss::PROGRAM_LIMIT_EXCEEDED, &e.to_string()))
}
}
}
async fn on_copy_done<C>(&self, client: &mut C, _done: CopyDone) -> PgWireResult<()>
where
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
let id = conn_id(&client.socket_addr());
let pending = self.restore_state.take(id).ok_or_else(|| {
sqlstate(
ss::FEATURE_NOT_SUPPORTED,
"no restore pending on this connection",
)
})?;
let stats = backup::restore_tenant(
&self.state,
pending.tenant_id,
&pending.bytes,
pending.dry_run,
)
.await
.map_err(internal)?;
let rows =
stats.documents + stats.kv_tables + stats.vectors + stats.timeseries + stats.edges;
let tag = Tag::new("RESTORE TENANT").with_rows(rows);
client
.send(PgWireBackendMessage::CommandComplete(tag.into()))
.await
.map_err(|e| {
sqlstate(
ss::INTERNAL_ERROR,
&format!("CommandComplete send failed: {e:?}"),
)
})?;
client.set_state(PgWireConnectionState::AwaitingSync);
Ok(())
}
}
fn conn_id(addr: &SocketAddr) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut h = DefaultHasher::new();
addr.hash(&mut h);
h.finish()
}
fn sqlstate(code: &str, message: &str) -> PgWireError {
PgWireError::UserError(Box::new(ErrorInfo::new(
"ERROR".into(),
code.into(),
message.into(),
)))
}
fn internal(e: crate::Error) -> PgWireError {
sqlstate(ss::INTERNAL_ERROR, &e.to_string())
}