sqlx-sqlserver 0.0.1-beta.1

Independent Microsoft SQL Server driver crate for SQLx.
Documentation
use sqlx_core::Error;

use super::col_meta_data::ColMetaData;
use super::done::{Done, Status};
use super::packet::{encode_message, PacketFrameError, PacketType};
use super::read::{read_len_prefixed, read_u32_le, read_u8};
use super::return_value::ReturnValue;
use super::row::Row;
use super::token::{parse_env_change, parse_server_error, EnvChange, TokenParseError};
use crate::{error::server_error, MssqlColumn, MssqlQueryResult, MssqlRow, MssqlValue};

const TOKEN_COL_METADATA: u8 = 0x81;
const TOKEN_ERROR: u8 = 0xaa;
const TOKEN_INFO: u8 = 0xab;
const TOKEN_RETURN_STATUS: u8 = 0x79;
const TOKEN_RETURN_VALUE: u8 = 0xac;
const TOKEN_ROW: u8 = 0xd1;
const TOKEN_NBCROW: u8 = 0xd2;
const TOKEN_ENVCHANGE: u8 = 0xe3;
const TOKEN_DONE: u8 = 0xfd;
const TOKEN_DONEPROC: u8 = 0xfe;
const TOKEN_DONEINPROC: u8 = 0xff;

#[derive(Debug)]
pub(crate) struct QueryOutput {
    pub(crate) columns: Vec<MssqlColumn>,
    pub(crate) rows: Vec<MssqlRow>,
    pub(crate) result: MssqlQueryResult,
    pub(crate) return_values: Vec<MssqlValue>,
    pub(crate) env_changes: Vec<EnvChange>,
}

pub(crate) fn build_sql_batch_packet(
    sql: &str,
    packet_size: usize,
    transaction_descriptor: u64,
) -> Result<Vec<u8>, PacketFrameError> {
    let mut payload = Vec::with_capacity(22 + sql.len() * 2);
    write_all_headers(&mut payload, transaction_descriptor);

    for unit in sql.encode_utf16() {
        payload.extend_from_slice(&unit.to_le_bytes());
    }

    encode_message(PacketType::SQL_BATCH, &payload, packet_size)
}

pub(crate) fn parse_query_response(input: &[u8]) -> Result<QueryOutput, Error> {
    let mut input = input;
    let mut columns = Vec::new();
    let mut rows = Vec::new();
    let mut return_values = Vec::new();
    let mut env_changes = Vec::new();
    let mut rows_affected = 0;

    while !input.is_empty() {
        let token = read_u8(&mut input)?;

        match token {
            TOKEN_COL_METADATA => columns = ColMetaData::get(&mut input)?,
            TOKEN_ROW => rows.push(Row::get(&mut input, false, &columns)?),
            TOKEN_NBCROW => rows.push(Row::get(&mut input, true, &columns)?),
            TOKEN_RETURN_VALUE => {
                return_values.push(ReturnValue::get(&mut input)?.into_value());
            }
            TOKEN_RETURN_STATUS => {
                let _ = read_u32_le(&mut input)?;
            }
            TOKEN_DONE | TOKEN_DONEPROC | TOKEN_DONEINPROC => {
                let done = Done::get(&mut input)?;
                if done.status.contains(Status::DONE_COUNT) {
                    rows_affected += done.affected_rows;
                }
            }
            TOKEN_ERROR => {
                let error = parse_server_error(read_len_prefixed(&mut input)?)
                    .map_err(token_parse_error)?;
                return Err(server_error(error));
            }
            TOKEN_ENVCHANGE => {
                env_changes.push(
                    parse_env_change(read_len_prefixed(&mut input)?).map_err(token_parse_error)?,
                );
            }
            TOKEN_INFO => {
                let _ = read_len_prefixed(&mut input)?;
            }
            other => {
                return Err(Error::Protocol(format!(
                    "unsupported SQL Server query token 0x{other:02x}"
                )));
            }
        }
    }

    Ok(QueryOutput {
        columns,
        rows,
        result: MssqlQueryResult::new(rows_affected),
        return_values,
        env_changes,
    })
}

pub(crate) fn write_all_headers(out: &mut Vec<u8>, transaction_descriptor: u64) {
    out.extend_from_slice(&22_u32.to_le_bytes());
    out.extend_from_slice(&18_u32.to_le_bytes());
    out.extend_from_slice(&2_u16.to_le_bytes());
    out.extend_from_slice(&transaction_descriptor.to_le_bytes());
    out.extend_from_slice(&1_u32.to_le_bytes());
}

