Skip to main content

mcp_postgres/actions/
data_io.rs

1use crate::errors::Result as MCPResult;
2use crate::validation::quote_ident;
3use futures::SinkExt;
4use futures::StreamExt;
5use serde_json::{Value, json};
6use std::time::Duration;
7use tokio_postgres::Client;
8
9/// Cap on the response body fetched by `import_from_url` (100 MiB).
10const MAX_IMPORT_BYTES: usize = 100 * 1024 * 1024;
11/// Cap on the CSV produced by `export_csv` (100 MiB) to bound memory.
12const MAX_EXPORT_BYTES: usize = 100 * 1024 * 1024;
13/// Timeout for the outbound fetch in `import_from_url`.
14const IMPORT_FETCH_TIMEOUT: Duration = Duration::from_secs(30);
15
16pub async fn import_from_url(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
17    let url = params
18        .as_ref()
19        .and_then(|p| p.get("url").and_then(|v| v.as_str()))
20        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'url' parameter".into()))?;
21    let table = params
22        .as_ref()
23        .and_then(|p| p.get("table").and_then(|v| v.as_str()))
24        .ok_or_else(|| {
25            crate::errors::MCPError::InvalidParams("Missing 'table' parameter".into())
26        })?;
27    let schema = params
28        .as_ref()
29        .and_then(|p| p.get("schema").and_then(|v| v.as_str()))
30        .unwrap_or("public");
31    let delimiter = params
32        .as_ref()
33        .and_then(|p| p.get("delimiter").and_then(|v| v.as_str()))
34        .unwrap_or(",");
35    let header = params
36        .as_ref()
37        .and_then(|p| p.get("header").and_then(|v| v.as_bool()))
38        .unwrap_or(true);
39    let truncate = params
40        .as_ref()
41        .and_then(|p| p.get("truncate").and_then(|v| v.as_bool()))
42        .unwrap_or(false);
43    let columns = params
44        .as_ref()
45        .and_then(|p| p.get("columns").and_then(|v| v.as_str()));
46
47    // COPY requires a single-character delimiter; reject anything else so the
48    // value cannot smuggle extra COPY options.
49    if delimiter.chars().count() != 1 {
50        return Err(crate::errors::MCPError::InvalidParams(
51            "'delimiter' must be a single character".into(),
52        ));
53    }
54
55    // Validate the optional column list as identifiers and rebuild it quoted,
56    // instead of interpolating the raw string into the COPY statement.
57    let col_clause = match columns {
58        Some(c) => {
59            let mut quoted = Vec::new();
60            for col in c.split(',') {
61                let col = col.trim();
62                crate::validation::validate_identifier(col, "column")?;
63                quoted.push(quote_ident(col));
64            }
65            format!(" ({})", quoted.join(", "))
66        }
67        None => String::new(),
68    };
69
70    // SSRF guard: only http(s), and the host must resolve to a public address.
71    crate::ssrf::validate_import_url(url).await?;
72
73    let qualified = format!("{}.{}", quote_ident(schema), quote_ident(table));
74
75    if truncate {
76        client
77            .execute(&format!("TRUNCATE {}", qualified), &[])
78            .await?;
79    }
80
81    // Build the COPY SQL early so we can open the sink before the HTTP fetch.
82    let copy_sql = format!(
83        "COPY {} FROM STDIN (FORMAT csv, HEADER {}, DELIMITER '{}'){}",
84        qualified,
85        if header { "true" } else { "false" },
86        delimiter.replace('\'', "''"),
87        col_clause,
88    );
89
90    // Open the COPY sink first — chunks stream directly into it.
91    let mut sink = Box::pin(client.copy_in(&copy_sql).await?);
92
93    // reqwest is built with `rustls-no-provider`, so it needs a process-default
94    // rustls CryptoProvider. Ensure `ring` is installed even when this runs in a
95    // process that never opened a Postgres TLS connection (e.g. a library
96    // consumer). Idempotent.
97    crate::tls::ensure_crypto_provider();
98
99    // Disable redirects (a 3xx could redirect to a blocked internal address)
100    // and bound the request time.
101    let http = reqwest::Client::builder()
102        .redirect(reqwest::redirect::Policy::none())
103        .timeout(IMPORT_FETCH_TIMEOUT)
104        .build()
105        .map_err(|e| {
106            crate::errors::MCPError::InvalidParams(format!("Failed to build HTTP client: {e}"))
107        })?;
108
109    let resp = http.get(url).send().await.map_err(|e| {
110        crate::errors::MCPError::InvalidParams(format!("Failed to fetch URL: {}", e))
111    })?;
112    let status = resp.status();
113    if !status.is_success() {
114        return Err(crate::errors::MCPError::InvalidParams(format!(
115            "URL returned HTTP {}",
116            status
117        )));
118    }
119
120    // Stream body chunks directly into the COPY sink instead of buffering
121    // the entire file. A hard size cap still bounds worst-case memory.
122    let mut stream = resp.bytes_stream();
123    let mut total_bytes: usize = 0;
124    while let Some(chunk) = stream.next().await {
125        let chunk = chunk.map_err(|e| {
126            crate::errors::MCPError::InvalidParams(format!("Failed to read response body: {}", e))
127        })?;
128        total_bytes += chunk.len();
129        if total_bytes > MAX_IMPORT_BYTES {
130            return Err(crate::errors::MCPError::InvalidParams(format!(
131                "Response body exceeds maximum import size of {} bytes",
132                MAX_IMPORT_BYTES
133            )));
134        }
135        sink.as_mut().send(chunk).await?;
136    }
137    // finish() flushes, ends the COPY, and returns the number of rows imported.
138    let count = sink.as_mut().finish().await?;
139
140    Ok(json!({
141        "success": true,
142        "table": table,
143        "schema": schema,
144        "rows_imported": count,
145        "source": url,
146    }))
147}
148
149pub async fn export_csv(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
150    let query = params
151        .as_ref()
152        .and_then(|p| p.get("query").and_then(|v| v.as_str()));
153    let table = params
154        .as_ref()
155        .and_then(|p| p.get("table").and_then(|v| v.as_str()));
156    let schema = params
157        .as_ref()
158        .and_then(|p| p.get("schema").and_then(|v| v.as_str()))
159        .unwrap_or("public");
160    let header = params
161        .as_ref()
162        .and_then(|p| p.get("header").and_then(|v| v.as_bool()))
163        .unwrap_or(true);
164    let delimiter = params
165        .as_ref()
166        .and_then(|p| p.get("delimiter").and_then(|v| v.as_str()))
167        .unwrap_or(",");
168    let limit = params
169        .as_ref()
170        .and_then(|p| p.get("limit").and_then(|v| v.as_i64()))
171        .unwrap_or(10000)
172        .min(100000);
173
174    if delimiter.chars().count() != 1 {
175        return Err(crate::errors::MCPError::InvalidParams(
176            "'delimiter' must be a single character".into(),
177        ));
178    }
179
180    let sql = match (query, table) {
181        (Some(q), _) => {
182            crate::actions::query::validate_sql(q, "SELECT", "SELECT")?;
183            let trimmed = q.trim();
184            format!("({}) AS _export", trimmed.trim_end_matches(';'))
185        }
186        (None, Some(t)) => format!("{}.{}", quote_ident(schema), quote_ident(t)),
187        (None, None) => {
188            return Err(crate::errors::MCPError::InvalidParams(
189                "Either 'query' or 'table' is required".into(),
190            ));
191        }
192    };
193
194    let copy_sql = format!(
195        "COPY {} TO STDOUT (FORMAT csv, HEADER {}, DELIMITER '{}', LIMIT {})",
196        sql,
197        if header { "true" } else { "false" },
198        delimiter.replace('\'', "''"),
199        limit,
200    );
201
202    let stream = client.copy_out(&copy_sql).await?;
203    let mut stream = Box::pin(stream);
204    let mut output = Vec::new();
205    while let Some(chunk) = stream.next().await {
206        let chunk = chunk?;
207        if output.len() + chunk.len() > MAX_EXPORT_BYTES {
208            return Err(crate::errors::MCPError::InvalidParams(format!(
209                "Export exceeds maximum size of {} bytes; narrow the query or lower the limit",
210                MAX_EXPORT_BYTES
211            )));
212        }
213        output.extend_from_slice(&chunk);
214    }
215
216    let csv_text = String::from_utf8(output).map_err(|e| {
217        crate::errors::MCPError::InvalidParams(format!("Output is not valid UTF-8: {}", e))
218    })?;
219
220    Ok(json!({
221        "csv": csv_text,
222        "row_count": csv_text.lines().count().saturating_sub(if header { 1 } else { 0 }),
223        "format": "csv",
224    }))
225}