rsql_driver_sqlserver 0.19.1

rsql sqlserver driver
Documentation
use crate::metadata;
use async_trait::async_trait;
use chrono::{Datelike, Timelike};
use file_type::FileType;
use futures_util::stream::TryStreamExt;
use rsql_driver::Error::{InvalidUrl, IoError, UnsupportedColumnType};
use rsql_driver::{MemoryQueryResult, Metadata, QueryResult, Result, Value};
use sqlparser::dialect::{Dialect, MsSqlDialect};
use std::collections::HashMap;
use std::string::ToString;
use tiberius::{AuthMethod, Client, Column, Config, EncryptionLevel, QueryItem, Row};
use tokio::net::TcpStream;
use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
use url::Url;

#[derive(Debug)]
pub struct Driver;

#[async_trait]
impl rsql_driver::Driver for Driver {
    fn identifier(&self) -> &'static str {
        "sqlserver"
    }

    async fn connect(&self, url: &str) -> Result<Box<dyn rsql_driver::Connection>> {
        let parsed_url = Url::parse(url).map_err(|error| InvalidUrl(error.to_string()))?;
        let password = parsed_url.password().map(ToString::to_string);
        let connection = Connection::new(url, password).await?;
        Ok(Box::new(connection))
    }

    fn supports_file_type(&self, _file_type: &FileType) -> bool {
        false
    }
}

#[derive(Debug)]
pub struct Connection {
    url: String,
    client: Client<Compat<TcpStream>>,
}

impl Connection {
    pub(crate) async fn new(url: &str, password: Option<String>) -> Result<Connection> {
        let parsed_url = Url::parse(url)?;
        let mut params: HashMap<String, String> = parsed_url.query_pairs().into_owned().collect();
        let trust_server_certificate = params
            .remove("TrustServerCertificate")
            .is_some_and(|value| value == "true");
        let encryption = params
            .remove("encryption")
            .unwrap_or("required".to_string());

        let host = parsed_url.host_str().unwrap_or("localhost");
        let port = parsed_url.port().unwrap_or(1433);
        let database = parsed_url.path().replace('/', "");
        let username = parsed_url.username();

        let mut config = Config::new();
        config.host(host);
        config.port(port);

        if !database.is_empty() {
            config.database(database);
        }

        if !username.is_empty() {
            config.authentication(AuthMethod::sql_server(
                username,
                password.expect("password is required"),
            ));
        }

        if trust_server_certificate {
            config.trust_cert();
        }

        match encryption.as_str() {
            "off" => config.encryption(EncryptionLevel::Off),
            "on" => config.encryption(EncryptionLevel::On),
            "not_supported" => config.encryption(EncryptionLevel::NotSupported),
            _ => config.encryption(EncryptionLevel::Required),
        }

        let tcp = TcpStream::connect(config.get_addr()).await?;
        tcp.set_nodelay(true)?;

        let client = Client::connect(config, tcp.compat_write())
            .await
            .map_err(|error| IoError(error.to_string()))?;
        let connection = Connection {
            url: url.to_string(),
            client,
        };

        Ok(connection)
    }
}

#[async_trait]
impl rsql_driver::Connection for Connection {
    fn url(&self) -> &String {
        &self.url
    }

    async fn execute(&mut self, sql: &str) -> Result<u64> {
        let result = self
            .client
            .execute(sql, &[])
            .await
            .map_err(|error| IoError(error.to_string()))?;
        let rows = result.rows_affected()[0];
        Ok(rows)
    }

    async fn query(&mut self, sql: &str) -> Result<Box<dyn QueryResult>> {
        let mut query_stream = self
            .client
            .query(sql, &[])
            .await
            .map_err(|error| IoError(error.to_string()))?;
        let mut columns: Vec<String> = Vec::new();

        let mut rows = Vec::new();
        while let Some(item) = query_stream
            .try_next()
            .await
            .map_err(|error| IoError(error.to_string()))?
        {
            if let QueryItem::Metadata(meta) = item {
                if meta.result_index() == 0 {
                    for column in meta.columns() {
                        columns.push(column.name().to_string());
                    }
                }
            } else if let QueryItem::Row(row) = item {
                let mut row_data = Vec::new();
                if row.result_index() == 0 {
                    for (index, column) in row.columns().iter().enumerate() {
                        let value = convert_to_value(&row, column, index)?;
                        row_data.push(value);
                    }
                }
                rows.push(row_data);
            }
        }

        let query_result = MemoryQueryResult::new(columns, rows);
        Ok(Box::new(query_result))
    }

