yamlbase 0.7.2

A lightweight SQL server that serves YAML-defined tables over standard SQL protocols
Documentation
use bytes::{BufMut, BytesMut};
use tokio::io::AsyncWriteExt;
use tokio::net::TcpStream;

use crate::YamlBaseError;
use crate::database::Value;
use crate::sql::executor::QueryResult;

/// PostgreSQL COPY protocol implementation
/// Supports COPY (SELECT ...) TO STDOUT for read-only operations
pub struct PostgresCopyProtocol;

impl PostgresCopyProtocol {
    /// Handle COPY TO STDOUT command
    pub async fn handle_copy_to_stdout(
        stream: &mut TcpStream,
        result: &QueryResult,
        format: CopyFormat,
    ) -> crate::Result<()> {
        match format {
            CopyFormat::Text => Self::handle_copy_text_format(stream, result).await,
            CopyFormat::Binary => Self::handle_copy_binary_format(stream, result).await,
            CopyFormat::Csv => Self::handle_copy_csv_format(stream, result).await,
        }
    }

    async fn handle_copy_text_format(
        stream: &mut TcpStream,
        result: &QueryResult,
    ) -> crate::Result<()> {
        // Send CopyOutResponse
        let mut buf = BytesMut::new();
        buf.put_u8(b'H'); // CopyOutResponse
        buf.put_u32(4 + 1 + 2 + result.columns.len() as u32 * 2); // length
        buf.put_u8(0); // overall format: 0 = text, 1 = binary
        buf.put_u16(result.columns.len() as u16); // number of columns

        // Format codes for each column (0 = text, 1 = binary)
        for _ in &result.columns {
            buf.put_u16(0); // text format for all columns
        }

        stream.write_all(&buf).await?;

        // Send data rows
        for row in &result.rows {
            buf.clear();
            buf.put_u8(b'd'); // CopyData

            // Build the row data
            let mut row_data = Vec::new();
            for (i, value) in row.iter().enumerate() {
                if i > 0 {
                    row_data.push(b'\t'); // Tab delimiter
                }

                match value {
                    Value::Null => row_data.extend_from_slice(b"\\N"),
                    _ => {
                        let val_str = value.to_string();
                        // Escape special characters in text format
                        let escaped = Self::escape_copy_text(&val_str);
                        row_data.extend_from_slice(escaped.as_bytes());
                    }
                }
            }
            row_data.push(b'\n'); // Newline at end of row

            buf.put_u32(4 + row_data.len() as u32); // length
            buf.put_slice(&row_data);

            stream.write_all(&buf).await?;
        }

        // Send CopyDone
        buf.clear();
        buf.put_u8(b'c'); // CopyDone
        buf.put_u32(4); // length
        stream.write_all(&buf).await?;

        Ok(())
    }

    async fn handle_copy_binary_format(
        stream: &mut TcpStream,
        result: &QueryResult,
    ) -> crate::Result<()> {
        // Send CopyOutResponse
        let mut buf = BytesMut::new();
        buf.put_u8(b'H'); // CopyOutResponse
        buf.put_u32(4 + 1 + 2 + result.columns.len() as u32 * 2); // length
        buf.put_u8(1); // overall format: 1 = binary
        buf.put_u16(result.columns.len() as u16); // number of columns

        // Format codes for each column (1 = binary for all)
        for _ in &result.columns {
            buf.put_u16(1); // binary format for all columns
        }

        stream.write_all(&buf).await?;

        // Send binary header
        buf.clear();
        buf.put_u8(b'd'); // CopyData

        let header = b"PGCOPY\n\xff\r\n\0"; // PostgreSQL binary copy signature
        let header_data = [
            header.as_slice(),
            &[0, 0, 0, 0], // flags field (4 bytes)
            &[0, 0, 0, 0], // header extension length (4 bytes)
        ]
        .concat();

        buf.put_u32(4 + header_data.len() as u32);
        buf.put_slice(&header_data);
        stream.write_all(&buf).await?;

        // Send data rows in binary format
        for row in &result.rows {
            buf.clear();
            buf.put_u8(b'd'); // CopyData

            // Build binary row
            let mut row_data = BytesMut::new();
            row_data.put_u16(row.len() as u16); // number of fields

            for (i, value) in row.iter().enumerate() {
                let col_type = result.column_types.get(i);
                Self::write_binary_value(&mut row_data, value, col_type)?;
            }

            buf.put_u32(4 + row_data.len() as u32);
            buf.put_slice(&row_data);

            stream.write_all(&buf).await?;
        }

        // Send binary trailer
        buf.clear();
        buf.put_u8(b'd'); // CopyData
        buf.put_u32(4 + 2); // length
        buf.put_i16(-1); // file trailer
        stream.write_all(&buf).await?;

        // Send CopyDone
        buf.clear();
        buf.put_u8(b'c'); // CopyDone
        buf.put_u32(4); // length
        stream.write_all(&buf).await?;

        Ok(())
    }

