#![allow(clippy::result_large_err)]
use crate::wit_bindgen;
use std::sync::Arc;
#[doc(hidden)]
pub mod wit {
#![allow(missing_docs)]
use crate::wit_bindgen;
wit_bindgen::generate!({
runtime_path: "crate::wit_bindgen::rt",
world: "spin-sdk-pg",
path: "wit",
generate_all,
});
pub use spin::postgres::postgres;
}
#[doc(inline)]
pub use wit::postgres::{
Column, DbDataType, DbError, DbValue, Error as PgError, ParameterValue, QueryError,
RangeBoundKind,
};
pub use wit::postgres::Interval;
use chrono::{Datelike, Timelike};
pub struct Connection(wit::postgres::Connection);
#[derive(Default)]
pub struct OpenOptions {
pub ca_root: Option<Certificate>,
}
pub enum Certificate {
FilePath(String),
Text(String),
}
impl Certificate {
fn load(self) -> Result<String, Error> {
match self {
Certificate::FilePath(path) => std::fs::read_to_string(path)
.map_err(|e| Error::PgError(PgError::Other(e.to_string()))),
Certificate::Text(text) => Ok(text),
}
}
}
impl Connection {
pub async fn open(address: impl Into<String>) -> Result<Self, Error> {
let inner = wit::postgres::Connection::open_async(address.into()).await?;
Ok(Self(inner))
}
pub async fn open_with_options(
address: impl AsRef<str>,
options: OpenOptions,
) -> Result<Self, Error> {
let builder = wit::postgres::ConnectionBuilder::new(address.as_ref());
let OpenOptions { ca_root } = options;
if let Some(ca_root) = ca_root {
let ca_root_text = ca_root.load()?;
builder.set_ca_root(&ca_root_text)?;
}
let inner = builder.build_async().await?;
Ok(Self(inner))
}
pub async fn query(
&self,
statement: impl Into<String>,
params: impl Into<Vec<ParameterValue>>,
) -> Result<QueryResult, Error> {
let (columns, rows, result) = self.0.query_async(statement.into(), params.into()).await?;
Ok(QueryResult {
columns: Arc::new(columns),
rows,
result,
})
}
pub async fn execute(
&self,
statement: impl Into<String>,
params: impl Into<Vec<ParameterValue>>,
) -> Result<u64, Error> {
self.0
.execute_async(statement.into(), params.into())
.await
.map_err(Error::PgError)
}
pub fn into_inner(self) -> wit::postgres::Connection {
self.0
}
}
pub struct QueryResult {
columns: Arc<Vec<Column>>,
rows: wit_bindgen::StreamReader<Vec<DbValue>>,
result: wit_bindgen::FutureReader<Result<(), PgError>>,
}
impl QueryResult {
pub fn columns(&self) -> &[Column] {
&self.columns
}
pub async fn next(&mut self) -> Option<Row> {
self.rows.next().await.map(|r| Row {
columns: self.columns.clone(),
result: r,
})
}
pub async fn result(self) -> Result<(), Error> {
self.result.await.map_err(Error::PgError)
}
pub async fn collect(mut self) -> Result<Vec<Row>, Error> {
let mut rows = vec![];
while let Some(row) = self.next().await {
rows.push(row);
}
self.result.await.map_err(Error::PgError)?;
Ok(rows)
}
pub fn rows(&mut self) -> &mut wit_bindgen::StreamReader<Vec<DbValue>> {
&mut self.rows
}
#[allow(
clippy::type_complexity,
reason = "sorry clippy that's just what the inner bits are"
)]
pub fn into_inner(
self,
) -> (
Vec<Column>,
wit_bindgen::StreamReader<Vec<DbValue>>,
wit_bindgen::FutureReader<Result<(), PgError>>,
) {
((*self.columns).clone(), self.rows, self.result)
}
}
pub struct Row {
columns: Arc<Vec<wit::postgres::Column>>,
result: Vec<DbValue>,
}
impl Row {
pub fn get<T: Decode>(&self, column: &str) -> Option<T> {
let i = self.columns.iter().position(|c| c.name == column)?;
let db_value = self.result.get(i)?;
Decode::decode(db_value).ok()
}
}
impl std::ops::Index<usize> for Row {
type Output = DbValue;
fn index(&self, index: usize) -> &Self::Output {
&self.result[index]
}
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("error value decoding: {0}")]
Decode(String),
#[error(transparent)]
PgError(#[from] PgError),
}
pub trait Decode: Sized {
fn decode(value: &DbValue) -> Result<Self, Error>;
}
impl<T> Decode for Option<T>
where
T: Decode,
{
fn decode(value: &DbValue) -> Result<Self, Error> {
match value {
DbValue::DbNull => Ok(None),
v => Ok(Some(T::decode(v)?)),
}
}
}
impl Decode for bool {
fn decode(value: &DbValue) -> Result<Self, Error> {
match value {
DbValue::Boolean(boolean) => Ok(*boolean),
_ => Err(Error::Decode(format_decode_err("BOOL", value))),
}
}
}
impl Decode for i16 {
fn decode(value: &DbValue) -> Result<Self, Error> {
match value {
DbValue::Int16(n) => Ok(*n),
_ => Err(Error::Decode(format_decode_err("SMALLINT", value))),
}
}
}
impl Decode for i32 {
fn decode(value: &DbValue) -> Result<Self, Error> {
match value {
DbValue::Int32(n) => Ok(*n),
_ => Err(Error::Decode(format_decode_err("INT", value))),
}
}
}
impl Decode for i64 {
fn decode(value: &DbValue) -> Result<Self, Error> {
match value {
DbValue::Int64(n) => Ok(*n),
_ => Err(Error::Decode(format_decode_err("BIGINT", value))),
}
}
}
impl Decode for f32 {
fn decode(value: &DbValue) -> Result<Self, Error> {
match value {
DbValue::Floating32(n) => Ok(*n),
_ => Err(Error::Decode(format_decode_err("REAL", value))),
}
}
}
impl Decode for f64 {
fn decode(value: &DbValue) -> Result<Self, Error> {
match value {
DbValue::Floating64(n) => Ok(*n),
_ => Err(Error::Decode(format_decode_err("DOUBLE PRECISION", value))),
}
}
}
impl Decode for Vec<u8> {
fn decode(value: &DbValue) -> Result<Self, Error> {
match value {
DbValue::Binary(n) => Ok(n.to_owned()),
_ => Err(Error::Decode(format_decode_err("BYTEA", value))),
}
}
}
impl Decode for String {
fn decode(value: &DbValue) -> Result<Self, Error> {
match value {
DbValue::Str(s) => Ok(s.to_owned()),
_ => Err(Error::Decode(format_decode_err(
"CHAR, VARCHAR, TEXT",
value,
))),
}
}
}
impl Decode for chrono::NaiveDate {
fn decode(value: &DbValue) -> Result<Self, Error> {
match value {
DbValue::Date((year, month, day)) => {
let naive_date =
chrono::NaiveDate::from_ymd_opt(*year, (*month).into(), (*day).into())
.ok_or_else(|| {
Error::Decode(format!(
"invalid date y={}, m={}, d={}",
year, month, day
))
})?;
Ok(naive_date)
}
_ => Err(Error::Decode(format_decode_err("DATE", value))),
}
}
}
impl Decode for chrono::NaiveTime {
fn decode(value: &DbValue) -> Result<Self, Error> {
match value {
DbValue::Time((hour, minute, second, nanosecond)) => {
let naive_time = chrono::NaiveTime::from_hms_nano_opt(
(*hour).into(),
(*minute).into(),
(*second).into(),
*nanosecond,
)
.ok_or_else(|| {
Error::Decode(format!(
"invalid time {}:{}:{}:{}",
hour, minute, second, nanosecond
))
})?;
Ok(naive_time)
}
_ => Err(Error::Decode(format_decode_err("TIME", value))),
}
}
}
impl Decode for chrono::NaiveDateTime {
fn decode(value: &DbValue) -> Result<Self, Error> {
match value {
DbValue::Datetime((year, month, day, hour, minute, second, nanosecond)) => {
let naive_date =
chrono::NaiveDate::from_ymd_opt(*year, (*month).into(), (*day).into())
.ok_or_else(|| {
Error::Decode(format!(
"invalid date y={}, m={}, d={}",
year, month, day
))
})?;
let naive_time = chrono::NaiveTime::from_hms_nano_opt(
(*hour).into(),
(*minute).into(),
(*second).into(),
*nanosecond,
)
.ok_or_else(|| {
Error::Decode(format!(
"invalid time {}:{}:{}:{}",
hour, minute, second, nanosecond
))
})?;
let dt = chrono::NaiveDateTime::new(naive_date, naive_time);
Ok(dt)
}
_ => Err(Error::Decode(format_decode_err("DATETIME", value))),
}
}
}
impl Decode for chrono::Duration {
fn decode(value: &DbValue) -> Result<Self, Error> {
match value {
DbValue::Timestamp(n) => Ok(chrono::Duration::seconds(*n)),
_ => Err(Error::Decode(format_decode_err("BIGINT", value))),
}
}
}
#[cfg(feature = "postgres4-types")]
impl Decode for uuid::Uuid {
fn decode(value: &DbValue) -> Result<Self, Error> {
match value {
DbValue::Uuid(s) => uuid::Uuid::parse_str(s).map_err(|e| Error::Decode(e.to_string())),
_ => Err(Error::Decode(format_decode_err("UUID", value))),
}
}
}
#[cfg(feature = "json")]
impl Decode for serde_json::Value {
fn decode(value: &DbValue) -> Result<Self, Error> {
from_jsonb(value)
}
}
#[cfg(feature = "json")]
pub fn from_jsonb<'a, T: serde::Deserialize<'a>>(value: &'a DbValue) -> Result<T, Error> {
match value {
DbValue::Jsonb(j) => serde_json::from_slice(j).map_err(|e| Error::Decode(e.to_string())),
_ => Err(Error::Decode(format_decode_err("JSONB", value))),
}
}
#[cfg(feature = "postgres4-types")]
impl Decode for rust_decimal::Decimal {
fn decode(value: &DbValue) -> Result<Self, Error> {
match value {
DbValue::Decimal(s) => {
rust_decimal::Decimal::from_str_exact(s).map_err(|e| Error::Decode(e.to_string()))
}
_ => Err(Error::Decode(format_decode_err("NUMERIC", value))),
}
}
}
#[cfg(feature = "postgres4-types")]
fn bound_type_from_wit(kind: RangeBoundKind) -> postgres_range::BoundType {
match kind {
RangeBoundKind::Inclusive => postgres_range::BoundType::Inclusive,
RangeBoundKind::Exclusive => postgres_range::BoundType::Exclusive,
}
}
#[cfg(feature = "postgres4-types")]
impl Decode for postgres_range::Range<i32> {
fn decode(value: &DbValue) -> Result<Self, Error> {
match value {
DbValue::RangeInt32((lbound, ubound)) => {
let lower = lbound.map(|(value, kind)| {
postgres_range::RangeBound::new(value, bound_type_from_wit(kind))
});
let upper = ubound.map(|(value, kind)| {
postgres_range::RangeBound::new(value, bound_type_from_wit(kind))
});
Ok(postgres_range::Range::new(lower, upper))
}
_ => Err(Error::Decode(format_decode_err("INT4RANGE", value))),
}
}
}
#[cfg(feature = "postgres4-types")]
impl Decode for postgres_range::Range<i64> {
fn decode(value: &DbValue) -> Result<Self, Error> {
match value {
DbValue::RangeInt64((lbound, ubound)) => {
let lower = lbound.map(|(value, kind)| {
postgres_range::RangeBound::new(value, bound_type_from_wit(kind))
});
let upper = ubound.map(|(value, kind)| {
postgres_range::RangeBound::new(value, bound_type_from_wit(kind))
});
Ok(postgres_range::Range::new(lower, upper))
}
_ => Err(Error::Decode(format_decode_err("INT8RANGE", value))),
}
}
}
#[cfg(feature = "postgres4-types")]
impl Decode
for (
Option<(rust_decimal::Decimal, RangeBoundKind)>,
Option<(rust_decimal::Decimal, RangeBoundKind)>,
)
{
fn decode(value: &DbValue) -> Result<Self, Error> {
fn parse(
value: &str,
kind: RangeBoundKind,
) -> Result<(rust_decimal::Decimal, RangeBoundKind), Error> {
let dec = rust_decimal::Decimal::from_str_exact(value)
.map_err(|e| Error::Decode(e.to_string()))?;
Ok((dec, kind))
}
match value {
DbValue::RangeDecimal((lbound, ubound)) => {
let lower = lbound
.as_ref()
.map(|(value, kind)| parse(value, *kind))
.transpose()?;
let upper = ubound
.as_ref()
.map(|(value, kind)| parse(value, *kind))
.transpose()?;
Ok((lower, upper))
}
_ => Err(Error::Decode(format_decode_err("NUMERICRANGE", value))),
}
}
}
impl Decode for Vec<Option<i32>> {
fn decode(value: &DbValue) -> Result<Self, Error> {
match value {
DbValue::ArrayInt32(a) => Ok(a.to_vec()),
_ => Err(Error::Decode(format_decode_err("INT4[]", value))),
}
}
}
impl Decode for Vec<Option<i64>> {
fn decode(value: &DbValue) -> Result<Self, Error> {
match value {
DbValue::ArrayInt64(a) => Ok(a.to_vec()),
_ => Err(Error::Decode(format_decode_err("INT8[]", value))),
}
}
}
impl Decode for Vec<Option<String>> {
fn decode(value: &DbValue) -> Result<Self, Error> {
match value {
DbValue::ArrayStr(a) => Ok(a.to_vec()),
_ => Err(Error::Decode(format_decode_err("TEXT[]", value))),
}
}
}
#[cfg(feature = "postgres4-types")]
fn map_decimal(s: &Option<String>) -> Result<Option<rust_decimal::Decimal>, Error> {
s.as_ref()
.map(|s| rust_decimal::Decimal::from_str_exact(s))
.transpose()
.map_err(|e| Error::Decode(e.to_string()))
}
#[cfg(feature = "postgres4-types")]
impl Decode for Vec<Option<rust_decimal::Decimal>> {
fn decode(value: &DbValue) -> Result<Self, Error> {
match value {
DbValue::ArrayDecimal(a) => {
let decs = a.iter().map(map_decimal).collect::<Result<_, _>>()?;
Ok(decs)
}
_ => Err(Error::Decode(format_decode_err("NUMERIC[]", value))),
}
}
}
impl Decode for Interval {
fn decode(value: &DbValue) -> Result<Self, Error> {
match value {
DbValue::Interval(i) => Ok(*i),
_ => Err(Error::Decode(format_decode_err("INTERVAL", value))),
}
}
}
macro_rules! impl_parameter_value_conversions {
($($ty:ty => $id:ident),*) => {
$(
impl From<$ty> for ParameterValue {
fn from(v: $ty) -> ParameterValue {
ParameterValue::$id(v)
}
}
)*
};
}
impl_parameter_value_conversions! {
i8 => Int8,
i16 => Int16,
i32 => Int32,
i64 => Int64,
f32 => Floating32,
f64 => Floating64,
bool => Boolean,
String => Str,
Vec<u8> => Binary,
Vec<Option<i32>> => ArrayInt32,
Vec<Option<i64>> => ArrayInt64,
Vec<Option<String>> => ArrayStr
}
impl From<chrono::NaiveDateTime> for ParameterValue {
fn from(v: chrono::NaiveDateTime) -> ParameterValue {
ParameterValue::Datetime((
v.year(),
v.month() as u8,
v.day() as u8,
v.hour() as u8,
v.minute() as u8,
v.second() as u8,
v.nanosecond(),
))
}
}
impl From<chrono::NaiveTime> for ParameterValue {
fn from(v: chrono::NaiveTime) -> ParameterValue {
ParameterValue::Time((
v.hour() as u8,
v.minute() as u8,
v.second() as u8,
v.nanosecond(),
))
}
}
impl From<chrono::NaiveDate> for ParameterValue {
fn from(v: chrono::NaiveDate) -> ParameterValue {
ParameterValue::Date((v.year(), v.month() as u8, v.day() as u8))
}
}
impl From<chrono::TimeDelta> for ParameterValue {
fn from(v: chrono::TimeDelta) -> ParameterValue {
ParameterValue::Timestamp(v.num_seconds())
}
}
#[cfg(feature = "postgres4-types")]
impl From<uuid::Uuid> for ParameterValue {
fn from(v: uuid::Uuid) -> ParameterValue {
ParameterValue::Uuid(v.to_string())
}
}
#[cfg(feature = "json")]
impl TryFrom<serde_json::Value> for ParameterValue {
type Error = serde_json::Error;
fn try_from(v: serde_json::Value) -> Result<ParameterValue, Self::Error> {
jsonb(&v)
}
}
#[cfg(feature = "json")]
pub fn jsonb<T: serde::Serialize>(value: &T) -> Result<ParameterValue, serde_json::Error> {
let json = serde_json::to_vec(value)?;
Ok(ParameterValue::Jsonb(json))
}
#[cfg(feature = "postgres4-types")]
impl From<rust_decimal::Decimal> for ParameterValue {
fn from(v: rust_decimal::Decimal) -> ParameterValue {
ParameterValue::Decimal(v.to_string())
}
}
#[allow(
clippy::type_complexity,
reason = "I sure hope 'blame Alex' works here too"
)]
fn range_bounds_to_wit<T, U>(
range: impl std::ops::RangeBounds<T>,
f: impl Fn(&T) -> U,
) -> (Option<(U, RangeBoundKind)>, Option<(U, RangeBoundKind)>) {
(
range_bound_to_wit(range.start_bound(), &f),
range_bound_to_wit(range.end_bound(), &f),
)
}
fn range_bound_to_wit<T, U>(
bound: std::ops::Bound<&T>,
f: &dyn Fn(&T) -> U,
) -> Option<(U, RangeBoundKind)> {
match bound {
std::ops::Bound::Included(v) => Some((f(v), RangeBoundKind::Inclusive)),
std::ops::Bound::Excluded(v) => Some((f(v), RangeBoundKind::Exclusive)),
std::ops::Bound::Unbounded => None,
}
}
#[cfg(feature = "postgres4-types")]
fn pg_range_bound_to_wit<S: postgres_range::BoundSided, T: Copy>(
bound: &postgres_range::RangeBound<S, T>,
) -> (T, RangeBoundKind) {
let kind = match &bound.type_ {
postgres_range::BoundType::Inclusive => RangeBoundKind::Inclusive,
postgres_range::BoundType::Exclusive => RangeBoundKind::Exclusive,
};
(bound.value, kind)
}
impl From<std::ops::Range<i32>> for ParameterValue {
fn from(v: std::ops::Range<i32>) -> ParameterValue {
ParameterValue::RangeInt32(range_bounds_to_wit(v, |n| *n))
}
}
impl From<std::ops::RangeInclusive<i32>> for ParameterValue {
fn from(v: std::ops::RangeInclusive<i32>) -> ParameterValue {
ParameterValue::RangeInt32(range_bounds_to_wit(v, |n| *n))
}
}
impl From<std::ops::RangeFrom<i32>> for ParameterValue {
fn from(v: std::ops::RangeFrom<i32>) -> ParameterValue {
ParameterValue::RangeInt32(range_bounds_to_wit(v, |n| *n))
}
}
impl From<std::ops::RangeTo<i32>> for ParameterValue {
fn from(v: std::ops::RangeTo<i32>) -> ParameterValue {
ParameterValue::RangeInt32(range_bounds_to_wit(v, |n| *n))
}
}
impl From<std::ops::RangeToInclusive<i32>> for ParameterValue {
fn from(v: std::ops::RangeToInclusive<i32>) -> ParameterValue {
ParameterValue::RangeInt32(range_bounds_to_wit(v, |n| *n))
}
}
#[cfg(feature = "postgres4-types")]
impl From<postgres_range::Range<i32>> for ParameterValue {
fn from(v: postgres_range::Range<i32>) -> ParameterValue {
let lbound = v.lower().map(pg_range_bound_to_wit);
let ubound = v.upper().map(pg_range_bound_to_wit);
ParameterValue::RangeInt32((lbound, ubound))
}
}
impl From<std::ops::Range<i64>> for ParameterValue {
fn from(v: std::ops::Range<i64>) -> ParameterValue {
ParameterValue::RangeInt64(range_bounds_to_wit(v, |n| *n))
}
}
impl From<std::ops::RangeInclusive<i64>> for ParameterValue {
fn from(v: std::ops::RangeInclusive<i64>) -> ParameterValue {
ParameterValue::RangeInt64(range_bounds_to_wit(v, |n| *n))
}
}
impl From<std::ops::RangeFrom<i64>> for ParameterValue {
fn from(v: std::ops::RangeFrom<i64>) -> ParameterValue {
ParameterValue::RangeInt64(range_bounds_to_wit(v, |n| *n))
}
}
impl From<std::ops::RangeTo<i64>> for ParameterValue {
fn from(v: std::ops::RangeTo<i64>) -> ParameterValue {
ParameterValue::RangeInt64(range_bounds_to_wit(v, |n| *n))
}
}
impl From<std::ops::RangeToInclusive<i64>> for ParameterValue {
fn from(v: std::ops::RangeToInclusive<i64>) -> ParameterValue {
ParameterValue::RangeInt64(range_bounds_to_wit(v, |n| *n))
}
}
#[cfg(feature = "postgres4-types")]
impl From<postgres_range::Range<i64>> for ParameterValue {
fn from(v: postgres_range::Range<i64>) -> ParameterValue {
let lbound = v.lower().map(pg_range_bound_to_wit);
let ubound = v.upper().map(pg_range_bound_to_wit);
ParameterValue::RangeInt64((lbound, ubound))
}
}
#[cfg(feature = "postgres4-types")]
impl From<std::ops::Range<rust_decimal::Decimal>> for ParameterValue {
fn from(v: std::ops::Range<rust_decimal::Decimal>) -> ParameterValue {
ParameterValue::RangeDecimal(range_bounds_to_wit(v, |d| d.to_string()))
}
}
impl From<Vec<i32>> for ParameterValue {
fn from(v: Vec<i32>) -> ParameterValue {
ParameterValue::ArrayInt32(v.into_iter().map(Some).collect())
}
}
impl From<Vec<i64>> for ParameterValue {
fn from(v: Vec<i64>) -> ParameterValue {
ParameterValue::ArrayInt64(v.into_iter().map(Some).collect())
}
}
impl From<Vec<String>> for ParameterValue {
fn from(v: Vec<String>) -> ParameterValue {
ParameterValue::ArrayStr(v.into_iter().map(Some).collect())
}
}
#[cfg(feature = "postgres4-types")]
impl From<Vec<Option<rust_decimal::Decimal>>> for ParameterValue {
fn from(v: Vec<Option<rust_decimal::Decimal>>) -> ParameterValue {
let strs = v
.into_iter()
.map(|optd| optd.map(|d| d.to_string()))
.collect();
ParameterValue::ArrayDecimal(strs)
}
}
#[cfg(feature = "postgres4-types")]
impl From<Vec<rust_decimal::Decimal>> for ParameterValue {
fn from(v: Vec<rust_decimal::Decimal>) -> ParameterValue {
let strs = v.into_iter().map(|d| Some(d.to_string())).collect();
ParameterValue::ArrayDecimal(strs)
}
}
impl From<Interval> for ParameterValue {
fn from(v: Interval) -> ParameterValue {
ParameterValue::Interval(v)
}
}
impl<T: Into<ParameterValue>> From<Option<T>> for ParameterValue {
fn from(o: Option<T>) -> ParameterValue {
match o {
Some(v) => v.into(),
None => ParameterValue::DbNull,
}
}
}
fn format_decode_err(types: &str, value: &DbValue) -> String {
format!("Expected {} from the DB but got {:?}", types, value)
}
#[cfg(test)]
mod tests {
use chrono::NaiveDateTime;
use super::*;
#[test]
fn boolean() {
assert!(bool::decode(&DbValue::Boolean(true)).unwrap());
assert!(bool::decode(&DbValue::Int32(0)).is_err());
assert!(Option::<bool>::decode(&DbValue::DbNull).unwrap().is_none());
}
#[test]
fn int16() {
assert_eq!(i16::decode(&DbValue::Int16(0)).unwrap(), 0);
assert!(i16::decode(&DbValue::Int32(0)).is_err());
assert!(Option::<i16>::decode(&DbValue::DbNull).unwrap().is_none());
}
#[test]
fn int32() {
assert_eq!(i32::decode(&DbValue::Int32(0)).unwrap(), 0);
assert!(i32::decode(&DbValue::Boolean(false)).is_err());
assert!(Option::<i32>::decode(&DbValue::DbNull).unwrap().is_none());
}
#[test]
fn int64() {
assert_eq!(i64::decode(&DbValue::Int64(0)).unwrap(), 0);
assert!(i64::decode(&DbValue::Boolean(false)).is_err());
assert!(Option::<i64>::decode(&DbValue::DbNull).unwrap().is_none());
}
#[test]
fn floating32() {
assert!(f32::decode(&DbValue::Floating32(0.0)).is_ok());
assert!(f32::decode(&DbValue::Boolean(false)).is_err());
assert!(Option::<f32>::decode(&DbValue::DbNull).unwrap().is_none());
}
#[test]
fn floating64() {
assert!(f64::decode(&DbValue::Floating64(0.0)).is_ok());
assert!(f64::decode(&DbValue::Boolean(false)).is_err());
assert!(Option::<f64>::decode(&DbValue::DbNull).unwrap().is_none());
}
#[test]
fn str() {
assert_eq!(
String::decode(&DbValue::Str(String::from("foo"))).unwrap(),
String::from("foo")
);
assert!(String::decode(&DbValue::Int32(0)).is_err());
assert!(Option::<String>::decode(&DbValue::DbNull)
.unwrap()
.is_none());
}
#[test]
fn binary() {
assert!(Vec::<u8>::decode(&DbValue::Binary(vec![0, 0])).is_ok());
assert!(Vec::<u8>::decode(&DbValue::Boolean(false)).is_err());
assert!(Option::<Vec<u8>>::decode(&DbValue::DbNull)
.unwrap()
.is_none());
}
#[test]
fn date() {
assert_eq!(
chrono::NaiveDate::decode(&DbValue::Date((1, 2, 4))).unwrap(),
chrono::NaiveDate::from_ymd_opt(1, 2, 4).unwrap()
);
assert_ne!(
chrono::NaiveDate::decode(&DbValue::Date((1, 2, 4))).unwrap(),
chrono::NaiveDate::from_ymd_opt(1, 2, 5).unwrap()
);
assert!(Option::<chrono::NaiveDate>::decode(&DbValue::DbNull)
.unwrap()
.is_none());
}
#[test]
fn time() {
assert_eq!(
chrono::NaiveTime::decode(&DbValue::Time((1, 2, 3, 4))).unwrap(),
chrono::NaiveTime::from_hms_nano_opt(1, 2, 3, 4).unwrap()
);
assert_ne!(
chrono::NaiveTime::decode(&DbValue::Time((1, 2, 3, 4))).unwrap(),
chrono::NaiveTime::from_hms_nano_opt(1, 2, 4, 5).unwrap()
);
assert!(Option::<chrono::NaiveTime>::decode(&DbValue::DbNull)
.unwrap()
.is_none());
}
#[test]
fn datetime() {
let date = chrono::NaiveDate::from_ymd_opt(1, 2, 3).unwrap();
let mut time = chrono::NaiveTime::from_hms_nano_opt(4, 5, 6, 7).unwrap();
assert_eq!(
chrono::NaiveDateTime::decode(&DbValue::Datetime((1, 2, 3, 4, 5, 6, 7))).unwrap(),
chrono::NaiveDateTime::new(date, time)
);
time = chrono::NaiveTime::from_hms_nano_opt(4, 5, 6, 8).unwrap();
assert_ne!(
NaiveDateTime::decode(&DbValue::Datetime((1, 2, 3, 4, 5, 6, 7))).unwrap(),
chrono::NaiveDateTime::new(date, time)
);
assert!(Option::<chrono::NaiveDateTime>::decode(&DbValue::DbNull)
.unwrap()
.is_none());
}
#[test]
fn timestamp() {
assert_eq!(
chrono::Duration::decode(&DbValue::Timestamp(1)).unwrap(),
chrono::Duration::seconds(1),
);
assert_ne!(
chrono::Duration::decode(&DbValue::Timestamp(2)).unwrap(),
chrono::Duration::seconds(1)
);
assert!(Option::<chrono::Duration>::decode(&DbValue::DbNull)
.unwrap()
.is_none());
}
#[test]
#[cfg(feature = "postgres4-types")]
fn uuid() {
let uuid_str = "12341234-1234-1234-1234-123412341234";
assert_eq!(
uuid::Uuid::try_parse(uuid_str).unwrap(),
uuid::Uuid::decode(&DbValue::Uuid(uuid_str.to_owned())).unwrap(),
);
assert!(Option::<uuid::Uuid>::decode(&DbValue::DbNull)
.unwrap()
.is_none());
}
#[derive(Debug, serde::Deserialize, PartialEq)]
struct JsonTest {
hello: String,
}
#[test]
#[cfg(feature = "json")]
fn jsonb() {
let json_val = serde_json::json!({
"hello": "world"
});
let dbval = DbValue::Jsonb(r#"{"hello":"world"}"#.into());
assert_eq!(json_val, serde_json::Value::decode(&dbval).unwrap(),);
let json_struct = JsonTest {
hello: "world".to_owned(),
};
assert_eq!(json_struct, from_jsonb(&dbval).unwrap());
}
#[test]
#[cfg(feature = "postgres4-types")]
fn ranges() {
let i32_range = postgres_range::Range::<i32>::decode(&DbValue::RangeInt32((
Some((45, RangeBoundKind::Inclusive)),
Some((89, RangeBoundKind::Exclusive)),
)))
.unwrap();
assert_eq!(45, i32_range.lower().unwrap().value);
assert_eq!(
postgres_range::BoundType::Inclusive,
i32_range.lower().unwrap().type_
);
assert_eq!(89, i32_range.upper().unwrap().value);
assert_eq!(
postgres_range::BoundType::Exclusive,
i32_range.upper().unwrap().type_
);
let i32_range_from = postgres_range::Range::<i32>::decode(&DbValue::RangeInt32((
Some((45, RangeBoundKind::Inclusive)),
None,
)))
.unwrap();
assert!(i32_range_from.upper().is_none());
let i64_range = postgres_range::Range::<i64>::decode(&DbValue::RangeInt64((
Some((4567456745674567, RangeBoundKind::Inclusive)),
Some((890189018901890189, RangeBoundKind::Exclusive)),
)))
.unwrap();
assert_eq!(4567456745674567, i64_range.lower().unwrap().value);
assert_eq!(890189018901890189, i64_range.upper().unwrap().value);
#[allow(clippy::type_complexity)]
let (dec_lbound, dec_ubound): (
Option<(rust_decimal::Decimal, RangeBoundKind)>,
Option<(rust_decimal::Decimal, RangeBoundKind)>,
) = Decode::decode(&DbValue::RangeDecimal((
Some(("4567.8901".to_owned(), RangeBoundKind::Inclusive)),
Some(("8901.2345678901".to_owned(), RangeBoundKind::Exclusive)),
)))
.unwrap();
assert_eq!(
rust_decimal::Decimal::from_i128_with_scale(45678901, 4),
dec_lbound.unwrap().0
);
assert_eq!(
rust_decimal::Decimal::from_i128_with_scale(89012345678901, 10),
dec_ubound.unwrap().0
);
}
#[test]
#[cfg(feature = "postgres4-types")]
fn arrays() {
let v32 = vec![Some(123), None, Some(456)];
let i32_arr = Vec::<Option<i32>>::decode(&DbValue::ArrayInt32(v32.clone())).unwrap();
assert_eq!(v32, i32_arr);
let v64 = vec![Some(123), None, Some(456)];
let i64_arr = Vec::<Option<i64>>::decode(&DbValue::ArrayInt64(v64.clone())).unwrap();
assert_eq!(v64, i64_arr);
let vdec = vec![Some("1.23".to_owned()), None];
let dec_arr =
Vec::<Option<rust_decimal::Decimal>>::decode(&DbValue::ArrayDecimal(vdec)).unwrap();
assert_eq!(
vec![
Some(rust_decimal::Decimal::from_i128_with_scale(123, 2)),
None
],
dec_arr
);
let vstr = vec![Some("alice".to_owned()), None, Some("bob".to_owned())];
let str_arr = Vec::<Option<String>>::decode(&DbValue::ArrayStr(vstr.clone())).unwrap();
assert_eq!(vstr, str_arr);
}
}