quack_protocol 0.1.0

Rust client SDK for DuckDB's experimental Quack remote protocol
Documentation
use crate::errors::{QuackError, Result};
use crate::vector::{DateValue, TimeUnit, TimestampUnit, Value, decimal_to_string};

#[derive(Clone, Debug, PartialEq)]
pub enum SqlParameter {
    Value(Value),
    List(Vec<SqlParameter>),
}

impl From<Value> for SqlParameter {
    fn from(value: Value) -> Self {
        Self::Value(value)
    }
}

impl From<&str> for SqlParameter {
    fn from(value: &str) -> Self {
        Self::Value(Value::String(value.to_string()))
    }
}

impl From<String> for SqlParameter {
    fn from(value: String) -> Self {
        Self::Value(Value::String(value))
    }
}

impl From<i32> for SqlParameter {
    fn from(value: i32) -> Self {
        Self::Value(Value::Int(value as i64))
    }
}

impl From<i64> for SqlParameter {
    fn from(value: i64) -> Self {
        Self::Value(Value::Int(value))
    }
}

impl From<u64> for SqlParameter {
    fn from(value: u64) -> Self {
        Self::Value(Value::UInt(value))
    }
}

impl From<bool> for SqlParameter {
    fn from(value: bool) -> Self {
        Self::Value(Value::Bool(value))
    }
}

#[derive(Clone, Debug, PartialEq)]
pub enum SqlParameters {
    Positional(Vec<SqlParameter>),
    Named(indexmap::IndexMap<String, SqlParameter>),
}

pub fn format_sql(sql: &str, params: Option<&SqlParameters>) -> Result<String> {
    match params {
        None => Ok(sql.to_string()),
        Some(SqlParameters::Positional(params)) => format_positional_sql(sql, params),
        Some(SqlParameters::Named(params)) => format_named_sql(sql, params),
    }
}

pub fn sql_literal(value: &SqlParameter) -> Result<String> {
    match value {
        SqlParameter::List(values) => Ok(format!(
            "[{}]",
            values
                .iter()
                .map(sql_literal)
                .collect::<Result<Vec<_>>>()?
                .join(", ")
        )),
        SqlParameter::Value(value) => value_literal(value),
    }
}

fn value_literal(value: &Value) -> Result<String> {
    Ok(match value {
        Value::Null => "NULL".to_string(),
        Value::Bool(value) => if *value { "TRUE" } else { "FALSE" }.to_string(),
        Value::Int(value) => value.to_string(),
        Value::UInt(value) => value.to_string(),
        Value::HugeInt(value) => value.to_string(),
        Value::UHugeInt(value) => value.to_string(),
        Value::Float(value) if value.is_finite() => value.to_string(),
        Value::Double(value) if value.is_finite() => value.to_string(),
        Value::Float(value) => {
            return Err(QuackError::protocol(format!(
                "cannot encode non-finite SQL number {value}"
            )));
        }
        Value::Double(value) => {
            return Err(QuackError::protocol(format!(
                "cannot encode non-finite SQL number {value}"
            )));
        }
        Value::String(value) => format!("'{}'", value.replace('\'', "''")),
        Value::Bytes(value) => format!("from_hex('{}')", bytes_to_hex(value)),
        Value::Decimal(value) => decimal_to_string(value),
        Value::Date(value) => format!("DATE '{}'", date_from_days(*value)),
        Value::Time(value) => format!("TIME '{}'", format_time(value.value, value.unit)),
        Value::TimeTz(_) => {
            return Err(QuackError::protocol(
                "TIME WITH TIME ZONE parameters are not supported as SQL literals",
            ));
        }
        Value::Timestamp(value) => {
            let keyword = if value.timezone_utc {
                "TIMESTAMPTZ"
            } else {
                "TIMESTAMP"
            };
            format!(
                "{keyword} '{}'",
                format_timestamp(value.value, value.unit, value.timezone_utc)
            )
        }
        Value::Interval(value) => format!(
            "INTERVAL '{} months {} days {} microseconds'",
            value.months, value.days, value.micros
        ),
        Value::List(_) | Value::Struct(_) => {
            return Err(QuackError::protocol(
                "object SQL parameters are not supported; pass scalar values or SqlParameter::List",
            ));
        }
    })
}

fn format_positional_sql(sql: &str, params: &[SqlParameter]) -> Result<String> {
    let mut index = 0usize;
    let formatted = scan_sql(sql, |token| {
        if token != "?" {
            return Ok(token.to_string());
        }
        let value = params.get(index).ok_or_else(|| {
            QuackError::protocol("SQL has more positional placeholders than parameters")
        })?;
        index += 1;
        sql_literal(value)
    })?;
    if index != params.len() {
        return Err(QuackError::protocol(format!(
            "SQL has {index} positional placeholders but {} parameters were provided",
            params.len()
        )));
    }
    Ok(formatted)
}

