diesel_async/mysql/
row.rs

1use diesel::backend::Backend;
2use diesel::mysql::data_types::{MysqlTime, MysqlTimestampType};
3use diesel::mysql::{Mysql, MysqlType, MysqlValue};
4use diesel::row::{PartialRow, RowIndex, RowSealed};
5use mysql_async::consts::{ColumnFlags, ColumnType};
6use mysql_async::{Column, Row, Value};
7use std::borrow::Cow;
8
9pub struct MysqlRow(pub(super) Row);
10
11impl mysql_async::prelude::FromRow for MysqlRow {
12    fn from_row_opt(row: Row) -> Result<Self, mysql_async::FromRowError>
13    where
14        Self: Sized,
15    {
16        Ok(Self(row))
17    }
18}
19
20impl RowIndex<usize> for MysqlRow {
21    fn idx(&self, idx: usize) -> Option<usize> {
22        if idx < self.0.columns_ref().len() {
23            Some(idx)
24        } else {
25            None
26        }
27    }
28}
29
30impl<'a> RowIndex<&'a str> for MysqlRow {
31    fn idx(&self, idx: &'a str) -> Option<usize> {
32        self.0.columns().iter().position(|c| c.name_str() == idx)
33    }
34}
35
36impl RowSealed for MysqlRow {}
37
38impl<'a> diesel::row::Row<'a, Mysql> for MysqlRow {
39    type InnerPartialRow = Self;
40    type Field<'b>
41        = MysqlField<'b>
42    where
43        Self: 'b,
44        'a: 'b;
45
46    fn field_count(&self) -> usize {
47        self.0.columns_ref().len()
48    }
49
50    fn get<'b, I>(&'b self, idx: I) -> Option<Self::Field<'b>>
51    where
52        'a: 'b,
53        Self: diesel::row::RowIndex<I>,
54    {
55        let idx = diesel::row::RowIndex::idx(self, idx)?;
56        let value = self.0.as_ref(idx)?;
57        let column = &self.0.columns_ref()[idx];
58        let buffer = match value {
59            Value::NULL => None,
60            Value::Bytes(b) => {
61                // deserialize gets the length prepended, so we just use that buffer
62                // directly
63                Some(Cow::Borrowed(b as &[_]))
64            }
65            Value::Time(neg, day, hour, minute, second, second_part) => {
66                let date = MysqlTime::new(
67                    0,
68                    0,
69                    *day as _,
70                    *hour as _,
71                    *minute as _,
72                    *second as _,
73                    *second_part as _,
74                    *neg as _,
75                    MysqlTimestampType::MYSQL_TIMESTAMP_TIME,
76                    0,
77                );
78                let buffer = unsafe {
79                    let ptr = &date as *const MysqlTime as *const u8;
80                    let slice = std::slice::from_raw_parts(ptr, std::mem::size_of::<MysqlTime>());
81                    slice.to_vec()
82                };
83                Some(Cow::Owned(buffer))
84            }
85            Value::Date(year, month, day, hour, minute, second, second_part) => {
86                let date = MysqlTime::new(
87                    *year as _,
88                    *month as _,
89                    *day as _,
90                    *hour as _,
91                    *minute as _,
92                    *second as _,
93                    *second_part as _,
94                    false,
95                    MysqlTimestampType::MYSQL_TIMESTAMP_DATETIME,
96                    0,
97                );
98                let buffer = unsafe {
99                    let ptr = &date as *const MysqlTime as *const u8;
100                    let slice = std::slice::from_raw_parts(ptr, std::mem::size_of::<MysqlTime>());
101                    slice.to_vec()
102                };
103                Some(Cow::Owned(buffer))
104            }
105            _t => {
106                let mut buffer = Vec::with_capacity(
107                    value
108                        .bin_len()
109                        .try_into()
110                        .expect("Failed to cast byte size to usize"),
111                );
112                mysql_common::proto::MySerialize::serialize(value, &mut buffer);
113                Some(Cow::Owned(buffer))
114            }
115        };
116        let field = MysqlField {
117            value: buffer,
118            column,
119            name: column.name_str(),
120        };
121        Some(field)
122    }
123
124    fn partial_row(&self, range: std::ops::Range<usize>) -> PartialRow<'_, Self::InnerPartialRow> {
125        PartialRow::new(self, range)
126    }
127}
128
129pub struct MysqlField<'a> {
130    value: Option<Cow<'a, [u8]>>,
131    column: &'a Column,
132    name: Cow<'a, str>,
133}
134
135impl diesel::row::Field<'_, Mysql> for MysqlField<'_> {
136    fn field_name(&self) -> Option<&str> {
137        Some(&*self.name)
138    }
139
140    fn value(&self) -> Option<<Mysql as Backend>::RawValue<'_>> {
141        self.value.as_ref().map(|v| {
142            MysqlValue::new(
143                v,
144                convert_type(self.column.column_type(), self.column.flags()),
145            )
146        })
147    }
148}
149
150fn convert_type(column_type: ColumnType, column_flags: ColumnFlags) -> MysqlType {
151    match column_type {
152        ColumnType::MYSQL_TYPE_NEWDECIMAL | ColumnType::MYSQL_TYPE_DECIMAL => MysqlType::Numeric,
153        ColumnType::MYSQL_TYPE_TINY if column_flags.contains(ColumnFlags::UNSIGNED_FLAG) => {
154            MysqlType::UnsignedTiny
155        }
156        ColumnType::MYSQL_TYPE_TINY => MysqlType::Tiny,
157        ColumnType::MYSQL_TYPE_YEAR | ColumnType::MYSQL_TYPE_SHORT
158            if column_flags.contains(ColumnFlags::UNSIGNED_FLAG) =>
159        {
160            MysqlType::UnsignedShort
161        }
162        ColumnType::MYSQL_TYPE_YEAR | ColumnType::MYSQL_TYPE_SHORT => MysqlType::Short,
163        ColumnType::MYSQL_TYPE_INT24 | ColumnType::MYSQL_TYPE_LONG
164            if column_flags.contains(ColumnFlags::UNSIGNED_FLAG) =>
165        {
166            MysqlType::UnsignedLong
167        }
168        ColumnType::MYSQL_TYPE_INT24 | ColumnType::MYSQL_TYPE_LONG => MysqlType::Long,
169        ColumnType::MYSQL_TYPE_LONGLONG if column_flags.contains(ColumnFlags::UNSIGNED_FLAG) => {
170            MysqlType::UnsignedLongLong
171        }
172        ColumnType::MYSQL_TYPE_LONGLONG => MysqlType::LongLong,
173        ColumnType::MYSQL_TYPE_FLOAT => MysqlType::Float,
174        ColumnType::MYSQL_TYPE_DOUBLE => MysqlType::Double,
175
176        ColumnType::MYSQL_TYPE_TIMESTAMP => MysqlType::Timestamp,
177        ColumnType::MYSQL_TYPE_DATE => MysqlType::Date,
178        ColumnType::MYSQL_TYPE_TIME => MysqlType::Time,
179        ColumnType::MYSQL_TYPE_DATETIME => MysqlType::DateTime,
180        ColumnType::MYSQL_TYPE_BIT => MysqlType::Bit,
181        ColumnType::MYSQL_TYPE_JSON => MysqlType::String,
182
183        ColumnType::MYSQL_TYPE_VAR_STRING
184        | ColumnType::MYSQL_TYPE_STRING
185        | ColumnType::MYSQL_TYPE_TINY_BLOB
186        | ColumnType::MYSQL_TYPE_MEDIUM_BLOB
187        | ColumnType::MYSQL_TYPE_LONG_BLOB
188        | ColumnType::MYSQL_TYPE_BLOB
189            if column_flags.contains(ColumnFlags::ENUM_FLAG) =>
190        {
191            MysqlType::Enum
192        }
193        ColumnType::MYSQL_TYPE_VAR_STRING
194        | ColumnType::MYSQL_TYPE_STRING
195        | ColumnType::MYSQL_TYPE_TINY_BLOB
196        | ColumnType::MYSQL_TYPE_MEDIUM_BLOB
197        | ColumnType::MYSQL_TYPE_LONG_BLOB
198        | ColumnType::MYSQL_TYPE_BLOB
199            if column_flags.contains(ColumnFlags::SET_FLAG) =>
200        {
201            MysqlType::Set
202        }
203
204        ColumnType::MYSQL_TYPE_VAR_STRING
205        | ColumnType::MYSQL_TYPE_STRING
206        | ColumnType::MYSQL_TYPE_TINY_BLOB
207        | ColumnType::MYSQL_TYPE_MEDIUM_BLOB
208        | ColumnType::MYSQL_TYPE_LONG_BLOB
209        | ColumnType::MYSQL_TYPE_BLOB
210            if column_flags.contains(ColumnFlags::BINARY_FLAG) =>
211        {
212            MysqlType::Blob
213        }
214
215        ColumnType::MYSQL_TYPE_VAR_STRING
216        | ColumnType::MYSQL_TYPE_STRING
217        | ColumnType::MYSQL_TYPE_TINY_BLOB
218        | ColumnType::MYSQL_TYPE_MEDIUM_BLOB
219        | ColumnType::MYSQL_TYPE_LONG_BLOB
220        | ColumnType::MYSQL_TYPE_BLOB => MysqlType::String,
221
222        ColumnType::MYSQL_TYPE_NULL
223        | ColumnType::MYSQL_TYPE_NEWDATE
224        | ColumnType::MYSQL_TYPE_VARCHAR
225        | ColumnType::MYSQL_TYPE_TIMESTAMP2
226        | ColumnType::MYSQL_TYPE_DATETIME2
227        | ColumnType::MYSQL_TYPE_TIME2
228        | ColumnType::MYSQL_TYPE_TYPED_ARRAY
229        | ColumnType::MYSQL_TYPE_UNKNOWN
230        | ColumnType::MYSQL_TYPE_ENUM
231        | ColumnType::MYSQL_TYPE_SET
232        | ColumnType::MYSQL_TYPE_VECTOR
233        | ColumnType::MYSQL_TYPE_GEOMETRY => {
234            unimplemented!("Hit an unsupported type: {:?}", column_type)
235        }
236    }
237}