    async fn handle_copy_csv_format(
        stream: &mut TcpStream,
        result: &QueryResult,
    ) -> crate::Result<()> {
        // Send CopyOutResponse
        let mut buf = BytesMut::new();
        buf.put_u8(b'H'); // CopyOutResponse
        buf.put_u32(4 + 1 + 2 + result.columns.len() as u32 * 2); // length
        buf.put_u8(0); // overall format: 0 = text (CSV is text-based)
        buf.put_u16(result.columns.len() as u16); // number of columns

        // Format codes for each column (0 = text for CSV)
        for _ in &result.columns {
            buf.put_u16(0); // text format for all columns
        }

        stream.write_all(&buf).await?;

        // Send header row if requested (for now, we'll skip it)

        // Send data rows in CSV format
        for row in &result.rows {
            buf.clear();
            buf.put_u8(b'd'); // CopyData

            let mut row_data = Vec::new();
            for (i, value) in row.iter().enumerate() {
                if i > 0 {
                    row_data.push(b','); // Comma delimiter
                }

                match value {
                    Value::Null => {
                        // Empty field for NULL in CSV
                    }
                    _ => {
                        let val_str = value.to_string();
                        let escaped = Self::escape_csv_field(&val_str);
                        row_data.extend_from_slice(escaped.as_bytes());
                    }
                }
            }
            row_data.push(b'\n'); // Newline at end of row

            buf.put_u32(4 + row_data.len() as u32); // length
            buf.put_slice(&row_data);

            stream.write_all(&buf).await?;
        }

        // Send CopyDone
        buf.clear();
        buf.put_u8(b'c'); // CopyDone
        buf.put_u32(4); // length
        stream.write_all(&buf).await?;

        Ok(())
    }

