use crate::protocol::text::ColumnType;
use crate::{MySql, MySqlTypeInfo, MySqlValueFormat};
use bytes::{Buf, BufMut};
use sqlx_core::database::Database;
use sqlx_core::decode::Decode;
use sqlx_core::encode::{Encode, IsNull};
use sqlx_core::error::BoxDynError;
use sqlx_core::types::Type;
use std::cmp::Ordering;
use std::fmt::{Debug, Display, Formatter, Write};
use std::time::Duration;
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub struct MySqlTime {
pub(crate) sign: MySqlTimeSign,
pub(crate) magnitude: TimeMagnitude,
}
#[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)]
pub(crate) struct TimeMagnitude {
pub(crate) hours: u32,
pub(crate) minutes: u8,
pub(crate) seconds: u8,
pub(crate) microseconds: u32,
}
const MAGNITUDE_ZERO: TimeMagnitude = TimeMagnitude {
hours: 0,
minutes: 0,
seconds: 0,
microseconds: 0,
};
const MAGNITUDE_MAX: TimeMagnitude = TimeMagnitude {
hours: MySqlTime::HOURS_MAX,
minutes: 59,
seconds: 59,
microseconds: 0,
};
#[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)]
pub enum MySqlTimeSign {
Negative,
Positive,
}
#[derive(Debug, thiserror::Error)]
pub enum MySqlTimeError {
#[error("`MySqlTime` field `{field}` cannot exceed {max}, got {value}")]
FieldRange {
field: &'static str,
max: u32,
value: u64,
},
#[error(
"`MySqlTime` cannot exceed +/-838:59:59.000000; got {sign}838:59:59.{microseconds:06}"
)]
SubsecondExcess {
sign: MySqlTimeSign,
microseconds: u32,
truncated: MySqlTime,
},
#[error("attempted to construct a `MySqlTime` value of negative zero")]
NegativeZero,
}
impl MySqlTime {
pub const ZERO: Self = MySqlTime {
sign: MySqlTimeSign::Positive,
magnitude: MAGNITUDE_ZERO,
};
pub const MAX: Self = MySqlTime {
sign: MySqlTimeSign::Positive,
magnitude: MAGNITUDE_MAX,
};
pub const MIN: Self = MySqlTime {
sign: MySqlTimeSign::Negative,
magnitude: MAGNITUDE_MAX,
};
pub(crate) const HOURS_MAX: u32 = 838;
pub fn new(
sign: MySqlTimeSign,
hours: u32,
minutes: u8,
seconds: u8,
microseconds: u32,
) -> Result<Self, MySqlTimeError> {
macro_rules! check_fields {
($($name:ident: $max:expr),+ $(,)?) => {
$(
if $name > $max {
return Err(MySqlTimeError::FieldRange {
field: stringify!($name),
max: $max as u32,
value: $name as u64
})
}
)+
}
}
check_fields!(
hours: Self::HOURS_MAX,
minutes: 59,
seconds: 59,
microseconds: 999_999
);
let values = TimeMagnitude {
hours,
minutes,
seconds,
microseconds,
};
if sign.is_negative() && values == MAGNITUDE_ZERO {
return Err(MySqlTimeError::NegativeZero);
}
if values > MAGNITUDE_MAX {
return Err(MySqlTimeError::SubsecondExcess {
sign,
microseconds,
truncated: if sign.is_positive() {
Self::MAX
} else {
Self::MIN
},
});
}
Ok(Self {
sign,
magnitude: values,
})
}
pub fn with_sign(self, sign: MySqlTimeSign) -> Self {
Self { sign, ..self }
}
pub fn sign(&self) -> MySqlTimeSign {
self.sign
}
pub fn is_zero(&self) -> bool {
self == &Self::ZERO
}
pub fn is_positive(&self) -> bool {
self.sign.is_positive()
}
pub fn is_negative(&self) -> bool {
self.sign.is_positive()
}
pub fn is_valid_time_of_day(&self) -> bool {
self.sign.is_positive() && self.hours() < 24
}
pub fn hours(&self) -> u32 {
self.magnitude.hours
}
pub fn minutes(&self) -> u8 {
self.magnitude.minutes
}
pub fn seconds(&self) -> u8 {
self.magnitude.seconds
}
pub fn microseconds(&self) -> u32 {
self.magnitude.microseconds
}
pub fn to_duration(&self) -> Option<Duration> {
self.is_positive()
.then(|| Duration::new(self.whole_seconds() as u64, self.subsec_nanos()))
}
pub(crate) fn whole_seconds(&self) -> u32 {
self.hours() * 3600 + self.minutes() as u32 * 60 + self.seconds() as u32
}
#[cfg_attr(not(any(feature = "time", feature = "chrono")), allow(dead_code))]
pub(crate) fn whole_seconds_signed(&self) -> i64 {
self.whole_seconds() as i64 * self.sign.signum() as i64
}
pub(crate) fn subsec_nanos(&self) -> u32 {
self.microseconds() * 1000
}
fn encoded_len(&self) -> u8 {
if self.is_zero() {
0
} else if self.microseconds() == 0 {
8
} else {
12
}
}
}
impl PartialOrd<MySqlTime> for MySqlTime {
fn partial_cmp(&self, other: &MySqlTime) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for MySqlTime {
fn cmp(&self, other: &Self) -> Ordering {
if self.sign != other.sign {
return self.sign.cmp(&other.sign);
}
match self.sign {
MySqlTimeSign::Positive => self.magnitude.cmp(&other.magnitude),
MySqlTimeSign::Negative => other.magnitude.cmp(&self.magnitude),
}
}
}
impl Display for MySqlTime {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let TimeMagnitude {
hours,
minutes,
seconds,
microseconds,
} = self.magnitude;
Display::fmt(&self.sign(), f)?;
write!(f, "{hours}:{minutes:02}:{seconds:02}")?;
if f.precision().map_or(microseconds != 0, |it| it != 0) {
f.write_char('.')?;
let mut remaining_precision = f.precision();
let mut remainder = microseconds;
let mut power_of_10 = 10u32.pow(5);
while remainder > 0 && remaining_precision != Some(0) {
let digit = remainder / power_of_10;
remainder %= power_of_10;
power_of_10 /= 10;
write!(f, "{digit}")?;
if let Some(remaining_precision) = &mut remaining_precision {
*remaining_precision = remaining_precision.saturating_sub(1);
}
}
if let Some(precision) = remaining_precision.filter(|it| *it != 0) {
write!(f, "{:0precision$}", 0)?;
}
}
Ok(())
}
}
impl Type<MySql> for MySqlTime {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::binary(ColumnType::Time)
}
}
impl<'r> Decode<'r, MySql> for MySqlTime {
fn decode(value: <MySql as Database>::ValueRef<'r>) -> Result<Self, BoxDynError> {
match value.format() {
MySqlValueFormat::Binary => {
let mut buf = value.as_bytes()?;
if buf.is_empty() {
return Err("empty buffer".into());
}
let length = buf.get_u8();
if length == 0 {
return Ok(Self::ZERO);
}
if !matches!(buf.len(), 8 | 12) {
return Err(format!(
"expected 8 or 12 bytes for TIME value, got {}",
buf.len()
)
.into());
}
let sign = MySqlTimeSign::from_byte(buf.get_u8())?;
let days = buf.get_u32_le();
let hours = buf.get_u8();
let minutes = buf.get_u8();
let seconds = buf.get_u8();
let microseconds = if !buf.is_empty() { buf.get_u32_le() } else { 0 };
let whole_hours = days
.checked_mul(24)
.and_then(|days_to_hours| days_to_hours.checked_add(hours as u32))
.ok_or("overflow calculating whole hours from `days * 24 + hours`")?;
Ok(Self::new(
sign,
whole_hours,
minutes,
seconds,
microseconds,
)?)
}
MySqlValueFormat::Text => parse(value.as_str()?),
}
}
}
impl Encode<'_, MySql> for MySqlTime {
fn encode_by_ref(
&self,
buf: &mut <MySql as Database>::ArgumentBuffer,
) -> Result<IsNull, BoxDynError> {
if self.is_zero() {
buf.put_u8(0);
return Ok(IsNull::No);
}
buf.put_u8(self.encoded_len());
buf.put_u8(self.sign.to_byte());
let TimeMagnitude {
hours: whole_hours,
minutes,
seconds,
microseconds,
} = self.magnitude;
let days = whole_hours / 24;
let hours = (whole_hours % 24) as u8;
buf.put_u32_le(days);
buf.put_u8(hours);
buf.put_u8(minutes);
buf.put_u8(seconds);
if microseconds != 0 {
buf.put_u32_le(microseconds);
}
Ok(IsNull::No)
}
fn size_hint(&self) -> usize {
self.encoded_len() as usize + 1
}
}
impl TryFrom<Duration> for MySqlTime {
type Error = MySqlTimeError;
fn try_from(value: Duration) -> Result<Self, Self::Error> {
let hours = value.as_secs() / 3600;
let rem_seconds = value.as_secs() % 3600;
let minutes = (rem_seconds / 60) as u8;
let seconds = (rem_seconds % 60) as u8;
let microseconds = value.subsec_micros();
Self::new(
MySqlTimeSign::Positive,
hours.try_into().map_err(|_| MySqlTimeError::FieldRange {
field: "hours",
max: Self::HOURS_MAX,
value: hours,
})?,
minutes,
seconds,
microseconds,
)
}
}
impl MySqlTimeSign {
fn from_byte(b: u8) -> Result<Self, BoxDynError> {
match b {
0 => Ok(Self::Positive),
1 => Ok(Self::Negative),
other => Err(format!("expected 0 or 1 for TIME sign byte, got {other}").into()),
}
}
fn to_byte(self) -> u8 {
match self {
Self::Negative => 1,
Self::Positive => 0,
}
}
fn signum(&self) -> i32 {
match self {
Self::Negative => -1,
Self::Positive => 1,
}
}
pub fn is_positive(&self) -> bool {
matches!(self, Self::Positive)
}
pub fn is_negative(&self) -> bool {
matches!(self, Self::Negative)
}
}
impl Display for MySqlTimeSign {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::Positive if f.sign_plus() => f.write_char('+'),
Self::Negative => f.write_char('-'),
_ => Ok(()),
}
}
}
impl Type<MySql> for Duration {
fn type_info() -> MySqlTypeInfo {
MySqlTime::type_info()
}
}
impl<'r> Decode<'r, MySql> for Duration {
fn decode(value: <MySql as Database>::ValueRef<'r>) -> Result<Self, BoxDynError> {
let time = MySqlTime::decode(value)?;
time.to_duration().ok_or_else(|| {
format!("`std::time::Duration` can only decode positive TIME values; got {time}").into()
})
}
}
fn parse(text: &str) -> Result<MySqlTime, BoxDynError> {
let mut segments = text.split(':');
let hours = segments
.next()
.ok_or("expected hours segment, got nothing")?;
let minutes = segments
.next()
.ok_or("expected minutes segment, got nothing")?;
let seconds = segments
.next()
.ok_or("expected seconds segment, got nothing")?;
let hours: i32 = hours
.parse()
.map_err(|e| format!("error parsing hours from {text:?} (segment {hours:?}): {e}"))?;
let sign = if hours.is_negative() {
MySqlTimeSign::Negative
} else {
MySqlTimeSign::Positive
};
let hours = hours.unsigned_abs();
let minutes: u8 = minutes
.parse()
.map_err(|e| format!("error parsing minutes from {text:?} (segment {minutes:?}): {e}"))?;
let (seconds, microseconds): (u8, u32) = if let Some((seconds, microseconds)) =
seconds.split_once('.')
{
(
seconds.parse().map_err(|e| {
format!("error parsing seconds from {text:?} (segment {seconds:?}): {e}")
})?,
parse_microseconds(microseconds).map_err(|e| {
format!("error parsing microseconds from {text:?} (segment {microseconds:?}): {e}")
})?,
)
} else {
(
seconds.parse().map_err(|e| {
format!("error parsing seconds from {text:?} (segment {seconds:?}): {e}")
})?,
0,
)
};
Ok(MySqlTime::new(sign, hours, minutes, seconds, microseconds)?)
}
fn parse_microseconds(micros: &str) -> Result<u32, BoxDynError> {
const EXPECTED_DIGITS: usize = 6;
match micros.len() {
0 => Err("empty string".into()),
len @ ..=EXPECTED_DIGITS => {
let micros: u32 = micros.parse()?;
#[allow(clippy::cast_possible_truncation)]
Ok(micros * 10u32.pow((EXPECTED_DIGITS - len) as u32))
}
_ => Ok(micros[..EXPECTED_DIGITS].parse()?),
}
}
#[cfg(test)]
mod tests {
use super::MySqlTime;
use crate::types::MySqlTimeSign;
use super::parse_microseconds;
#[test]
fn test_display() {
assert_eq!(MySqlTime::ZERO.to_string(), "0:00:00");
assert_eq!(format!("{:.0}", MySqlTime::ZERO), "0:00:00");
assert_eq!(format!("{:.3}", MySqlTime::ZERO), "0:00:00.000");
assert_eq!(format!("{:.6}", MySqlTime::ZERO), "0:00:00.000000");
assert_eq!(format!("{:.9}", MySqlTime::ZERO), "0:00:00.000000000");
assert_eq!(format!("{:.0}", MySqlTime::MAX), "838:59:59");
assert_eq!(format!("{:.3}", MySqlTime::MAX), "838:59:59.000");
assert_eq!(format!("{:.6}", MySqlTime::MAX), "838:59:59.000000");
assert_eq!(format!("{:.9}", MySqlTime::MAX), "838:59:59.000000000");
assert_eq!(format!("{:+.0}", MySqlTime::MAX), "+838:59:59");
assert_eq!(format!("{:+.3}", MySqlTime::MAX), "+838:59:59.000");
assert_eq!(format!("{:+.6}", MySqlTime::MAX), "+838:59:59.000000");
assert_eq!(format!("{:+.9}", MySqlTime::MAX), "+838:59:59.000000000");
assert_eq!(format!("{:.0}", MySqlTime::MIN), "-838:59:59");
assert_eq!(format!("{:.3}", MySqlTime::MIN), "-838:59:59.000");
assert_eq!(format!("{:.6}", MySqlTime::MIN), "-838:59:59.000000");
assert_eq!(format!("{:.9}", MySqlTime::MIN), "-838:59:59.000000000");
let positive = MySqlTime::new(MySqlTimeSign::Positive, 123, 45, 56, 890011).unwrap();
assert_eq!(positive.to_string(), "123:45:56.890011");
assert_eq!(format!("{positive:.0}"), "123:45:56");
assert_eq!(format!("{positive:.3}"), "123:45:56.890");
assert_eq!(format!("{positive:.6}"), "123:45:56.890011");
assert_eq!(format!("{positive:.9}"), "123:45:56.890011000");
assert_eq!(format!("{positive:+.0}"), "+123:45:56");
assert_eq!(format!("{positive:+.3}"), "+123:45:56.890");
assert_eq!(format!("{positive:+.6}"), "+123:45:56.890011");
assert_eq!(format!("{positive:+.9}"), "+123:45:56.890011000");
let negative = MySqlTime::new(MySqlTimeSign::Negative, 123, 45, 56, 890011).unwrap();
assert_eq!(negative.to_string(), "-123:45:56.890011");
assert_eq!(format!("{negative:.0}"), "-123:45:56");
assert_eq!(format!("{negative:.3}"), "-123:45:56.890");
assert_eq!(format!("{negative:.6}"), "-123:45:56.890011");
assert_eq!(format!("{negative:.9}"), "-123:45:56.890011000");
}
#[test]
fn test_parse_microseconds() {
assert_eq!(parse_microseconds("010").unwrap(), 10_000);
assert_eq!(parse_microseconds("0100000000").unwrap(), 10_000);
assert_eq!(parse_microseconds("890").unwrap(), 890_000);
assert_eq!(parse_microseconds("0890").unwrap(), 89_000);
assert_eq!(
parse_microseconds("123456789").unwrap(),
123456,
);
}
}