use bytes::BytesMut;
use qail_core::ast::Value;
use crate::protocol::EncodeError;
#[inline]
pub fn encode_copy_value(buf: &mut BytesMut, value: &Value) {
let mut tmp = BytesMut::new();
if try_encode_copy_value(&mut tmp, value).is_ok() {
buf.extend_from_slice(&tmp);
} else {
buf.extend_from_slice(b"\\N");
}
}
#[inline]
pub fn try_encode_copy_value(buf: &mut BytesMut, value: &Value) -> Result<(), EncodeError> {
match value {
Value::Null | Value::NullUuid => buf.extend_from_slice(b"\\N"),
Value::Bool(b) => buf.extend_from_slice(if *b { b"t" } else { b"f" }),
Value::Int(n) => {
let mut tmp = itoa::Buffer::new();
buf.extend_from_slice(tmp.format(*n).as_bytes());
}
Value::Float(n) => {
let mut tmp = ryu::Buffer::new();
buf.extend_from_slice(tmp.format(*n).as_bytes());
}
Value::String(s) => write_copy_escaped_str(buf, s)?,
Value::Uuid(u) => {
let mut uuid_buf = [0u8; 36];
u.hyphenated().encode_lower(&mut uuid_buf);
buf.extend_from_slice(&uuid_buf);
}
Value::Timestamp(ts) => write_copy_escaped_str(buf, ts)?,
Value::Column(_)
| Value::Function(_)
| Value::Param(_)
| Value::NamedParam(_)
| Value::Subquery(_)
| Value::Expr(_) => {
return Err(EncodeError::InvalidAst(
"COPY data value cannot be an expression or unresolved parameter".to_string(),
));
}
Value::Array(arr) => {
let mut arr_buf = Vec::with_capacity(arr.len() * 8 + 2);
arr_buf.extend_from_slice(b"{");
for (i, v) in arr.iter().enumerate() {
if i > 0 {
arr_buf.push(b',');
}
write_copy_array_value(&mut arr_buf, v)?;
}
arr_buf.extend_from_slice(b"}");
write_copy_escaped_bytes(buf, &arr_buf);
}
Value::Interval { amount, unit } => {
let mut tmp = itoa::Buffer::new();
buf.extend_from_slice(tmp.format(*amount).as_bytes());
buf.extend_from_slice(b" ");
buf.extend_from_slice(unit.to_string().as_bytes());
}
Value::Bytes(bytes) => {
buf.extend_from_slice(b"\\\\x");
for byte in bytes {
let hi = byte >> 4;
let lo = byte & 0x0f;
buf.extend_from_slice(&[
if hi < 10 { b'0' + hi } else { b'a' + hi - 10 },
if lo < 10 { b'0' + lo } else { b'a' + lo - 10 },
]);
}
}
Value::Vector(vec) => {
buf.extend_from_slice(b"{");
for (i, v) in vec.iter().enumerate() {
if i > 0 {
buf.extend_from_slice(b",");
}
let mut tmp = ryu::Buffer::new();
buf.extend_from_slice(tmp.format(*v).as_bytes());
}
buf.extend_from_slice(b"}");
}
Value::Json(json) => write_copy_escaped_str(buf, json)?,
}
Ok(())
}
fn write_copy_escaped_str(buf: &mut BytesMut, value: &str) -> Result<(), EncodeError> {
if value.as_bytes().contains(&0) {
return Err(EncodeError::NullByte);
}
write_copy_escaped_bytes(buf, value.as_bytes());
Ok(())
}
fn write_copy_escaped_bytes(buf: &mut BytesMut, value: &[u8]) {
for byte in value {
match *byte {
b'\\' => buf.extend_from_slice(b"\\\\"),
b'\t' => buf.extend_from_slice(b"\\t"),
b'\n' => buf.extend_from_slice(b"\\n"),
b'\r' => buf.extend_from_slice(b"\\r"),
_ => buf.extend_from_slice(&[*byte]),
}
}
}
fn write_copy_array_value(buf: &mut Vec<u8>, value: &Value) -> Result<(), EncodeError> {
match value {
Value::Null | Value::NullUuid => buf.extend_from_slice(b"NULL"),
Value::Bool(value) => buf.extend_from_slice(if *value { b"t" } else { b"f" }),
Value::Int(value) => buf.extend_from_slice(value.to_string().as_bytes()),
Value::Float(value) => buf.extend_from_slice(value.to_string().as_bytes()),
Value::Uuid(value) => buf.extend_from_slice(value.to_string().as_bytes()),
Value::String(value) | Value::Timestamp(value) | Value::Json(value) => {
write_quoted_array_element(buf, value)?
}
Value::Interval { amount, unit } => {
write_quoted_array_element(buf, &format!("{amount} {unit}"))?;
}
Value::Column(_)
| Value::Function(_)
| Value::Param(_)
| Value::NamedParam(_)
| Value::Array(_)
| Value::Subquery(_)
| Value::Bytes(_)
| Value::Expr(_)
| Value::Vector(_) => {
return Err(EncodeError::InvalidAst(
"COPY array value cannot contain expressions or nested binary/vector values"
.to_string(),
));
}
}
Ok(())
}
fn write_quoted_array_element(buf: &mut Vec<u8>, value: &str) -> Result<(), EncodeError> {
if value.as_bytes().contains(&0) {
return Err(EncodeError::NullByte);
}
buf.push(b'"');
for byte in value.bytes() {
if byte == b'"' || byte == b'\\' {
buf.push(b'\\');
}
buf.push(byte);
}
buf.push(b'"');
Ok(())
}
#[inline]
pub fn encode_copy_batch(rows: &[Vec<Value>]) -> BytesMut {
let estimated_size = rows.len() * 7 * 50;
let mut buf = BytesMut::with_capacity(estimated_size);
for row in rows {
for (i, val) in row.iter().enumerate() {
if i > 0 {
buf.extend_from_slice(b"\t");
}
encode_copy_value(&mut buf, val);
}
buf.extend_from_slice(b"\n");
}
buf
}
#[inline]
pub fn try_encode_copy_batch(rows: &[Vec<Value>]) -> Result<BytesMut, EncodeError> {
let estimated_size = rows.len() * 7 * 50;
let mut buf = BytesMut::with_capacity(estimated_size);
for row in rows {
for (i, val) in row.iter().enumerate() {
if i > 0 {
buf.extend_from_slice(b"\t");
}
try_encode_copy_value(&mut buf, val)?;
}
buf.extend_from_slice(b"\n");
}
Ok(buf)
}
#[cfg(test)]
mod tests {
use super::*;
use uuid::Uuid;
#[test]
fn test_encode_int() {
let mut buf = BytesMut::new();
encode_copy_value(&mut buf, &Value::Int(12345));
assert_eq!(&buf[..], b"12345");
}
#[test]
fn test_encode_float() {
let mut buf = BytesMut::new();
encode_copy_value(&mut buf, &Value::Float(9.87654));
assert!(buf.starts_with(b"9.87"));
}
#[test]
fn test_encode_string_escaping() {
let mut buf = BytesMut::new();
encode_copy_value(&mut buf, &Value::String("hello\tworld\n".to_string()));
assert_eq!(&buf[..], b"hello\\tworld\\n");
}
#[test]
fn test_try_encode_rejects_expression_values() {
let mut buf = BytesMut::new();
let err = try_encode_copy_value(&mut buf, &Value::Function("now()\n1\t2".to_string()))
.unwrap_err();
assert!(
matches!(err, EncodeError::InvalidAst(ref message) if message.contains("COPY data value cannot be an expression")),
"{err}"
);
assert!(buf.is_empty());
}
#[test]
fn test_try_encode_rejects_null_bytes() {
let mut buf = BytesMut::new();
let err =
try_encode_copy_value(&mut buf, &Value::String("bad\0value".to_string())).unwrap_err();
assert_eq!(err, EncodeError::NullByte);
assert!(buf.is_empty());
}
#[test]
fn test_try_encode_array_quotes_delimiter_values() {
let mut buf = BytesMut::new();
try_encode_copy_value(
&mut buf,
&Value::Array(vec![
Value::String("a,b".to_string()),
Value::String("line\nnext".to_string()),
]),
)
.unwrap();
assert_eq!(&buf[..], br#"{"a,b","line\nnext"}"#);
}
#[test]
fn test_try_encode_batch_rejects_expression_values() {
let rows = vec![vec![Value::Int(1), Value::Column("users.id".to_string())]];
let err = try_encode_copy_batch(&rows).unwrap_err();
assert!(
matches!(err, EncodeError::InvalidAst(ref message) if message.contains("COPY data value cannot be an expression")),
"{err}"
);
}
#[test]
fn test_encode_null() {
let mut buf = BytesMut::new();
encode_copy_value(&mut buf, &Value::Null);
assert_eq!(&buf[..], b"\\N");
}
#[test]
fn test_encode_batch() {
let rows = vec![
vec![Value::Int(1), Value::String("foo".to_string())],
vec![Value::Int(2), Value::String("bar".to_string())],
];
let buf = encode_copy_batch(&rows);
assert_eq!(&buf[..], b"1\tfoo\n2\tbar\n");
}
#[test]
fn test_encode_uuid() {
let mut buf = BytesMut::new();
let uuid = Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap();
encode_copy_value(&mut buf, &Value::Uuid(uuid));
assert_eq!(&buf[..], b"550e8400-e29b-41d4-a716-446655440000");
}
}