use nodedb_types::DatabaseId;
use std::path::Path;
use pgwire::api::results::{Response, Tag};
use pgwire::error::PgWireResult;
use sonic_rs;
use nodedb_sql::ddl_ast::statement::{CopyFormat, CopyToSource};
use crate::control::security::identity::AuthenticatedIdentity;
use crate::control::server::pgwire::types::sqlstate_error;
use crate::control::state::SharedState;
use crate::types::TraceId;
use super::format::serialize_rows;
pub async fn copy_to_file(
state: &SharedState,
identity: &AuthenticatedIdentity,
source: &CopyToSource,
path: &str,
format: Option<&CopyFormat>,
delimiter: Option<char>,
header: bool,
) -> PgWireResult<Vec<Response>> {
validate_path(path)?;
let resolved_format = format.ok_or_else(|| {
sqlstate_error(
"42601",
&format!(
"COPY TO: cannot infer format for '{path}'; \
add WITH (FORMAT ndjson|json|csv)"
),
)
})?;
let select_sql = build_select_sql(source)?;
if let CopyToSource::Collection(coll) = source {
check_collection_exists(state, identity, coll)?;
}
let rows = execute_and_collect(state, identity, &select_sql).await?;
let bytes = serialize_rows(&rows, resolved_format, delimiter.unwrap_or(','), header)?;
let tmp_path = format!("{path}.tmp");
tokio::fs::write(&tmp_path, &bytes).await.map_err(|e| {
sqlstate_error(
"58030",
&format!("COPY TO: cannot write to '{tmp_path}': {e}"),
)
})?;
tokio::fs::rename(&tmp_path, path).await.map_err(|e| {
let _ = std::fs::remove_file(&tmp_path);
sqlstate_error(
"58030",
&format!("COPY TO: cannot rename '{tmp_path}' to '{path}': {e}"),
)
})?;
let row_count = rows.len();
Ok(vec![Response::Execution(Tag::new(&format!(
"COPY {row_count}"
)))])
}
fn build_select_sql(source: &CopyToSource) -> PgWireResult<String> {
match source {
CopyToSource::Collection(coll) => Ok(format!("SELECT * FROM {coll}")),
CopyToSource::Query(q) => Ok(q.clone()),
}
}
fn check_collection_exists(
state: &SharedState,
identity: &AuthenticatedIdentity,
collection: &str,
) -> PgWireResult<()> {
let catalog = match state.credentials.catalog() {
Some(c) => c,
None => return Ok(()), };
match catalog.get_collection(DatabaseId::DEFAULT, identity.tenant_id.as_u64(), collection) {
Ok(Some(_)) => Ok(()),
Ok(None) => Err(sqlstate_error(
"42P01",
&format!("COPY TO: collection \"{collection}\" does not exist"),
)),
Err(e) => Err(sqlstate_error(
"XX000",
&format!("COPY TO: catalog lookup failed: {e}"),
)),
}
}
async fn execute_and_collect(
state: &SharedState,
identity: &AuthenticatedIdentity,
select_sql: &str,
) -> PgWireResult<Vec<serde_json::Value>> {
let query_ctx = crate::control::planner::context::QueryContext::for_state(state);
let tasks = query_ctx
.plan_sql(
select_sql,
identity.tenant_id,
crate::types::DatabaseId::DEFAULT,
)
.await
.map_err(|e| sqlstate_error("42601", &format!("COPY TO: query planning failed: {e}")))?;
let mut all_rows: Vec<serde_json::Value> = Vec::new();
for task in tasks {
let resp = crate::control::server::dispatch_utils::dispatch_to_data_plane(
state,
task.tenant_id,
task.vshard_id,
task.plan,
TraceId::ZERO,
)
.await
.map_err(|e| sqlstate_error("XX000", &format!("COPY TO: dispatch failed: {e}")))?;
if resp.payload.is_empty() {
continue;
}
let json = crate::data::executor::response_codec::decode_payload_to_json(&resp.payload);
extract_json_rows(&json, &mut all_rows)?;
}
Ok(all_rows)
}
fn extract_json_rows(json: &str, out: &mut Vec<serde_json::Value>) -> PgWireResult<()> {
if json.is_empty() {
return Ok(());
}
let parsed: serde_json::Value = sonic_rs::from_str(json).map_err(|e| {
sqlstate_error(
"XX000",
&format!("COPY TO: failed to decode result rows: {e}"),
)
})?;
match parsed {
serde_json::Value::Array(items) => {
out.extend(items);
}
obj @ serde_json::Value::Object(_) => {
out.push(obj);
}
_ => {} }
Ok(())
}
fn validate_path(path: &str) -> PgWireResult<()> {
if !path.starts_with('/') {
return Err(sqlstate_error(
"42601",
&format!(
"COPY TO: path '{path}' is not absolute; \
only absolute server-side paths are accepted"
),
));
}
let p = Path::new(path);
for component in p.components() {
use std::path::Component;
if matches!(component, Component::ParentDir) {
return Err(sqlstate_error(
"42501",
&format!(
"COPY TO: path '{path}' contains '..'; \
directory traversal is not permitted"
),
));
}
}
Ok(())
}