benzin 0.5.0

An async extension for Diesel the safe, extensible ORM and Query Builder
use std::borrow::Cow;

use diesel::{
    backend::Backend,
    mysql::{
        Mysql, MysqlType, MysqlValue,
        data_types::{MysqlTime, MysqlTimestampType},
    },
    row::{PartialRow, RowIndex, RowSealed},
};
use mysql_async::{
    Column, Row, Value,
    consts::{ColumnFlags, ColumnType},
};

pub struct MysqlRow(pub(super) Row);

impl mysql_async::prelude::FromRow for MysqlRow {
    fn from_row_opt(row: Row) -> Result<Self, mysql_async::FromRowError>
    where
        Self: Sized,
    {
        Ok(Self(row))
    }
}

impl RowIndex<usize> for MysqlRow {
    fn idx(&self, idx: usize) -> Option<usize> {
        if idx < self.0.columns_ref().len() {
            Some(idx)
        } else {
            None
        }
    }
}

impl<'a> RowIndex<&'a str> for MysqlRow {
    fn idx(&self, idx: &'a str) -> Option<usize> {
        self.0.columns().iter().position(|c| c.name_str() == idx)
    }
}

impl RowSealed for MysqlRow {}

impl<'a> diesel::row::Row<'a, Mysql> for MysqlRow {
    type InnerPartialRow = Self;
    type Field<'b>
        = MysqlField<'b>
    where
        Self: 'b,
        'a: 'b;

    fn field_count(&self) -> usize {
        self.0.columns_ref().len()
    }

    fn get<'b, I>(&'b self, idx: I) -> Option<Self::Field<'b>>
    where
        'a: 'b,
        Self: diesel::row::RowIndex<I>,
    {
        let idx = diesel::row::RowIndex::idx(self, idx)?;
        let value = self.0.as_ref(idx)?;
        let column = &self.0.columns_ref()[idx];
        let buffer = match value {
            Value::NULL => None,
            Value::Bytes(b) => {
                // deserialize gets the length prepended, so we just use that buffer
                // directly
                Some(Cow::Borrowed(b as &[_]))
            }
            Value::Time(neg, day, hour, minute, second, second_part) => {
                let date = MysqlTime::new(
                    0,
                    0,
                    *day as _,
                    *hour as _,
                    *minute as _,
                    *second as _,
                    *second_part as _,
                    *neg as _,
                    MysqlTimestampType::MYSQL_TIMESTAMP_TIME,
                    0,
                );
                let buffer = unsafe {
                    let ptr = &date as *const MysqlTime as *const u8;
                    let slice = std::slice::from_raw_parts(ptr, std::mem::size_of::<MysqlTime>());
                    slice.to_vec()
                };
                Some(Cow::Owned(buffer))
            }
            Value::Date(year, month, day, hour, minute, second, second_part) => {
                let date = MysqlTime::new(
                    *year as _,
                    *month as _,
                    *day as _,
                    *hour as _,
                    *minute as _,
                    *second as _,
                    *second_part as _,
                    false,
                    MysqlTimestampType::MYSQL_TIMESTAMP_DATETIME,
                    0,
                );
                let buffer = unsafe {
                    let ptr = &date as *const MysqlTime as *const u8;
                    let slice = std::slice::from_raw_parts(ptr, std::mem::size_of::<MysqlTime>());
                    slice.to_vec()
                };
                Some(Cow::Owned(buffer))
            }
            Value::Int(v) => Some(Cow::Owned(v.to_le_bytes().to_vec())),
            Value::UInt(v) => Some(Cow::Owned(v.to_le_bytes().to_vec())),
            Value::Float(v) => Some(Cow::Owned(v.to_le_bytes().to_vec())),
            Value::Double(v) => Some(Cow::Owned(v.to_le_bytes().to_vec())),
        };
        let field = MysqlField {
            value: buffer,
            column,
            name: column.name_str(),
        };
        Some(field)
    }

    fn partial_row(
        &'_ self,
        range: std::ops::Range<usize>,
    ) -> PartialRow<'_, Self::InnerPartialRow> {
        PartialRow::new(self, range)
    }
}

pub struct MysqlField<'a> {
    value: Option<Cow<'a, [u8]>>,
    column: &'a Column,
    name: Cow<'a, str>,
}

impl diesel::row::Field<'_, Mysql> for MysqlField<'_> {
    fn field_name(&self) -> Option<&str> {
        Some(&*self.name)
    }

    fn value(&self) -> Option<<Mysql as Backend>::RawValue<'_>> {
        self.value.as_ref().map(|v| {
            MysqlValue::new(
                v,
                convert_type(self.column.column_type(), self.column.flags()),
            )
        })
    }
}

