use bytes::Buf;
use chrono::{
DateTime, Datelike, Local, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Timelike, Utc,
};
use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
use crate::error::{BoxDynError, UnexpectedNullError};
use crate::protocol::text::ColumnType;
use crate::type_info::MySqlTypeInfo;
use crate::types::Type;
use crate::{MySql, MySqlValueFormat, MySqlValueRef};
impl Type<MySql> for DateTime<Utc> {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::binary(ColumnType::Timestamp)
}
fn compatible(ty: &MySqlTypeInfo) -> bool {
matches!(ty.r#type, ColumnType::Datetime | ColumnType::Timestamp)
}
}
impl Encode<'_, MySql> for DateTime<Utc> {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
Encode::<MySql>::encode(&self.naive_utc(), buf)
}
}
impl<'r> Decode<'r, MySql> for DateTime<Utc> {
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
let naive: NaiveDateTime = Decode::<MySql>::decode(value)?;
Ok(Utc.from_utc_datetime(&naive))
}
}
impl Type<MySql> for DateTime<Local> {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::binary(ColumnType::Timestamp)
}
fn compatible(ty: &MySqlTypeInfo) -> bool {
matches!(ty.r#type, ColumnType::Datetime | ColumnType::Timestamp)
}
}
impl Encode<'_, MySql> for DateTime<Local> {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
Encode::<MySql>::encode(&self.naive_utc(), buf)
}
}
impl<'r> Decode<'r, MySql> for DateTime<Local> {
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
Ok(<DateTime<Utc> as Decode<'r, MySql>>::decode(value)?.with_timezone(&Local))
}
}
impl Type<MySql> for NaiveTime {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::binary(ColumnType::Time)
}
}
impl Encode<'_, MySql> for NaiveTime {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
let len = Encode::<MySql>::size_hint(self) - 1;
buf.push(len as u8);
buf.push(0);
buf.extend_from_slice(&[0_u8; 4]);
encode_time(self, len > 9, buf);
IsNull::No
}
fn size_hint(&self) -> usize {
if self.nanosecond() == 0 {
9
} else {
13
}
}
}
impl<'r> Decode<'r, MySql> for NaiveTime {
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
match value.format() {
MySqlValueFormat::Binary => {
let mut buf = value.as_bytes()?;
let len = buf.get_u8();
if len == 0 {
return Ok(NaiveTime::from_hms_micro_opt(0, 0, 0, 0)
.expect("expected NaiveTime to construct from all zeroes"));
}
let is_negative = buf.get_u8();
debug_assert_eq!(is_negative, 0, "Negative dates/times are not supported");
buf.advance(4);
decode_time(len - 5, buf)
}
MySqlValueFormat::Text => {
let s = value.as_str()?;
NaiveTime::parse_from_str(s, "%H:%M:%S%.f").map_err(Into::into)
}
}
}
}
impl Type<MySql> for NaiveDate {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::binary(ColumnType::Date)
}
}
impl Encode<'_, MySql> for NaiveDate {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
buf.push(4);
encode_date(self, buf);
IsNull::No
}
fn size_hint(&self) -> usize {
5
}
}
impl<'r> Decode<'r, MySql> for NaiveDate {
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
match value.format() {
MySqlValueFormat::Binary => {
decode_date(&value.as_bytes()?[1..])?.ok_or_else(|| UnexpectedNullError.into())
}
MySqlValueFormat::Text => {
let s = value.as_str()?;
NaiveDate::parse_from_str(s, "%Y-%m-%d").map_err(Into::into)
}
}
}
}
impl Type<MySql> for NaiveDateTime {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::binary(ColumnType::Datetime)
}
}
impl Encode<'_, MySql> for NaiveDateTime {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
let len = Encode::<MySql>::size_hint(self) - 1;
buf.push(len as u8);
encode_date(&self.date(), buf);
if len > 4 {
encode_time(&self.time(), len > 8, buf);
}
IsNull::No
}
fn size_hint(&self) -> usize {
match (
self.hour(),
self.minute(),
self.second(),
self.timestamp_subsec_nanos(),
) {
(0, 0, 0, 0) => 5,
(_, _, _, 0) => 8,
(_, _, _, _) => 12,
}
}
}
impl<'r> Decode<'r, MySql> for NaiveDateTime {
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
match value.format() {
MySqlValueFormat::Binary => {
let buf = value.as_bytes()?;
let len = buf[0];
let date = decode_date(&buf[1..])?.ok_or(UnexpectedNullError)?;
let dt = if len > 4 {
date.and_time(decode_time(len - 4, &buf[5..])?)
} else {
date.and_hms_opt(0, 0, 0)
.expect("expected `NaiveDate::and_hms_opt(0, 0, 0)` to be valid")
};
Ok(dt)
}
MySqlValueFormat::Text => {
let s = value.as_str()?;
NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S%.f").map_err(Into::into)
}
}
}
}
fn encode_date(date: &NaiveDate, buf: &mut Vec<u8>) {
let year = u16::try_from(date.year())
.unwrap_or_else(|_| panic!("NaiveDateTime out of range for Mysql: {date}"));
buf.extend_from_slice(&year.to_le_bytes());
buf.push(date.month() as u8);
buf.push(date.day() as u8);
}
fn decode_date(mut buf: &[u8]) -> Result<Option<NaiveDate>, BoxDynError> {
match buf.len() {
0 => Ok(None),
4.. => {
let year = buf.get_u16_le() as i32;
let month = buf[0] as u32;
let day = buf[1] as u32;
let date = NaiveDate::from_ymd_opt(year, month, day)
.ok_or_else(|| format!("server returned invalid date: {year}/{month}/{day}"))?;
Ok(Some(date))
}
len => Err(format!("expected at least 4 bytes for date, got {len}").into()),
}
}
fn encode_time(time: &NaiveTime, include_micros: bool, buf: &mut Vec<u8>) {
buf.push(time.hour() as u8);
buf.push(time.minute() as u8);
buf.push(time.second() as u8);
if include_micros {
buf.extend(&((time.nanosecond() / 1000) as u32).to_le_bytes());
}
}
fn decode_time(len: u8, mut buf: &[u8]) -> Result<NaiveTime, BoxDynError> {
let hour = buf.get_u8();
let minute = buf.get_u8();
let seconds = buf.get_u8();
let micros = if len > 3 {
buf.get_uint_le(buf.len())
} else {
0
};
NaiveTime::from_hms_micro_opt(hour as u32, minute as u32, seconds as u32, micros as u32)
.ok_or_else(|| format!("server returned invalid time: {hour:02}:{minute:02}:{seconds:02}; micros: {micros}").into())
}