use crate::{bytes::ByteString, eval::Schema, value::Value};
use chrono::{DateTime, Datelike, TimeZone, Timelike};
use memchr::{memchr2_iter, memchr3_iter, memchr_iter};
use rand_regex::Encoding;
use std::{
io::{Error, Write},
slice,
};
use tzfile::ArcTz;
pub trait Format {
fn write_value(&self, writer: &mut dyn Write, value: &Value) -> Result<(), Error>;
fn write_file_header(&self, writer: &mut dyn Write, schema: &Schema<'_>) -> Result<(), Error>;
fn write_header(&self, writer: &mut dyn Write, schema: &Schema<'_>) -> Result<(), Error>;
fn write_value_separator(&self, writer: &mut dyn Write) -> Result<(), Error>;
fn write_row_separator(&self, writer: &mut dyn Write) -> Result<(), Error>;
fn write_trailer(&self, writer: &mut dyn Write) -> Result<(), Error>;
}
#[derive(Debug, Default)]
pub struct SqlFormat {
pub headers: bool,
pub escape_backslash: bool,
}
#[derive(Debug)]
pub struct CsvFormat {
pub headers: bool,
pub escape_backslash: bool,
}
fn write_timestamp(writer: &mut dyn Write, quote: &str, timestamp: &DateTime<ArcTz>) -> Result<(), Error> {
write!(
writer,
"{}{:04}-{:02}-{:02} {:02}:{:02}:{:02}",
quote,
timestamp.year(),
timestamp.month(),
timestamp.day(),
timestamp.hour(),
timestamp.minute(),
timestamp.second(),
)?;
let ns = timestamp.nanosecond();
if ns != 0 {
write!(writer, ".{:06}", ns / 1000)?;
}
writer.write_all(quote.as_bytes())
}
fn write_interval(writer: &mut dyn Write, quote: &str, mut interval: i64) -> Result<(), Error> {
writer.write_all(quote.as_bytes())?;
if interval == i64::min_value() {
return write!(writer, "-106751991 04:00:54.775808{}", quote);
} else if interval < 0 {
interval = -interval;
writer.write_all(b"-")?;
}
let seconds = interval / 1_000_000;
let microseconds = interval % 1_000_000;
let minutes = seconds / 60;
let seconds = seconds % 60;
let hours = minutes / 60;
let minutes = minutes % 60;
let days = hours / 24;
let hours = hours % 24;
if days > 0 {
write!(writer, "{} ", days)?;
}
write!(writer, "{:02}:{:02}:{:02}", hours, minutes, seconds)?;
if microseconds > 0 {
write!(writer, ".{:06}", microseconds)?;
}
writer.write_all(quote.as_bytes())
}
#[derive(Debug, Copy, Clone)]
#[allow(variant_size_differences)]
enum EscapeRule {
Escape(&'static [u8]),
Unescape(u8),
}
#[derive(Debug, Default)]
struct EscapeState {
prev_end: usize,
prev_byte: u8,
cur_start: usize,
cur_byte: u8,
unescape_ready: bool,
}
impl EscapeState {
fn set_cur_and_read_prev<'b>(&mut self, bytes: &'b [u8], cur: usize) -> &'b [u8] {
self.cur_start = cur;
self.cur_byte = bytes[cur];
&bytes[self.prev_end..cur]
}
fn apply_rule<'b>(&mut self, rule: &'b EscapeRule) -> &'b [u8] {
let ret = match rule {
EscapeRule::Escape(replacement) => {
self.unescape_ready = false;
*replacement
}
EscapeRule::Unescape(b) => {
if self.unescape_ready && *b == self.prev_byte && self.cur_start == self.prev_end {
self.unescape_ready = false;
&[]
} else {
self.unescape_ready = true;
slice::from_ref(b)
}
}
};
self.prev_end = self.cur_start + 1;
self.prev_byte = self.cur_byte;
ret
}
}
fn write_with_escape(writer: &mut dyn Write, bytes: &[u8], rules: &[(u8, EscapeRule)]) -> Result<(), Error> {
let mut state = EscapeState::default();
match *rules {
[] => {}
[(s1, r1)] => {
for cur in memchr_iter(s1, bytes) {
writer.write_all(state.set_cur_and_read_prev(bytes, cur))?;
writer.write_all(state.apply_rule(&r1))?;
}
}
[(s1, r1), (s2, r2)] => {
for cur in memchr2_iter(s1, s2, bytes) {
writer.write_all(state.set_cur_and_read_prev(bytes, cur))?;
let rule = if state.cur_byte == s1 { r1 } else { r2 };
writer.write_all(state.apply_rule(&rule))?;
}
}
[(s1, r1), (s2, r2), (s3, r3)] => {
for cur in memchr3_iter(s1, s2, s3, bytes) {
writer.write_all(state.set_cur_and_read_prev(bytes, cur))?;
let rule = match bytes[cur] {
b if b == s1 => r1,
b if b == s2 => r2,
_ => r3,
};
writer.write_all(state.apply_rule(&rule))?;
}
}
_ => {
for (cur, cur_byte) in bytes.iter().enumerate() {
if let Some((_, rule)) = rules.iter().find(|(s, _)| s == cur_byte) {
writer.write_all(state.set_cur_and_read_prev(bytes, cur))?;
writer.write_all(state.apply_rule(&rule))?;
}
}
}
}
writer.write_all(&bytes[state.prev_end..])
}
impl SqlFormat {
fn write_bytes(&self, writer: &mut dyn Write, bytes: &ByteString) -> Result<(), Error> {
if bytes.encoding() == Encoding::Binary {
writer.write_all(b"X'")?;
for b in bytes.as_bytes() {
write!(writer, "{:02X}", b)?;
}
} else {
writer.write_all(b"'")?;
write_with_escape(
writer,
bytes.as_bytes(),
if self.escape_backslash {
&[
(b'\'', EscapeRule::Escape(b"''")),
(b'\\', EscapeRule::Escape(br"\\")),
(b'\0', EscapeRule::Escape(br"\0")),
]
} else {
&[(b'\'', EscapeRule::Escape(b"''"))]
},
)?;
}
writer.write_all(b"'")
}
}
impl Format for SqlFormat {
fn write_value(&self, writer: &mut dyn Write, value: &Value) -> Result<(), Error> {
match value {
Value::Null => writer.write_all(b"NULL"),
Value::Number(number) => write!(writer, "{}", number),
Value::Bytes(bytes) => self.write_bytes(writer, bytes),
Value::Timestamp(timestamp, tz) => write_timestamp(writer, "'", &tz.from_utc_datetime(timestamp)),
Value::Interval(interval) => write_interval(writer, "'", *interval),
Value::Array(array) => {
writer.write_all(b"ARRAY[")?;
for (i, item) in array.iter().enumerate() {
if i != 0 {
writer.write_all(b", ")?;
}
self.write_value(writer, item)?;
}
writer.write_all(b"]")
}
}
}
fn write_file_header(&self, _: &mut dyn Write, _: &Schema<'_>) -> Result<(), Error> {
Ok(())
}
fn write_header(&self, writer: &mut dyn Write, schema: &Schema<'_>) -> Result<(), Error> {
write!(writer, "INSERT INTO {} ", schema.name)?;
if self.headers {
writer.write_all(b"(")?;
for (i, col) in schema.column_names().enumerate() {
if i != 0 {
writer.write_all(b", ")?;
}
writer.write_all(col.as_bytes())?
}
writer.write_all(b") ")?;
}
writer.write_all(b"VALUES\n(")
}
fn write_value_separator(&self, writer: &mut dyn Write) -> Result<(), Error> {
writer.write_all(b", ")
}
fn write_row_separator(&self, writer: &mut dyn Write) -> Result<(), Error> {
writer.write_all(b"),\n(")
}
fn write_trailer(&self, writer: &mut dyn Write) -> Result<(), Error> {
writer.write_all(b");\n")
}
}
impl CsvFormat {
fn write_bytes(&self, writer: &mut dyn Write, bytes: &ByteString) -> Result<(), Error> {
writer.write_all(b"\"")?;
write_with_escape(
writer,
bytes.as_bytes(),
if self.escape_backslash {
&[(b'"', EscapeRule::Escape(b"\"\"")), (b'\\', EscapeRule::Escape(br"\\"))]
} else {
&[(b'"', EscapeRule::Escape(b"\"\""))]
},
)?;
writer.write_all(b"\"")
}
fn write_column_name(&self, writer: &mut dyn Write, name: &[u8]) -> Result<(), Error> {
writer.write_all(b"\"")?;
let (mut rules, name) = match name.first() {
Some(b'"') => (Vec::new(), &name[1..(name.len() - 1)]),
Some(b'`') => (
vec![(b'`', EscapeRule::Unescape(b'`')), (b'"', EscapeRule::Escape(b"\"\""))],
&name[1..(name.len() - 1)],
),
Some(b'[') => (vec![(b'"', EscapeRule::Escape(b"\"\""))], &name[1..(name.len() - 1)]),
_ => (Vec::new(), name),
};
if self.escape_backslash {
rules.push((b'\\', EscapeRule::Escape(br"\\")));
}
write_with_escape(writer, name, &rules)?;
writer.write_all(b"\"")
}
}
impl Format for CsvFormat {
fn write_value(&self, writer: &mut dyn Write, value: &Value) -> Result<(), Error> {
match value {
Value::Null => writer.write_all(br"\N"),
Value::Number(number) => write!(writer, "{}", number),
Value::Bytes(bytes) => self.write_bytes(writer, bytes),
Value::Timestamp(timestamp, tz) => write_timestamp(writer, "", &tz.from_utc_datetime(timestamp)),
Value::Interval(interval) => write_interval(writer, "", *interval),
Value::Array(array) => {
writer.write_all(b"{")?;
for (i, item) in array.iter().enumerate() {
if i != 0 {
writer.write_all(b",")?;
}
self.write_value(writer, item)?;
}
writer.write_all(b"}")
}
}
}
fn write_file_header(&self, writer: &mut dyn Write, schema: &Schema<'_>) -> Result<(), Error> {
if !self.headers {
return Ok(());
}
for (i, col) in schema.column_names().enumerate() {
if i != 0 {
self.write_value_separator(writer)?;
}
self.write_column_name(writer, col.as_bytes())?;
}
self.write_row_separator(writer)
}
fn write_header(&self, _: &mut dyn Write, _: &Schema<'_>) -> Result<(), Error> {
Ok(())
}
fn write_value_separator(&self, writer: &mut dyn Write) -> Result<(), Error> {
writer.write_all(b",")
}
fn write_row_separator(&self, writer: &mut dyn Write) -> Result<(), Error> {
writer.write_all(b"\n")
}
fn write_trailer(&self, writer: &mut dyn Write) -> Result<(), Error> {
writer.write_all(b"\n")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_write_with_escape() {
let test_cases: Vec<(&[u8], &[(u8, EscapeRule)], &[u8])> = vec![
(b"10 o'clock", &[], b"10 o'clock"),
(b"10 o'clock", &[(b'\'', EscapeRule::Escape(b"''"))], b"10 o''clock"),
(
b"<b>R&D</b>",
&[
(b'<', EscapeRule::Escape(b"<")),
(b'>', EscapeRule::Escape(b">")),
(b'&', EscapeRule::Escape(b"&")),
],
b"<b>R&D</b>",
),
(
br#"<b>"R&D"</b>"#,
&[
(b'<', EscapeRule::Escape(b"<")),
(b'>', EscapeRule::Escape(b">")),
(b'&', EscapeRule::Escape(b"&")),
(b'"', EscapeRule::Escape(b""")),
],
b"<b>"R&D"</b>",
),
(
b"`a'b``c`",
&[(b'\'', EscapeRule::Escape(b"''")), (b'`', EscapeRule::Unescape(b'`'))],
b"`a''b`c`",
),
];
for (src, rules, expected) in test_cases {
let mut writer = Vec::with_capacity(expected.len());
write_with_escape(&mut writer, src, rules).unwrap();
assert_eq!(&writer, expected);
}
}
}