use crate::errors::Result as MCPResult;
use crate::validation::quote_ident;
use futures::SinkExt;
use futures::StreamExt;
use serde_json::{Value, json};
use std::time::Duration;
use tokio_postgres::Client;
const MAX_IMPORT_BYTES: usize = 100 * 1024 * 1024;
const MAX_EXPORT_BYTES: usize = 100 * 1024 * 1024;
const IMPORT_FETCH_TIMEOUT: Duration = Duration::from_secs(30);
pub async fn import_from_url(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
let url = params
.as_ref()
.and_then(|p| p.get("url").and_then(|v| v.as_str()))
.ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'url' parameter".into()))?;
let table = params
.as_ref()
.and_then(|p| p.get("table").and_then(|v| v.as_str()))
.ok_or_else(|| {
crate::errors::MCPError::InvalidParams("Missing 'table' parameter".into())
})?;
let schema = params
.as_ref()
.and_then(|p| p.get("schema").and_then(|v| v.as_str()))
.unwrap_or("public");
let delimiter = params
.as_ref()
.and_then(|p| p.get("delimiter").and_then(|v| v.as_str()))
.unwrap_or(",");
let header = params
.as_ref()
.and_then(|p| p.get("header").and_then(|v| v.as_bool()))
.unwrap_or(true);
let truncate = params
.as_ref()
.and_then(|p| p.get("truncate").and_then(|v| v.as_bool()))
.unwrap_or(false);
let columns = params
.as_ref()
.and_then(|p| p.get("columns").and_then(|v| v.as_str()));
if delimiter.chars().count() != 1 {
return Err(crate::errors::MCPError::InvalidParams(
"'delimiter' must be a single character".into(),
));
}
let col_clause = match columns {
Some(c) => {
let mut quoted = Vec::new();
for col in c.split(',') {
let col = col.trim();
crate::validation::validate_identifier(col, "column")?;
quoted.push(quote_ident(col));
}
format!(" ({})", quoted.join(", "))
}
None => String::new(),
};
crate::ssrf::validate_import_url(url).await?;
let qualified = format!("{}.{}", quote_ident(schema), quote_ident(table));
if truncate {
client
.execute(&format!("TRUNCATE {}", qualified), &[])
.await?;
}
let copy_sql = format!(
"COPY {} FROM STDIN (FORMAT csv, HEADER {}, DELIMITER '{}'){}",
qualified,
if header { "true" } else { "false" },
delimiter.replace('\'', "''"),
col_clause,
);
let mut sink = Box::pin(client.copy_in(©_sql).await?);
let http = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none())
.timeout(IMPORT_FETCH_TIMEOUT)
.build()
.map_err(|e| {
crate::errors::MCPError::InvalidParams(format!("Failed to build HTTP client: {e}"))
})?;
let resp = http.get(url).send().await.map_err(|e| {
crate::errors::MCPError::InvalidParams(format!("Failed to fetch URL: {}", e))
})?;
let status = resp.status();
if !status.is_success() {
return Err(crate::errors::MCPError::InvalidParams(format!(
"URL returned HTTP {}",
status
)));
}
let mut stream = resp.bytes_stream();
let mut total_bytes: usize = 0;
while let Some(chunk) = stream.next().await {
let chunk = chunk.map_err(|e| {
crate::errors::MCPError::InvalidParams(format!("Failed to read response body: {}", e))
})?;
total_bytes += chunk.len();
if total_bytes > MAX_IMPORT_BYTES {
return Err(crate::errors::MCPError::InvalidParams(format!(
"Response body exceeds maximum import size of {} bytes",
MAX_IMPORT_BYTES
)));
}
sink.as_mut().send(chunk).await?;
}
let count = sink.as_mut().finish().await?;
Ok(json!({
"success": true,
"table": table,
"schema": schema,
"rows_imported": count,
"source": url,
}))
}
pub async fn export_csv(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
let query = params
.as_ref()
.and_then(|p| p.get("query").and_then(|v| v.as_str()));
let table = params
.as_ref()
.and_then(|p| p.get("table").and_then(|v| v.as_str()));
let schema = params
.as_ref()
.and_then(|p| p.get("schema").and_then(|v| v.as_str()))
.unwrap_or("public");
let header = params
.as_ref()
.and_then(|p| p.get("header").and_then(|v| v.as_bool()))
.unwrap_or(true);
let delimiter = params
.as_ref()
.and_then(|p| p.get("delimiter").and_then(|v| v.as_str()))
.unwrap_or(",");
let limit = params
.as_ref()
.and_then(|p| p.get("limit").and_then(|v| v.as_i64()))
.unwrap_or(10000)
.min(100000);
if delimiter.chars().count() != 1 {
return Err(crate::errors::MCPError::InvalidParams(
"'delimiter' must be a single character".into(),
));
}
let sql = match (query, table) {
(Some(q), _) => {
crate::actions::query::validate_sql(q, "SELECT", "SELECT")?;
let trimmed = q.trim();
format!("({}) AS _export", trimmed.trim_end_matches(';'))
}
(None, Some(t)) => format!("{}.{}", quote_ident(schema), quote_ident(t)),
(None, None) => {
return Err(crate::errors::MCPError::InvalidParams(
"Either 'query' or 'table' is required".into(),
));
}
};
let copy_sql = format!(
"COPY {} TO STDOUT (FORMAT csv, HEADER {}, DELIMITER '{}', LIMIT {})",
sql,
if header { "true" } else { "false" },
delimiter.replace('\'', "''"),
limit,
);
let stream = client.copy_out(©_sql).await?;
let mut stream = Box::pin(stream);
let mut output = Vec::new();
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
if output.len() + chunk.len() > MAX_EXPORT_BYTES {
return Err(crate::errors::MCPError::InvalidParams(format!(
"Export exceeds maximum size of {} bytes; narrow the query or lower the limit",
MAX_EXPORT_BYTES
)));
}
output.extend_from_slice(&chunk);
}
let csv_text = String::from_utf8(output).map_err(|e| {
crate::errors::MCPError::InvalidParams(format!("Output is not valid UTF-8: {}", e))
})?;
Ok(json!({
"csv": csv_text,
"row_count": csv_text.lines().count().saturating_sub(if header { 1 } else { 0 }),
"format": "csv",
}))
}