    fn write_binary_value(
        buf: &mut BytesMut,
        value: &Value,
        col_type: Option<&crate::yaml::schema::SqlType>,
    ) -> crate::Result<()> {
        match value {
            Value::Null => {
                buf.put_i32(-1); // NULL indicator
            }
            Value::Boolean(b) => {
                buf.put_i32(1); // length
                buf.put_u8(if *b { 1 } else { 0 });
            }
            Value::Integer(i) => {
                // Determine size based on column type
                match col_type {
                    Some(crate::yaml::schema::SqlType::BigInt) => {
                        buf.put_i32(8); // length
                        buf.put_i64(*i);
                    }
                    Some(crate::yaml::schema::SqlType::Integer) => {
                        buf.put_i32(4); // length
                        buf.put_i32(*i as i32);
                    }
                    _ => {
                        // Default to 4-byte integer
                        buf.put_i32(4); // length
                        buf.put_i32(*i as i32);
                    }
                }
            }
            Value::Float(f) => {
                buf.put_i32(4); // length
                buf.put_f32(*f);
            }
            Value::Double(d) => {
                buf.put_i32(8); // length
                buf.put_f64(*d);
            }
            Value::Decimal(d) => {
                // For simplicity, convert decimal to string for binary format
                let val_str = d.to_string();
                let bytes = val_str.as_bytes();
                buf.put_i32(bytes.len() as i32);
                buf.put_slice(bytes);
            }
            Value::Text(s) => {
                let bytes = s.as_bytes();
                buf.put_i32(bytes.len() as i32);
                buf.put_slice(bytes);
            }
            Value::Json(s) => {
                let val_str = s.to_string();
                let bytes = val_str.as_bytes();
                buf.put_i32(bytes.len() as i32);
                buf.put_slice(bytes);
            }
            Value::Date(d) => {
                let val_str = d.to_string();
                let bytes = val_str.as_bytes();
                buf.put_i32(bytes.len() as i32);
                buf.put_slice(bytes);
            }
            Value::Time(t) => {
                let val_str = t.to_string();
                let bytes = val_str.as_bytes();
                buf.put_i32(bytes.len() as i32);
                buf.put_slice(bytes);
            }
            Value::Timestamp(ts) => {
                let val_str = ts.to_string();
                let bytes = val_str.as_bytes();
                buf.put_i32(bytes.len() as i32);
                buf.put_slice(bytes);
            }
            Value::Uuid(u) => {
                let val_str = u.to_string();
                let bytes = val_str.as_bytes();
                buf.put_i32(bytes.len() as i32);
                buf.put_slice(bytes);
            }
        }
        Ok(())
    }

    fn escape_copy_text(input: &str) -> String {
        input
            .replace('\\', "\\\\")
            .replace('\t', "\\t")
            .replace('\n', "\\n")
            .replace('\r', "\\r")
    }

    fn escape_csv_field(input: &str) -> String {
        // If field contains comma, quote, or newline, wrap in quotes and escape quotes
        if input.contains(',')
            || input.contains('"')
            || input.contains('\n')
            || input.contains('\r')
        {
            format!("\"{}\"", input.replace('"', "\"\""))
        } else {
            input.to_string()
        }
    }

    pub fn is_copy_command(query: &str) -> bool {
        query.trim().to_uppercase().starts_with("COPY")
    }

    pub fn parse_copy_command(query: &str) -> Result<CopyCommand, YamlBaseError> {
        let query = query.trim();
        let upper_query = query.to_uppercase();

        if !upper_query.starts_with("COPY") {
            return Err(YamlBaseError::Protocol("Not a COPY command".to_string()));
        }

        // Look for TO STDOUT
        if !upper_query.contains("TO STDOUT") {
            return Err(YamlBaseError::Protocol(
                "Only COPY TO STDOUT is supported".to_string(),
            ));
        }

        // Extract the SELECT statement
        let copy_start = query
            .find('(')
            .ok_or_else(|| YamlBaseError::Protocol("Expected '(' after COPY".to_string()))?;

        let copy_end = query
            .rfind(") TO STDOUT")
            .ok_or_else(|| YamlBaseError::Protocol("Expected ') TO STDOUT'".to_string()))?;

        let select_query = &query[copy_start + 1..copy_end];

        // Determine format
        let format = if upper_query.contains("WITH") {
            if upper_query.contains("FORMAT CSV") || upper_query.contains("CSV") {
                CopyFormat::Csv
            } else if upper_query.contains("FORMAT BINARY") || upper_query.contains("BINARY") {
                CopyFormat::Binary
            } else {
                CopyFormat::Text
            }
        } else {
            CopyFormat::Text
        };

        Ok(CopyCommand {
            select_query: select_query.to_string(),
            format,
        })
    }
}

#[derive(Debug, Clone)]
pub struct CopyCommand {
    pub select_query: String,
    pub format: CopyFormat,
}

#[derive(Debug, Clone)]
pub enum CopyFormat {
    Text,
    Binary,
    Csv,
}