use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Timelike, Utc};
use crate::error::{Error, Result};
use crate::protocol::types::{Oid, oid};
use super::{FromWireValue, ToWireValue};
const PG_EPOCH: NaiveDate = NaiveDate::from_ymd_opt(2000, 1, 1).unwrap();
const PG_EPOCH_DT: NaiveDateTime =
NaiveDateTime::new(PG_EPOCH, NaiveTime::from_hms_opt(0, 0, 0).unwrap());
const USECS_PER_SEC: i64 = 1_000_000;
#[cfg(test)]
const USECS_PER_DAY: i64 = 86_400_000_000;
impl FromWireValue<'_> for NaiveDate {
fn from_text(oid: Oid, bytes: &[u8]) -> Result<Self> {
if oid != oid::DATE {
return Err(Error::Decode(format!(
"cannot decode oid {} as NaiveDate",
oid
)));
}
let s = simdutf8::compat::from_utf8(bytes)
.map_err(|e| Error::Decode(format!("invalid UTF-8: {}", e)))?;
NaiveDate::parse_from_str(s, "%Y-%m-%d")
.map_err(|e| Error::Decode(format!("invalid date: {}", e)))
}
fn from_binary(oid: Oid, bytes: &[u8]) -> Result<Self> {
if oid != oid::DATE {
return Err(Error::Decode(format!(
"cannot decode oid {} as NaiveDate",
oid
)));
}
let arr: [u8; 4] = bytes.try_into().map_err(|_unhelpful_err| {
Error::Decode(format!("invalid Date length: {}", bytes.len()))
})?;
let pg_days = i32::from_be_bytes(arr);
PG_EPOCH
.checked_add_days(chrono::Days::new(pg_days.max(0) as u64))
.or_else(|| PG_EPOCH.checked_sub_days(chrono::Days::new((-pg_days).max(0) as u64)))
.ok_or_else(|| Error::Decode("date overflow".into()))
}
}
impl ToWireValue for NaiveDate {
fn natural_oid(&self) -> Oid {
oid::DATE
}
fn encode(&self, target_oid: Oid, buf: &mut Vec<u8>) -> Result<()> {
match target_oid {
oid::DATE => {
let pg_days = self.signed_duration_since(PG_EPOCH).num_days() as i32;
buf.extend_from_slice(&4_i32.to_be_bytes());
buf.extend_from_slice(&pg_days.to_be_bytes());
Ok(())
}
_ => Err(Error::type_mismatch(self.natural_oid(), target_oid)),
}
}
}
impl FromWireValue<'_> for NaiveTime {
fn from_text(oid: Oid, bytes: &[u8]) -> Result<Self> {
if oid != oid::TIME {
return Err(Error::Decode(format!(
"cannot decode oid {} as NaiveTime",
oid
)));
}
let s = simdutf8::compat::from_utf8(bytes)
.map_err(|e| Error::Decode(format!("invalid UTF-8: {}", e)))?;
NaiveTime::parse_from_str(s, "%H:%M:%S%.f")
.or_else(|_| NaiveTime::parse_from_str(s, "%H:%M:%S"))
.map_err(|e| Error::Decode(format!("invalid time: {}", e)))
}
fn from_binary(oid: Oid, bytes: &[u8]) -> Result<Self> {
if oid != oid::TIME {
return Err(Error::Decode(format!(
"cannot decode oid {} as NaiveTime",
oid
)));
}
let arr: [u8; 8] = bytes.try_into().map_err(|_unhelpful_err| {
Error::Decode(format!("invalid Time length: {}", bytes.len()))
})?;
let usecs = i64::from_be_bytes(arr);
let secs = (usecs / USECS_PER_SEC) as u32;
let nano = ((usecs % USECS_PER_SEC) * 1000) as u32;
NaiveTime::from_num_seconds_from_midnight_opt(secs, nano)
.ok_or_else(|| Error::Decode("invalid time".into()))
}
}
impl ToWireValue for NaiveTime {
fn natural_oid(&self) -> Oid {
oid::TIME
}
fn encode(&self, target_oid: Oid, buf: &mut Vec<u8>) -> Result<()> {
match target_oid {
oid::TIME => {
let usecs = (self.num_seconds_from_midnight() as i64) * USECS_PER_SEC
+ (self.nanosecond() as i64) / 1000;
buf.extend_from_slice(&8_i32.to_be_bytes());
buf.extend_from_slice(&usecs.to_be_bytes());
Ok(())
}
_ => Err(Error::type_mismatch(self.natural_oid(), target_oid)),
}
}
}
impl FromWireValue<'_> for NaiveDateTime {
fn from_text(oid: Oid, bytes: &[u8]) -> Result<Self> {
if !matches!(oid, oid::TIMESTAMP | oid::TIMESTAMPTZ) {
return Err(Error::Decode(format!(
"cannot decode oid {} as NaiveDateTime",
oid
)));
}
let s = simdutf8::compat::from_utf8(bytes)
.map_err(|e| Error::Decode(format!("invalid UTF-8: {}", e)))?;
let s = s
.find(['+', '-'])
.filter(|&pos| pos > 10)
.map(|pos| &s[..pos])
.unwrap_or(s);
NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S%.f")
.or_else(|_| NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S"))
.map_err(|e| Error::Decode(format!("invalid timestamp: {}", e)))
}
fn from_binary(oid: Oid, bytes: &[u8]) -> Result<Self> {
if !matches!(oid, oid::TIMESTAMP | oid::TIMESTAMPTZ) {
return Err(Error::Decode(format!(
"cannot decode oid {} as NaiveDateTime",
oid
)));
}
let arr: [u8; 8] = bytes.try_into().map_err(|_unhelpful_err| {
Error::Decode(format!("invalid Timestamp length: {}", bytes.len()))
})?;
let usecs = i64::from_be_bytes(arr);
let pg_epoch_dt = PG_EPOCH
.and_hms_opt(0, 0, 0)
.ok_or_else(|| Error::Decode("invalid epoch".into()))?;
let secs = usecs / USECS_PER_SEC;
let nano = ((usecs % USECS_PER_SEC).abs() * 1000) as u32;
pg_epoch_dt
.checked_add_signed(chrono::Duration::seconds(secs))
.and_then(|dt| {
if usecs >= 0 {
dt.with_nanosecond(nano)
} else {
dt.checked_sub_signed(chrono::Duration::nanoseconds(nano as i64))
}
})
.ok_or_else(|| Error::Decode("timestamp overflow".into()))
}
}
impl ToWireValue for NaiveDateTime {
fn natural_oid(&self) -> Oid {
oid::TIMESTAMP
}
fn encode(&self, target_oid: Oid, buf: &mut Vec<u8>) -> Result<()> {
match target_oid {
oid::TIMESTAMP | oid::TIMESTAMPTZ => {
let duration = self.signed_duration_since(PG_EPOCH_DT);
let usecs = duration.num_microseconds().unwrap_or(i64::MAX);
buf.extend_from_slice(&8_i32.to_be_bytes());
buf.extend_from_slice(&usecs.to_be_bytes());
Ok(())
}
_ => Err(Error::type_mismatch(self.natural_oid(), target_oid)),
}
}
}
impl FromWireValue<'_> for DateTime<Utc> {
fn from_text(oid: Oid, bytes: &[u8]) -> Result<Self> {
if oid != oid::TIMESTAMPTZ {
return Err(Error::Decode(format!(
"cannot decode oid {} as DateTime<Utc>",
oid
)));
}
let s = simdutf8::compat::from_utf8(bytes)
.map_err(|e| Error::Decode(format!("invalid UTF-8: {}", e)))?;
DateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S%.f%#z")
.or_else(|_| DateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S%#z"))
.map(|dt| dt.with_timezone(&Utc))
.map_err(|e| Error::Decode(format!("invalid timestamptz: {}", e)))
}
fn from_binary(oid: Oid, bytes: &[u8]) -> Result<Self> {
if oid != oid::TIMESTAMPTZ {
return Err(Error::Decode(format!(
"cannot decode oid {} as DateTime<Utc>",
oid
)));
}
let arr: [u8; 8] = bytes.try_into().map_err(|_unhelpful_err| {
Error::Decode(format!("invalid Timestamp length: {}", bytes.len()))
})?;
let usecs = i64::from_be_bytes(arr);
let pg_epoch_utc = PG_EPOCH
.and_hms_opt(0, 0, 0)
.ok_or_else(|| Error::Decode("invalid epoch".into()))?
.and_utc();
let secs = usecs / USECS_PER_SEC;
let nano = ((usecs % USECS_PER_SEC).abs() * 1000) as u32;
pg_epoch_utc
.checked_add_signed(chrono::Duration::seconds(secs))
.and_then(|dt| {
if usecs >= 0 {
dt.with_nanosecond(nano)
} else {
dt.checked_sub_signed(chrono::Duration::nanoseconds(nano as i64))
}
})
.ok_or_else(|| Error::Decode("timestamp overflow".into()))
}
}
impl ToWireValue for DateTime<Utc> {
fn natural_oid(&self) -> Oid {
oid::TIMESTAMPTZ
}
fn encode(&self, target_oid: Oid, buf: &mut Vec<u8>) -> Result<()> {
match target_oid {
oid::TIMESTAMP | oid::TIMESTAMPTZ => {
let duration = self.signed_duration_since(PG_EPOCH_DT.and_utc());
let usecs = duration.num_microseconds().unwrap_or(i64::MAX);
buf.extend_from_slice(&8_i32.to_be_bytes());
buf.extend_from_slice(&usecs.to_be_bytes());
Ok(())
}
_ => Err(Error::type_mismatch(self.natural_oid(), target_oid)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::{Datelike, Timelike};
#[test]
fn date_text() {
let date = NaiveDate::from_text(oid::DATE, b"2024-01-15").unwrap();
assert_eq!(date.year(), 2024);
assert_eq!(date.month(), 1);
assert_eq!(date.day(), 15);
}
#[test]
fn date_binary() {
let days: i32 = 8780;
let bytes = days.to_be_bytes();
let date = NaiveDate::from_binary(oid::DATE, &bytes).unwrap();
assert_eq!(date.year(), 2024);
assert_eq!(date.month(), 1);
assert_eq!(date.day(), 15);
}
#[test]
fn date_roundtrip() {
let original = NaiveDate::from_ymd_opt(2024, 1, 15).unwrap();
let mut buf = Vec::new();
original.encode(original.natural_oid(), &mut buf).unwrap();
let decoded = NaiveDate::from_binary(oid::DATE, &buf[4..]).unwrap();
assert_eq!(original, decoded);
}
#[test]
fn time_text() {
let time = NaiveTime::from_text(oid::TIME, b"10:30:45").unwrap();
assert_eq!(time.hour(), 10);
assert_eq!(time.minute(), 30);
assert_eq!(time.second(), 45);
}
#[test]
fn time_text_with_micros() {
let time = NaiveTime::from_text(oid::TIME, b"10:30:45.123456").unwrap();
assert_eq!(time.hour(), 10);
assert_eq!(time.minute(), 30);
assert_eq!(time.second(), 45);
assert_eq!(time.nanosecond(), 123456000);
}
#[test]
fn time_binary() {
let usecs: i64 = (10 * 3600 + 30 * 60 + 45) * 1_000_000;
let bytes = usecs.to_be_bytes();
let time = NaiveTime::from_binary(oid::TIME, &bytes).unwrap();
assert_eq!(time.hour(), 10);
assert_eq!(time.minute(), 30);
assert_eq!(time.second(), 45);
}
#[test]
fn time_roundtrip() {
let original = NaiveTime::from_hms_micro_opt(10, 30, 45, 123456).unwrap();
let mut buf = Vec::new();
original.encode(original.natural_oid(), &mut buf).unwrap();
let decoded = NaiveTime::from_binary(oid::TIME, &buf[4..]).unwrap();
assert_eq!(original, decoded);
}
#[test]
fn timestamp_text() {
let ts = NaiveDateTime::from_text(oid::TIMESTAMP, b"2024-01-15 10:30:45").unwrap();
assert_eq!(ts.year(), 2024);
assert_eq!(ts.month(), 1);
assert_eq!(ts.day(), 15);
assert_eq!(ts.hour(), 10);
assert_eq!(ts.minute(), 30);
assert_eq!(ts.second(), 45);
}
#[test]
fn timestamp_binary() {
let day_usecs: i64 = 8780 * USECS_PER_DAY;
let time_usecs: i64 = (10 * 3600 + 30 * 60 + 45) * USECS_PER_SEC;
let total_usecs = day_usecs + time_usecs;
let bytes = total_usecs.to_be_bytes();
let ts = NaiveDateTime::from_binary(oid::TIMESTAMP, &bytes).unwrap();
assert_eq!(ts.year(), 2024);
assert_eq!(ts.month(), 1);
assert_eq!(ts.day(), 15);
assert_eq!(ts.hour(), 10);
assert_eq!(ts.minute(), 30);
assert_eq!(ts.second(), 45);
}
#[test]
fn timestamp_roundtrip() {
let original = NaiveDate::from_ymd_opt(2024, 1, 15)
.unwrap()
.and_hms_micro_opt(10, 30, 45, 123456)
.unwrap();
let mut buf = Vec::new();
original.encode(original.natural_oid(), &mut buf).unwrap();
let decoded = NaiveDateTime::from_binary(oid::TIMESTAMP, &buf[4..]).unwrap();
assert_eq!(original, decoded);
}
#[test]
fn timestamptz_binary() {
let day_usecs: i64 = 8780 * USECS_PER_DAY;
let time_usecs: i64 = (10 * 3600 + 30 * 60 + 45) * USECS_PER_SEC;
let total_usecs = day_usecs + time_usecs;
let bytes = total_usecs.to_be_bytes();
let ts = DateTime::<Utc>::from_binary(oid::TIMESTAMPTZ, &bytes).unwrap();
assert_eq!(ts.year(), 2024);
assert_eq!(ts.month(), 1);
assert_eq!(ts.day(), 15);
assert_eq!(ts.hour(), 10);
assert_eq!(ts.minute(), 30);
assert_eq!(ts.second(), 45);
}
#[test]
fn timestamptz_roundtrip() {
let original = Utc::now();
let original = original
.with_nanosecond((original.nanosecond() / 1000) * 1000)
.unwrap();
let mut buf = Vec::new();
original.encode(original.natural_oid(), &mut buf).unwrap();
let decoded = DateTime::<Utc>::from_binary(oid::TIMESTAMPTZ, &buf[4..]).unwrap();
assert_eq!(original, decoded);
}
}