fn format_named_sql(
    sql: &str,
    params: &indexmap::IndexMap<String, SqlParameter>,
) -> Result<String> {
    scan_sql(sql, |token| {
        if !token.starts_with(':') {
            return Ok(token.to_string());
        }
        let name = &token[1..];
        let value = params
            .get(name)
            .ok_or_else(|| QuackError::protocol(format!("missing SQL parameter :{name}")))?;
        sql_literal(value)
    })
}

fn scan_sql(sql: &str, mut replace_token: impl FnMut(&str) -> Result<String>) -> Result<String> {
    let bytes = sql.as_bytes();
    let mut output = String::new();
    let mut index = 0usize;
    while index < bytes.len() {
        let ch = bytes[index] as char;
        let next = bytes.get(index + 1).copied().map(char::from);
        if ch == '\'' {
            let end = read_single_quoted(sql, index);
            output.push_str(&sql[index..end]);
            index = end;
            continue;
        }
        if ch == '"' {
            let end = read_double_quoted(sql, index);
            output.push_str(&sql[index..end]);
            index = end;
            continue;
        }
        if ch == '-' && next == Some('-') {
            let end = sql[index + 2..]
                .find('\n')
                .map(|relative| index + 2 + relative)
                .unwrap_or(sql.len());
            output.push_str(&sql[index..end]);
            index = end;
            continue;
        }
        if ch == '/' && next == Some('*') {
            let end = sql[index + 2..]
                .find("*/")
                .map(|relative| index + 2 + relative + 2)
                .unwrap_or(sql.len());
            output.push_str(&sql[index..end]);
            index = end;
            continue;
        }
        if ch == '?' {
            output.push_str(&replace_token("?")?);
            index += 1;
            continue;
        }
        if ch == ':'
            && index.checked_sub(1).and_then(|i| bytes.get(i)).copied() != Some(b':')
            && next != Some(':')
            && next.is_some_and(is_identifier_start)
        {
            let start = index + 1;
            let mut end = start + 1;
            while end < bytes.len() && is_identifier_part(bytes[end] as char) {
                end += 1;
            }
            output.push_str(&replace_token(&sql[index..end])?);
            index = end;
            continue;
        }
        output.push(ch);
        index += 1;
    }
    Ok(output)
}

fn read_single_quoted(sql: &str, start: usize) -> usize {
    let bytes = sql.as_bytes();
    let mut index = start + 1;
    while index < bytes.len() {
        if bytes[index] == b'\'' && bytes.get(index + 1) == Some(&b'\'') {
            index += 2;
            continue;
        }
        if bytes[index] == b'\'' {
            return index + 1;
        }
        index += 1;
    }
    index
}

fn read_double_quoted(sql: &str, start: usize) -> usize {
    let bytes = sql.as_bytes();
    let mut index = start + 1;
    while index < bytes.len() {
        if bytes[index] == b'"' && bytes.get(index + 1) == Some(&b'"') {
            index += 2;
            continue;
        }
        if bytes[index] == b'"' {
            return index + 1;
        }
        index += 1;
    }
    index
}

fn is_identifier_start(ch: char) -> bool {
    ch.is_ascii_alphabetic() || ch == '_'
}

fn is_identifier_part(ch: char) -> bool {
    ch.is_ascii_alphanumeric() || ch == '_'
}

fn bytes_to_hex(bytes: &[u8]) -> String {
    bytes.iter().map(|byte| format!("{byte:02x}")).collect()
}

fn date_from_days(value: DateValue) -> String {
    let epoch = chrono::NaiveDate::from_ymd_opt(1970, 1, 1).expect("valid epoch");
    (epoch + chrono::Duration::days(value.days as i64))
        .format("%Y-%m-%d")
        .to_string()
}

fn format_timestamp(value: i64, unit: TimestampUnit, timezone_utc: bool) -> String {
    let micros = match unit {
        TimestampUnit::Seconds => value * 1_000_000,
        TimestampUnit::Millis => value * 1_000,
        TimestampUnit::Micros => value,
        TimestampUnit::Nanos => value / 1_000,
    };
    let seconds = micros.div_euclid(1_000_000);
    let micro_fraction = micros.rem_euclid(1_000_000);
    let date = chrono::DateTime::from_timestamp(seconds, (micro_fraction as u32) * 1_000)
        .expect("timestamp in range");
    let suffix = if timezone_utc { "+00" } else { "" };
    format!(
        "{}.{:03}{suffix}",
        date.format("%Y-%m-%d %H:%M:%S"),
        micro_fraction / 1_000
    )
}

fn format_time(value: i64, unit: TimeUnit) -> String {
    let nanos = match unit {
        TimeUnit::Micros => value as i128 * 1_000,
        TimeUnit::Nanos => value as i128,
    };
    let total_seconds = nanos / 1_000_000_000;
    let fraction = nanos % 1_000_000_000;
    let hours = total_seconds / 3_600;
    let minutes = (total_seconds % 3_600) / 60;
    let seconds = total_seconds % 60;
    let fraction_text = if fraction == 0 {
        String::new()
    } else {
        format!(".{:09}", fraction)
            .trim_end_matches('0')
            .to_string()
    };
    format!("{hours:02}:{minutes:02}:{seconds:02}{fraction_text}")
}