use crate::ddl_ast::statement::{CopyFormat, CopyToSource, NodedbStatement};
use crate::error::SqlError;
use super::copy_from::{
detect_format_from_path, extract_quoted_string, parse_with_clause, strip_quotes,
};
pub(super) fn try_parse(upper: &str, trimmed: &str) -> Option<Result<NodedbStatement, SqlError>> {
if !upper.starts_with("COPY ") {
return None;
}
let after_copy = trimmed["COPY ".len()..].trim_start();
let upper_after = upper["COPY ".len()..].trim_start();
if after_copy.starts_with('(') {
if find_matching_paren(after_copy).is_some_and(|close| {
let rest = after_copy[close + 1..].trim_start().to_uppercase();
rest.starts_with("TO ")
}) {
return Some(parse_copy_to(trimmed, after_copy));
}
return None;
}
let has_to = upper_after.contains(" TO ");
let has_from = upper_after.contains(" FROM ");
if has_from && (!has_to || upper_after.find(" FROM ") < upper_after.find(" TO ")) {
return None;
}
if !has_to {
return None;
}
Some(parse_copy_to(trimmed, after_copy))
}
fn parse_copy_to(trimmed: &str, after_copy: &str) -> Result<NodedbStatement, SqlError> {
if after_copy.starts_with('(') {
parse_query_form(trimmed)
} else {
parse_table_form(trimmed)
}
}
fn parse_table_form(trimmed: &str) -> Result<NodedbStatement, SqlError> {
let upper = trimmed.to_uppercase();
let to_pos = upper.find(" TO ").ok_or_else(|| SqlError::Parse {
detail: "COPY: missing TO keyword".to_string(),
})?;
let coll_raw = trimmed["COPY ".len()..to_pos].trim();
if coll_raw.is_empty() {
return Err(SqlError::Parse {
detail: "COPY: missing collection name before TO".to_string(),
});
}
let collection = strip_quotes(coll_raw).to_lowercase();
let after_to = trimmed[to_pos + " TO ".len()..].trim_start();
let (path, rest, format, delimiter, header) = parse_path_and_opts(after_to)?;
let _ = rest;
Ok(NodedbStatement::CopyToFile {
source: CopyToSource::Collection(collection),
path,
format,
delimiter,
header,
})
}
fn parse_query_form(trimmed: &str) -> Result<NodedbStatement, SqlError> {
let after_copy = trimmed["COPY ".len()..].trim_start();
let close = find_matching_paren(after_copy).ok_or_else(|| SqlError::Parse {
detail: "COPY: unclosed parenthesis in query form".to_string(),
})?;
let query = after_copy[1..close].trim().to_string();
if query.is_empty() {
return Err(SqlError::Parse {
detail: "COPY: empty query in query form".to_string(),
});
}
let after_paren = after_copy[close + 1..].trim_start();
let upper_after = after_paren.to_uppercase();
if !upper_after.starts_with("TO ") {
return Err(SqlError::Parse {
detail: format!(
"COPY: expected TO after query, got: {}",
&after_paren[..after_paren.len().min(32)]
),
});
}
let after_to = after_paren["TO ".len()..].trim_start();
let (path, _rest, format, delimiter, header) = parse_path_and_opts(after_to)?;
Ok(NodedbStatement::CopyToFile {
source: CopyToSource::Query(query),
path,
format,
delimiter,
header,
})
}
type PathAndOpts<'a> = (String, &'a str, Option<CopyFormat>, Option<char>, bool);
fn parse_path_and_opts(s: &str) -> Result<PathAndOpts<'_>, SqlError> {
let (path, rest) = extract_quoted_string(s)?;
let rest = rest.trim();
let (format, delimiter, header) = if rest.is_empty() {
(None, None, true)
} else {
parse_with_clause(rest)?
};
let format = match format {
Some(f) => Some(f),
None => detect_format_from_path(&path)?,
};
Ok((path, rest, format, delimiter, header))
}
fn find_matching_paren(s: &str) -> Option<usize> {
let mut depth = 0usize;
let mut in_single_quote = false;
for (i, ch) in s.char_indices() {
match ch {
'\'' if !in_single_quote => in_single_quote = true,
'\'' if in_single_quote => in_single_quote = false,
'(' if !in_single_quote => depth += 1,
')' if !in_single_quote => {
depth = depth.saturating_sub(1);
if depth == 0 {
return Some(i);
}
}
_ => {}
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ddl_ast::parse::dispatch::parse;
fn ok(sql: &str) -> NodedbStatement {
parse(sql).expect("expected Some").expect("expected Ok")
}
fn err(sql: &str) -> SqlError {
parse(sql)
.expect("expected Some")
.expect_err("expected Err")
}
#[test]
fn basic_ndjson_by_extension() {
let stmt = ok("COPY users TO '/tmp/out.ndjson'");
match stmt {
NodedbStatement::CopyToFile {
source,
path,
format,
..
} => {
assert_eq!(source, CopyToSource::Collection("users".to_string()));
assert_eq!(path, "/tmp/out.ndjson");
assert_eq!(format, Some(CopyFormat::Ndjson));
}
other => panic!("unexpected: {other:?}"),
}
}
#[test]
fn json_array_by_extension() {
let stmt = ok("COPY users TO '/tmp/out.json'");
match stmt {
NodedbStatement::CopyToFile { format, .. } => {
assert_eq!(format, Some(CopyFormat::JsonArray));
}
other => panic!("unexpected: {other:?}"),
}
}
#[test]
fn csv_by_extension() {
let stmt = ok("COPY users TO '/tmp/out.csv'");
match stmt {
NodedbStatement::CopyToFile { format, .. } => {
assert_eq!(format, Some(CopyFormat::Csv));
}
other => panic!("unexpected: {other:?}"),
}
}
#[test]
fn explicit_format_with_clause() {
let stmt = ok("COPY users TO '/tmp/out.csv' WITH (FORMAT ndjson)");
match stmt {
NodedbStatement::CopyToFile { format, .. } => {
assert_eq!(format, Some(CopyFormat::Ndjson));
}
other => panic!("unexpected: {other:?}"),
}
}
#[test]
fn csv_with_delimiter() {
let stmt = ok("COPY users TO '/tmp/out.csv' WITH (FORMAT csv, DELIMITER ';')");
match stmt {
NodedbStatement::CopyToFile {
format, delimiter, ..
} => {
assert_eq!(format, Some(CopyFormat::Csv));
assert_eq!(delimiter, Some(';'));
}
other => panic!("unexpected: {other:?}"),
}
}
#[test]
fn header_false() {
let stmt = ok("COPY users TO '/tmp/out.csv' WITH (FORMAT csv, HEADER false)");
match stmt {
NodedbStatement::CopyToFile { header, .. } => {
assert!(!header);
}
other => panic!("unexpected: {other:?}"),
}
}
#[test]
fn query_form() {
let stmt = ok("COPY (SELECT * FROM users WHERE active = true) TO '/tmp/out.ndjson'");
match stmt {
NodedbStatement::CopyToFile { source, path, .. } => {
assert!(matches!(source, CopyToSource::Query(_)));
assert_eq!(path, "/tmp/out.ndjson");
if let CopyToSource::Query(q) = source {
assert!(q.to_uppercase().starts_with("SELECT"));
}
}
other => panic!("unexpected: {other:?}"),
}
}
#[test]
fn copy_from_not_intercepted() {
let stmt = ok("COPY users FROM '/tmp/data.ndjson'");
assert!(matches!(stmt, NodedbStatement::CopyFromFile { .. }));
}
#[test]
fn unknown_extension_is_err() {
let e = err("COPY users TO '/tmp/out.parquet'");
assert!(matches!(e, SqlError::Parse { .. }));
assert!(e.to_string().contains("cannot infer format"));
}
}