fn convert_type(column_type: ColumnType, column_flags: ColumnFlags) -> MysqlType {
    match column_type {
        ColumnType::MYSQL_TYPE_NEWDECIMAL | ColumnType::MYSQL_TYPE_DECIMAL => MysqlType::Numeric,
        ColumnType::MYSQL_TYPE_TINY if column_flags.contains(ColumnFlags::UNSIGNED_FLAG) => {
            MysqlType::UnsignedTiny
        }
        ColumnType::MYSQL_TYPE_TINY => MysqlType::Tiny,
        ColumnType::MYSQL_TYPE_YEAR | ColumnType::MYSQL_TYPE_SHORT
            if column_flags.contains(ColumnFlags::UNSIGNED_FLAG) =>
        {
            MysqlType::UnsignedShort
        }
        ColumnType::MYSQL_TYPE_YEAR | ColumnType::MYSQL_TYPE_SHORT => MysqlType::Short,
        ColumnType::MYSQL_TYPE_INT24 | ColumnType::MYSQL_TYPE_LONG
            if column_flags.contains(ColumnFlags::UNSIGNED_FLAG) =>
        {
            MysqlType::UnsignedLong
        }
        ColumnType::MYSQL_TYPE_INT24 | ColumnType::MYSQL_TYPE_LONG => MysqlType::Long,
        ColumnType::MYSQL_TYPE_LONGLONG if column_flags.contains(ColumnFlags::UNSIGNED_FLAG) => {
            MysqlType::UnsignedLongLong
        }
        ColumnType::MYSQL_TYPE_LONGLONG => MysqlType::LongLong,
        ColumnType::MYSQL_TYPE_FLOAT => MysqlType::Float,
        ColumnType::MYSQL_TYPE_DOUBLE => MysqlType::Double,

        ColumnType::MYSQL_TYPE_TIMESTAMP => MysqlType::Timestamp,
        ColumnType::MYSQL_TYPE_DATE => MysqlType::Date,
        ColumnType::MYSQL_TYPE_TIME => MysqlType::Time,
        ColumnType::MYSQL_TYPE_DATETIME => MysqlType::DateTime,
        ColumnType::MYSQL_TYPE_BIT => MysqlType::Bit,
        ColumnType::MYSQL_TYPE_JSON => MysqlType::String,

        ColumnType::MYSQL_TYPE_VAR_STRING
        | ColumnType::MYSQL_TYPE_STRING
        | ColumnType::MYSQL_TYPE_TINY_BLOB
        | ColumnType::MYSQL_TYPE_MEDIUM_BLOB
        | ColumnType::MYSQL_TYPE_LONG_BLOB
        | ColumnType::MYSQL_TYPE_BLOB
            if column_flags.contains(ColumnFlags::ENUM_FLAG) =>
        {
            MysqlType::Enum
        }
        ColumnType::MYSQL_TYPE_VAR_STRING
        | ColumnType::MYSQL_TYPE_STRING
        | ColumnType::MYSQL_TYPE_TINY_BLOB
        | ColumnType::MYSQL_TYPE_MEDIUM_BLOB
        | ColumnType::MYSQL_TYPE_LONG_BLOB
        | ColumnType::MYSQL_TYPE_BLOB
            if column_flags.contains(ColumnFlags::SET_FLAG) =>
        {
            MysqlType::Set
        }

        ColumnType::MYSQL_TYPE_VAR_STRING
        | ColumnType::MYSQL_TYPE_STRING
        | ColumnType::MYSQL_TYPE_TINY_BLOB
        | ColumnType::MYSQL_TYPE_MEDIUM_BLOB
        | ColumnType::MYSQL_TYPE_LONG_BLOB
        | ColumnType::MYSQL_TYPE_BLOB
            if column_flags.contains(ColumnFlags::BINARY_FLAG) =>
        {
            MysqlType::Blob
        }

        ColumnType::MYSQL_TYPE_VAR_STRING
        | ColumnType::MYSQL_TYPE_STRING
        | ColumnType::MYSQL_TYPE_TINY_BLOB
        | ColumnType::MYSQL_TYPE_MEDIUM_BLOB
        | ColumnType::MYSQL_TYPE_LONG_BLOB
        | ColumnType::MYSQL_TYPE_BLOB => MysqlType::String,

        ColumnType::MYSQL_TYPE_NULL
        | ColumnType::MYSQL_TYPE_NEWDATE
        | ColumnType::MYSQL_TYPE_VARCHAR
        | ColumnType::MYSQL_TYPE_VECTOR
        | ColumnType::MYSQL_TYPE_TIMESTAMP2
        | ColumnType::MYSQL_TYPE_DATETIME2
        | ColumnType::MYSQL_TYPE_TIME2
        | ColumnType::MYSQL_TYPE_TYPED_ARRAY
        | ColumnType::MYSQL_TYPE_UNKNOWN
        | ColumnType::MYSQL_TYPE_ENUM
        | ColumnType::MYSQL_TYPE_SET
        | ColumnType::MYSQL_TYPE_GEOMETRY => {
            unimplemented!("Hit an unsupported type: {:?}", column_type)
        }
    }
}