    async fn metadata(&mut self) -> Result<Metadata> {
        metadata::get_metadata(self).await
    }

    fn dialect(&self) -> Box<dyn Dialect> {
        Box::new(MsSqlDialect {})
    }
}

#[expect(clippy::same_functions_in_if_condition)]
#[expect(clippy::too_many_lines)]
fn convert_to_value(row: &Row, column: &Column, index: usize) -> Result<Value> {
    let column_name = column.name();

    if let Ok(value) = row.try_get(index) {
        let value: Option<&str> = value;
        match value {
            Some(v) => Ok(Value::String(v.to_string())),
            None => Ok(Value::Null),
        }
    } else if let Ok(value) = row.try_get(column_name) {
        let value: Option<&[u8]> = value;
        match value {
            Some(v) => Ok(Value::Bytes(v.to_vec())),
            None => Ok(Value::Null),
        }
    } else if let Ok(value) = row.try_get(column_name) {
        let value: Option<u8> = value;
        match value {
            Some(v) => Ok(Value::U8(v)),
            None => Ok(Value::Null),
        }
    } else if let Ok(value) = row.try_get(column_name) {
        let value: Option<i16> = value;
        match value {
            Some(v) => Ok(Value::I16(v)),
            None => Ok(Value::Null),
        }
    } else if let Ok(value) = row.try_get(column_name) {
        let value: Option<i32> = value;
        match value {
            Some(v) => Ok(Value::I32(v)),
            None => Ok(Value::Null),
        }
    } else if let Ok(value) = row.try_get(column_name) {
        let value: Option<i64> = value;
        match value {
            Some(v) => Ok(Value::I64(v)),
            None => Ok(Value::Null),
        }
    } else if let Ok(value) = row.try_get(column_name) {
        let value: Option<f32> = value;
        match value {
            Some(v) => Ok(Value::F32(v)),
            None => Ok(Value::Null),
        }
    } else if let Ok(value) = row.try_get(column_name) {
        let value: Option<f64> = value;
        match value {
            Some(v) => Ok(Value::F64(v)),
            None => Ok(Value::Null),
        }
    } else if let Ok(value) = row.try_get(column_name) {
        let value: Option<bool> = value;
        match value {
            Some(v) => Ok(Value::Bool(v)),
            None => Ok(Value::Null),
        }
    } else if let Ok(value) = row.try_get(column_name) {
        let value: Option<rust_decimal::Decimal> = value;
        match value {
            Some(v) => Ok(Value::Decimal(v)),
            None => Ok(Value::Null),
        }
    } else if let Ok(value) = row.try_get(column_name) {
        let value: Option<chrono::NaiveDate> = value;
        match value {
            Some(v) => {
                let year = i16::try_from(v.year())?;
                let month = i8::try_from(v.month())?;
                let day = i8::try_from(v.day())?;
                let date = jiff::civil::date(year, month, day);
                Ok(Value::Date(date))
            }
            None => Ok(Value::Null),
        }
    } else if let Ok(value) = row.try_get(column_name) {
        let value: Option<chrono::NaiveTime> = value;
        match value {
            Some(v) => {
                let hour = i8::try_from(v.hour())?;
                let minute = i8::try_from(v.minute())?;
                let second = i8::try_from(v.second())?;
                let nanosecond = i32::try_from(v.nanosecond())?;
                let time = jiff::civil::time(hour, minute, second, nanosecond);
                Ok(Value::Time(time))
            }
            None => Ok(Value::Null),
        }
    } else if let Ok(value) = row.try_get(column_name) {
        let value: Option<chrono::NaiveDateTime> = value;
        match value {
            Some(v) => {
                let year = i16::try_from(v.year())?;
                let month = i8::try_from(v.month())?;
                let day = i8::try_from(v.day())?;
                let hour = i8::try_from(v.hour())?;
                let minute = i8::try_from(v.minute())?;
                let second = i8::try_from(v.second())?;
                let nanosecond = i32::try_from(v.nanosecond())?;
                let date_time =
                    jiff::civil::datetime(year, month, day, hour, minute, second, nanosecond);
                Ok(Value::DateTime(date_time))
            }
            None => Ok(Value::Null),
        }
    } else {
        let column_type = format!("{:?}", column.column_type());
        let type_name = format!("{column_type:?}");
        return Err(UnsupportedColumnType {
            column_name: column_name.to_string(),
            column_type: type_name,
        });
    }
}