Skip to main content

diesel_async/mysql/
row.rs

1use diesel::backend::Backend;
2use diesel::mysql::data_types::{MysqlTime, MysqlTimestampType};
3use diesel::mysql::{Mysql, MysqlType, MysqlValue};
4use diesel::row::{PartialRow, RowIndex, RowSealed};
5use mysql_async::consts::{ColumnFlags, ColumnType};
6use mysql_async::{Column, Row, Value};
7use std::borrow::Cow;
8
9pub struct MysqlRow(pub(super) Row);
10
11impl mysql_async::prelude::FromRow for MysqlRow {
12    fn from_row_opt(row: Row) -> Result<Self, mysql_async::FromRowError>
13    where
14        Self: Sized,
15    {
16        Ok(Self(row))
17    }
18}
19
20impl RowIndex<usize> for MysqlRow {
21    fn idx(&self, idx: usize) -> Option<usize> {
22        if idx < self.0.columns_ref().len() {
23            Some(idx)
24        } else {
25            None
26        }
27    }
28}
29
30impl<'a> RowIndex<&'a str> for MysqlRow {
31    fn idx(&self, idx: &'a str) -> Option<usize> {
32        self.0.columns().iter().position(|c| c.name_str() == idx)
33    }
34}
35
36impl RowSealed for MysqlRow {}
37
38impl<'a> diesel::row::Row<'a, Mysql> for MysqlRow {
39    type InnerPartialRow = Self;
40    type Field<'b>
41        = MysqlField<'b>
42    where
43        Self: 'b,
44        'a: 'b;
45
46    fn field_count(&self) -> usize {
47        self.0.columns_ref().len()
48    }
49
50    fn get<'b, I>(&'b self, idx: I) -> Option<Self::Field<'b>>
51    where
52        'a: 'b,
53        Self: diesel::row::RowIndex<I>,
54    {
55        let idx = diesel::row::RowIndex::idx(self, idx)?;
56        let value = self.0.as_ref(idx)?;
57        let column = &self.0.columns_ref()[idx];
58        let buffer = match value {
59            Value::NULL => None,
60            Value::Bytes(b) => {
61                // deserialize gets the length prepended, so we just use that buffer
62                // directly
63                Some(Cow::Borrowed(b as &[_]))
64            }
65            Value::Time(neg, day, hour, minute, second, second_part) => {
66                let date = MysqlTime::new(
67                    0,
68                    0,
69                    *day as _,
70                    *hour as _,
71                    *minute as _,
72                    *second as _,
73                    *second_part as _,
74                    *neg as _,
75                    MysqlTimestampType::MYSQL_TIMESTAMP_TIME,
76                    0,
77                );
78                let buffer = date.serialize().to_vec();
79                Some(Cow::Owned(buffer))
80            }
81            Value::Date(year, month, day, hour, minute, second, second_part) => {
82                let date = MysqlTime::new(
83                    *year as _,
84                    *month as _,
85                    *day as _,
86                    *hour as _,
87                    *minute as _,
88                    *second as _,
89                    *second_part as _,
90                    false,
91                    MysqlTimestampType::MYSQL_TIMESTAMP_DATETIME,
92                    0,
93                );
94                let buffer = date.serialize().to_vec();
95                Some(Cow::Owned(buffer))
96            }
97            _t => {
98                let mut buffer = Vec::with_capacity(
99                    value
100                        .bin_len()
101                        .try_into()
102                        .expect("Failed to cast byte size to usize"),
103                );
104                mysql_common::proto::MySerialize::serialize(value, &mut buffer);
105                Some(Cow::Owned(buffer))
106            }
107        };
108        let field = MysqlField {
109            value: buffer,
110            column,
111            name: column.name_str(),
112        };
113        Some(field)
114    }
115
116    fn partial_row(&self, range: std::ops::Range<usize>) -> PartialRow<'_, Self::InnerPartialRow> {
117        PartialRow::new(self, range)
118    }
119}
120
121pub struct MysqlField<'a> {
122    value: Option<Cow<'a, [u8]>>,
123    column: &'a Column,
124    name: Cow<'a, str>,
125}
126
127impl diesel::row::Field<'_, Mysql> for MysqlField<'_> {
128    fn field_name(&self) -> Option<&str> {
129        Some(&*self.name)
130    }
131
132    fn value(&self) -> Option<<Mysql as Backend>::RawValue<'_>> {
133        self.value.as_ref().map(|v| {
134            MysqlValue::new(
135                v,
136                convert_type(self.column.column_type(), self.column.flags()),
137            )
138        })
139    }
140}
141
142fn convert_type(column_type: ColumnType, column_flags: ColumnFlags) -> MysqlType {
143    match column_type {
144        ColumnType::MYSQL_TYPE_NEWDECIMAL | ColumnType::MYSQL_TYPE_DECIMAL => MysqlType::Numeric,
145        ColumnType::MYSQL_TYPE_TINY if column_flags.contains(ColumnFlags::UNSIGNED_FLAG) => {
146            MysqlType::UnsignedTiny
147        }
148        ColumnType::MYSQL_TYPE_TINY => MysqlType::Tiny,
149        ColumnType::MYSQL_TYPE_YEAR | ColumnType::MYSQL_TYPE_SHORT
150            if column_flags.contains(ColumnFlags::UNSIGNED_FLAG) =>
151        {
152            MysqlType::UnsignedShort
153        }
154        ColumnType::MYSQL_TYPE_YEAR | ColumnType::MYSQL_TYPE_SHORT => MysqlType::Short,
155        ColumnType::MYSQL_TYPE_INT24 | ColumnType::MYSQL_TYPE_LONG
156            if column_flags.contains(ColumnFlags::UNSIGNED_FLAG) =>
157        {
158            MysqlType::UnsignedLong
159        }
160        ColumnType::MYSQL_TYPE_INT24 | ColumnType::MYSQL_TYPE_LONG => MysqlType::Long,
161        ColumnType::MYSQL_TYPE_LONGLONG if column_flags.contains(ColumnFlags::UNSIGNED_FLAG) => {
162            MysqlType::UnsignedLongLong
163        }
164        ColumnType::MYSQL_TYPE_LONGLONG => MysqlType::LongLong,
165        ColumnType::MYSQL_TYPE_FLOAT => MysqlType::Float,
166        ColumnType::MYSQL_TYPE_DOUBLE => MysqlType::Double,
167
168        ColumnType::MYSQL_TYPE_TIMESTAMP => MysqlType::Timestamp,
169        ColumnType::MYSQL_TYPE_DATE => MysqlType::Date,
170        ColumnType::MYSQL_TYPE_TIME => MysqlType::Time,
171        ColumnType::MYSQL_TYPE_DATETIME => MysqlType::DateTime,
172        ColumnType::MYSQL_TYPE_BIT => MysqlType::Bit,
173        ColumnType::MYSQL_TYPE_JSON => MysqlType::String,
174
175        ColumnType::MYSQL_TYPE_VAR_STRING
176        | ColumnType::MYSQL_TYPE_STRING
177        | ColumnType::MYSQL_TYPE_TINY_BLOB
178        | ColumnType::MYSQL_TYPE_MEDIUM_BLOB
179        | ColumnType::MYSQL_TYPE_LONG_BLOB
180        | ColumnType::MYSQL_TYPE_BLOB
181            if column_flags.contains(ColumnFlags::ENUM_FLAG) =>
182        {
183            MysqlType::Enum
184        }
185        ColumnType::MYSQL_TYPE_VAR_STRING
186        | ColumnType::MYSQL_TYPE_STRING
187        | ColumnType::MYSQL_TYPE_TINY_BLOB
188        | ColumnType::MYSQL_TYPE_MEDIUM_BLOB
189        | ColumnType::MYSQL_TYPE_LONG_BLOB
190        | ColumnType::MYSQL_TYPE_BLOB
191            if column_flags.contains(ColumnFlags::SET_FLAG) =>
192        {
193            MysqlType::Set
194        }
195
196        ColumnType::MYSQL_TYPE_VAR_STRING
197        | ColumnType::MYSQL_TYPE_STRING
198        | ColumnType::MYSQL_TYPE_TINY_BLOB
199        | ColumnType::MYSQL_TYPE_MEDIUM_BLOB
200        | ColumnType::MYSQL_TYPE_LONG_BLOB
201        | ColumnType::MYSQL_TYPE_BLOB
202            if column_flags.contains(ColumnFlags::BINARY_FLAG) =>
203        {
204            MysqlType::Blob
205        }
206
207        ColumnType::MYSQL_TYPE_VAR_STRING
208        | ColumnType::MYSQL_TYPE_STRING
209        | ColumnType::MYSQL_TYPE_TINY_BLOB
210        | ColumnType::MYSQL_TYPE_MEDIUM_BLOB
211        | ColumnType::MYSQL_TYPE_LONG_BLOB
212        | ColumnType::MYSQL_TYPE_BLOB => MysqlType::String,
213
214        ColumnType::MYSQL_TYPE_NULL
215        | ColumnType::MYSQL_TYPE_NEWDATE
216        | ColumnType::MYSQL_TYPE_VARCHAR
217        | ColumnType::MYSQL_TYPE_TIMESTAMP2
218        | ColumnType::MYSQL_TYPE_DATETIME2
219        | ColumnType::MYSQL_TYPE_TIME2
220        | ColumnType::MYSQL_TYPE_TYPED_ARRAY
221        | ColumnType::MYSQL_TYPE_UNKNOWN
222        | ColumnType::MYSQL_TYPE_ENUM
223        | ColumnType::MYSQL_TYPE_SET
224        | ColumnType::MYSQL_TYPE_VECTOR
225        | ColumnType::MYSQL_TYPE_GEOMETRY => {
226            unimplemented!("Hit an unsupported type: {:?}", column_type)
227        }
228    }
229}