use crate::connection::Connection;
use crate::error::{Error, Result};
const DEFAULT_IMPORT_CHUNK_SIZE: usize = 1024 * 1024;
#[derive(Debug, Clone)]
pub struct CopyOptions {
format: CopyFormat,
header: bool,
delimiter: Option<u8>,
null_string: Option<String>,
quote: Option<u8>,
escape: Option<u8>,
chunk_size: Option<usize>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum CopyFormat {
Csv,
Text,
}
impl CopyOptions {
#[must_use]
pub fn csv() -> Self {
CopyOptions {
format: CopyFormat::Csv,
header: false,
delimiter: None,
null_string: None,
quote: None,
escape: None,
chunk_size: None,
}
}
#[must_use]
pub fn tsv() -> Self {
CopyOptions {
format: CopyFormat::Text,
header: false,
delimiter: Some(b'\t'),
null_string: None,
quote: None,
escape: None,
chunk_size: None,
}
}
#[must_use]
pub fn text() -> Self {
CopyOptions {
format: CopyFormat::Text,
header: false,
delimiter: None,
null_string: None,
quote: None,
escape: None,
chunk_size: None,
}
}
#[must_use]
pub fn with_header(mut self, header: bool) -> Self {
self.header = header;
self
}
#[must_use]
pub fn with_delimiter(mut self, delimiter: u8) -> Self {
self.delimiter = Some(delimiter);
self
}
#[must_use]
pub fn with_null(mut self, null_string: impl Into<String>) -> Self {
self.null_string = Some(null_string.into());
self
}
#[must_use]
pub fn with_quote(mut self, quote: u8) -> Self {
self.quote = Some(quote);
self
}
#[must_use]
pub fn with_escape(mut self, escape: u8) -> Self {
self.escape = Some(escape);
self
}
#[must_use]
pub fn with_chunk_size(mut self, size: usize) -> Self {
assert!(size > 0, "chunk size must be > 0");
self.chunk_size = Some(size);
self
}
fn validate(&self) -> Result<()> {
if self.format == CopyFormat::Text {
if self.quote.is_some() {
return Err(Error::new(
"QUOTE option is only supported with CSV format. \
Use CopyOptions::csv() instead of CopyOptions::text().",
));
}
if self.escape.is_some() {
return Err(Error::new(
"ESCAPE option is only supported with CSV format. \
Use CopyOptions::csv() instead of CopyOptions::text().",
));
}
}
Ok(())
}
fn to_copy_out_options(&self) -> String {
let mut parts = Vec::new();
match self.format {
CopyFormat::Csv => parts.push("FORMAT csv".to_string()),
CopyFormat::Text => parts.push("FORMAT text".to_string()),
}
if self.header {
parts.push("HEADER true".to_string());
}
if let Some(d) = self.delimiter {
parts.push(format!("DELIMITER E'\\x{d:02x}'"));
}
if let Some(ref n) = self.null_string {
parts.push(format!("NULL '{}'", n.replace('\'', "''")));
}
if let Some(q) = self.quote {
parts.push(format!("QUOTE E'\\x{q:02x}'"));
}
if let Some(e) = self.escape {
parts.push(format!("ESCAPE E'\\x{e:02x}'"));
}
format!("WITH ({})", parts.join(", "))
}
fn to_copy_in_options(&self) -> String {
self.to_copy_out_options()
}
}
impl Connection {
pub fn export_csv(&self, select_query: &str, writer: &mut dyn std::io::Write) -> Result<u64> {
let opts = CopyOptions::csv().with_header(true);
self.export_text(select_query, &opts, writer)
}
pub fn export_text(
&self,
select_query: &str,
options: &CopyOptions,
writer: &mut dyn std::io::Write,
) -> Result<u64> {
options.validate()?;
let copy_query = format!(
"COPY ({}) TO STDOUT {}",
select_query,
options.to_copy_out_options()
);
let client = self.tcp_client().ok_or_else(|| {
Error::new(
"CSV export requires a TCP connection. gRPC does not support COPY operations.",
)
})?;
Ok(client.copy_out_to_writer(©_query, writer)?)
}
pub fn export_csv_string(&self, select_query: &str) -> Result<String> {
let mut buf = Vec::new();
self.export_csv(select_query, &mut buf)?;
String::from_utf8(buf)
.map_err(|e| Error::new(format!("CSV output is not valid UTF-8: {e}")))
}
pub fn import_csv(&self, table_name: &str, reader: impl std::io::Read) -> Result<u64> {
let opts = CopyOptions::csv();
self.import_text(table_name, &opts, reader)
}
pub fn import_csv_with_header(
&self,
table_name: &str,
reader: impl std::io::Read,
) -> Result<u64> {
let opts = CopyOptions::csv().with_header(true);
self.import_text(table_name, &opts, reader)
}
pub fn import_text(
&self,
table_name: &str,
options: &CopyOptions,
mut reader: impl std::io::Read,
) -> Result<u64> {
options.validate()?;
let escaped_table = table_name.replace('"', "\"\"");
let copy_query = format!(
"COPY \"{}\" FROM STDIN {}",
escaped_table,
options.to_copy_in_options()
);
let client = self.tcp_client().ok_or_else(|| {
Error::new(
"CSV import requires a TCP connection. gRPC does not support COPY operations.",
)
})?;
let mut writer = client.copy_in_raw(©_query)?;
let chunk_size = options.chunk_size.unwrap_or(DEFAULT_IMPORT_CHUNK_SIZE);
let mut buf = vec![0u8; chunk_size];
loop {
let n = reader
.read(&mut buf)
.map_err(|e| Error::with_cause("Failed to read import data", e))?;
if n == 0 {
break;
}
writer.send(&buf[..n])?;
}
Ok(writer.finish()?)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_csv_options_valid() {
let opts = CopyOptions::csv().with_quote(b'"').with_escape(b'\\');
assert!(opts.validate().is_ok());
}
#[test]
fn test_text_quote_rejected() {
let opts = CopyOptions::text().with_quote(b'"');
let err = opts.validate().unwrap_err();
assert!(err.to_string().contains("QUOTE"));
assert!(err.to_string().contains("CSV format"));
}
#[test]
fn test_text_escape_rejected() {
let opts = CopyOptions::text().with_escape(b'\\');
let err = opts.validate().unwrap_err();
assert!(err.to_string().contains("ESCAPE"));
assert!(err.to_string().contains("CSV format"));
}
#[test]
fn test_tsv_quote_rejected() {
let opts = CopyOptions::tsv().with_quote(b'"');
let err = opts.validate().unwrap_err();
assert!(err.to_string().contains("QUOTE"));
}
#[test]
fn test_text_without_csv_options_valid() {
let opts = CopyOptions::text().with_header(true).with_delimiter(b'|');
assert!(opts.validate().is_ok());
}
#[test]
fn test_chunk_size_custom() {
let opts = CopyOptions::csv().with_chunk_size(4 * 1024 * 1024);
assert_eq!(opts.chunk_size, Some(4 * 1024 * 1024));
}
#[test]
fn test_chunk_size_default() {
let opts = CopyOptions::csv();
assert_eq!(
opts.chunk_size.unwrap_or(DEFAULT_IMPORT_CHUNK_SIZE),
1024 * 1024
);
}
#[test]
#[should_panic(expected = "chunk size must be > 0")]
fn test_chunk_size_zero_panics() {
let _ = CopyOptions::csv().with_chunk_size(0);
}
#[test]
fn test_copy_in_options_csv() {
let opts = CopyOptions::csv().with_header(true).with_delimiter(b'|');
let sql = opts.to_copy_in_options();
assert!(sql.contains("FORMAT csv"));
assert!(sql.contains("HEADER true"));
assert!(sql.contains("DELIMITER"));
}
#[test]
fn test_copy_in_options_text() {
let opts = CopyOptions::text();
let sql = opts.to_copy_in_options();
assert!(sql.contains("FORMAT text"));
}
}