fn token_parse_error(error: TokenParseError) -> Error {
    Error::Protocol(error.to_string())
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::Mssql;
    use sqlx_core::row::Row;
    use sqlx_core::value::{Value, ValueRef};

    #[test]
    fn sql_batch_packet_starts_with_all_headers_and_utf16_sql() {
        let packet = build_sql_batch_packet("SELECT 1", 512, 0).unwrap();
        let payload = &packet[8..];

        assert_eq!(22, u32::from_le_bytes(payload[0..4].try_into().unwrap()));
        assert_eq!(18, u32::from_le_bytes(payload[4..8].try_into().unwrap()));
        assert_eq!(2, u16::from_le_bytes(payload[8..10].try_into().unwrap()));
        assert_eq!(0, u64::from_le_bytes(payload[10..18].try_into().unwrap()));
        assert_eq!(1, u32::from_le_bytes(payload[18..22].try_into().unwrap()));
        assert_eq!(
            &[b'S', 0, b'E', 0, b'L', 0, b'E', 0, b'C', 0, b'T', 0, b' ', 0, b'1', 0],
            &payload[22..]
        );
    }

    #[test]
    fn sql_batch_packet_writes_transaction_descriptor() {
        let packet = build_sql_batch_packet("SELECT 1", 512, 0x0102_0304_0506_0708).unwrap();
        let payload = &packet[8..];

        assert_eq!(
            0x0102_0304_0506_0708,
            u64::from_le_bytes(payload[10..18].try_into().unwrap())
        );
    }

    #[test]
    fn parses_select_one_response() {
        let response = [col_metadata_int(""), row_int(1), done(0x10, 0, 1)].concat();
        let output = parse_query_response(&response).unwrap();

        assert_eq!(1, output.rows.len());
        assert_eq!(1, output.result.rows_affected());
        assert_eq!(1_i32, output.rows[0].try_get::<i32, _>(0).unwrap());
    }

    #[test]
    fn parses_variable_length_int_response() {
        let response = [col_metadata_intn(""), row_intn(7), done(0x10, 0, 1)].concat();
        let output = parse_query_response(&response).unwrap();

        assert_eq!(7_i32, output.rows[0].try_get::<i32, _>(0).unwrap());
    }

    #[test]
    fn parses_null_typed_value_as_null() {
        let response = [col_metadata_null("value"), row_null(), done(0x10, 0, 1)].concat();
        let output = parse_query_response(&response).unwrap();

        assert!(output.rows[0].try_get_raw(0).unwrap().is_null());
    }

    #[test]
    fn parses_nbcrow_null_bitmap() {
        let response = [col_metadata_intn("value"), nbcrow_null(1), done(0x10, 0, 1)].concat();
        let output = parse_query_response(&response).unwrap();

        assert!(output.rows[0].try_get_raw(0).unwrap().is_null());
    }

    #[test]
    fn parses_return_value_response() {
        let response = [return_status(0), return_value_int(42), done(0x10, 0, 1)].concat();
        let output = parse_query_response(&response).unwrap();

        assert_eq!(1, output.return_values.len());
        assert_eq!(
            42_i32,
            <i32 as sqlx_core::decode::Decode<Mssql>>::decode(output.return_values[0].as_ref())
                .unwrap()
        );
    }

    #[test]
    fn collects_envchange_tokens_from_query_response() {
        let response = [
            env_change(4, &[4, b'8', 0, b'1', 0, b'9', 0, b'2', 0]),
            env_change(8, &[8, 8, 7, 6, 5, 4, 3, 2, 1]),
            done(0, 0, 0),
        ]
        .concat();
        let output = parse_query_response(&response).unwrap();

        assert_eq!(
            output.env_changes,
            vec![
                EnvChange::PacketSize(8192),
                EnvChange::BeginTransaction(0x0102_0304_0506_0708)
            ]
        );
    }

    #[test]
    fn parses_error_token_as_database_error() {
        let response = [error(208, 1, 16, "Invalid object name", "dbhost", "", 3)].concat();
        let error = parse_query_response(&response).unwrap_err();
        let db_error = error.as_database_error().unwrap();
        let mssql_error = db_error
            .as_error()
            .downcast_ref::<crate::MssqlDatabaseError>()
            .unwrap();

        assert_eq!(208, mssql_error.number());
        assert_eq!("Invalid object name", mssql_error.message());
        assert_eq!("dbhost", mssql_error.server_name());
        assert_eq!(3, mssql_error.line_number());
    }

    fn col_metadata_int(name: &str) -> Vec<u8> {
        let mut out = Vec::new();
        out.push(TOKEN_COL_METADATA);
        out.extend_from_slice(&1_u16.to_le_bytes());
        out.extend_from_slice(&0_u32.to_le_bytes());
        out.extend_from_slice(&0_u16.to_le_bytes());
        out.push(crate::protocol::type_info::DataType::Int as u8);
        push_b_varchar(&mut out, name);
        out
    }

    fn col_metadata_intn(name: &str) -> Vec<u8> {
        let mut out = Vec::new();
        out.push(TOKEN_COL_METADATA);
        out.extend_from_slice(&1_u16.to_le_bytes());
        out.extend_from_slice(&0_u32.to_le_bytes());
        out.extend_from_slice(&0_u16.to_le_bytes());
        out.push(crate::protocol::type_info::DataType::IntN as u8);
        out.push(4);
        push_b_varchar(&mut out, name);
        out
    }

    fn col_metadata_null(name: &str) -> Vec<u8> {
        let mut out = Vec::new();
        out.push(TOKEN_COL_METADATA);
        out.extend_from_slice(&1_u16.to_le_bytes());
        out.extend_from_slice(&0_u32.to_le_bytes());
        out.extend_from_slice(&0_u16.to_le_bytes());
        out.push(crate::protocol::type_info::DataType::Null as u8);
        push_b_varchar(&mut out, name);
        out
    }

    fn row_int(value: i32) -> Vec<u8> {
        let mut out = Vec::new();
        out.push(TOKEN_ROW);
        out.extend_from_slice(&value.to_le_bytes());
        out
    }

    fn row_intn(value: i32) -> Vec<u8> {
        let mut out = Vec::new();
        out.push(TOKEN_ROW);
        out.push(4);
        out.extend_from_slice(&value.to_le_bytes());
        out
    }

    fn row_null() -> Vec<u8> {
        vec![TOKEN_ROW]
    }

    fn nbcrow_null(column_count: usize) -> Vec<u8> {
        let mut out = vec![TOKEN_NBCROW];
        out.resize(1 + column_count.div_ceil(8), 0xff);
        out
    }

    fn return_value_int(value: i32) -> Vec<u8> {
        let mut out = Vec::new();
        out.push(TOKEN_RETURN_VALUE);
        out.extend_from_slice(&1_u16.to_le_bytes());
        out.push(0);
        out.push(1);
        out.extend_from_slice(&0_u32.to_le_bytes());
        out.extend_from_slice(&0_u16.to_le_bytes());
        out.push(crate::protocol::type_info::DataType::IntN as u8);
        out.push(4);
        out.push(4);
        out.extend_from_slice(&value.to_le_bytes());
        out
    }

    fn return_status(value: i32) -> Vec<u8> {
        let mut out = Vec::new();
        out.push(TOKEN_RETURN_STATUS);
        out.extend_from_slice(&value.to_le_bytes());
        out
    }

    fn env_change(change_type: u8, data: &[u8]) -> Vec<u8> {
        let len = 1 + data.len();
        let mut out = Vec::new();
        out.push(TOKEN_ENVCHANGE);
        out.extend_from_slice(&u16::try_from(len).unwrap().to_le_bytes());
        out.push(change_type);
        out.extend_from_slice(data);
        out
    }

    fn error(
        number: i32,
        state: u8,
        class: u8,
        message: &str,
        server: &str,
        procedure: &str,
        line: u32,
    ) -> Vec<u8> {
        let mut payload = Vec::new();
        payload.extend_from_slice(&number.to_le_bytes());
        payload.push(state);
        payload.push(class);
        push_us_varchar(&mut payload, message);
        push_b_varchar(&mut payload, server);
        push_b_varchar(&mut payload, procedure);
        payload.extend_from_slice(&line.to_le_bytes());

        let mut out = Vec::new();
        out.push(TOKEN_ERROR);
        out.extend_from_slice(&u16::try_from(payload.len()).unwrap().to_le_bytes());
        out.extend_from_slice(&payload);
        out
    }

    fn done(status: u16, current_command: u16, row_count: u64) -> Vec<u8> {
        let mut out = Vec::new();
        out.push(TOKEN_DONE);
        out.extend_from_slice(&status.to_le_bytes());
        out.extend_from_slice(&current_command.to_le_bytes());
        out.extend_from_slice(&row_count.to_le_bytes());
        out
    }

    fn push_us_varchar(out: &mut Vec<u8>, value: &str) {
        out.extend_from_slice(
            &u16::try_from(value.encode_utf16().count())
                .unwrap()
                .to_le_bytes(),
        );
        for unit in value.encode_utf16() {
            out.extend_from_slice(&unit.to_le_bytes());
        }
    }

    fn push_b_varchar(out: &mut Vec<u8>, value: &str) {
        out.push(u8::try_from(value.encode_utf16().count()).unwrap());
        for unit in value.encode_utf16() {
            out.extend_from_slice(&unit.to_le_bytes());
        }
    }
}