#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Auto<T> {
Unset,
Set(T),
}
impl<T: serde::Serialize> serde::Serialize for Auto<T> {
fn serialize<S: serde::Serializer>(&self, ser: S) -> Result<S::Ok, S::Error> {
match self {
Self::Set(v) => v.serialize(ser),
Self::Unset => ser.serialize_none(),
}
}
}
impl<'de, T> serde::Deserialize<'de> for Auto<T>
where
T: serde::de::DeserializeOwned,
{
fn deserialize<D: serde::Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
use serde::de::Error as _;
let value = serde_json::Value::deserialize(de)?;
match value {
serde_json::Value::Null => Ok(Self::Unset),
serde_json::Value::String(ref s) if s == "Unset" => Ok(Self::Unset),
serde_json::Value::Object(ref map) if map.len() == 1 && map.contains_key("Set") => {
let inner = map.get("Set").cloned().unwrap_or(serde_json::Value::Null);
let v: T = serde_json::from_value(inner).map_err(D::Error::custom)?;
Ok(Self::Set(v))
}
other => {
let v: T = serde_json::from_value(other).map_err(D::Error::custom)?;
Ok(Self::Set(v))
}
}
}
}
impl<T> Default for Auto<T> {
fn default() -> Self {
Self::Unset
}
}
impl<T> From<T> for Auto<T> {
fn from(v: T) -> Self {
Self::Set(v)
}
}
impl<T> Auto<T> {
#[must_use]
pub fn get(&self) -> Option<&T> {
match self {
Self::Set(v) => Some(v),
Self::Unset => None,
}
}
#[must_use]
pub fn is_unset(&self) -> bool {
matches!(self, Self::Unset)
}
#[must_use]
pub fn is_set(&self) -> bool {
matches!(self, Self::Set(_))
}
pub fn take(&mut self) -> Option<T> {
match std::mem::replace(self, Self::Unset) {
Self::Set(v) => Some(v),
Self::Unset => None,
}
}
#[must_use]
pub fn into_inner(self) -> Option<T> {
match self {
Self::Set(v) => Some(v),
Self::Unset => None,
}
}
}
impl<T> From<Auto<T>> for crate::core::SqlValue
where
T: Into<crate::core::SqlValue>,
{
fn from(a: Auto<T>) -> Self {
match a {
Auto::Unset => Self::Null,
Auto::Set(v) => v.into(),
}
}
}
#[cfg(feature = "postgres")]
impl<'r, T> sqlx::Decode<'r, sqlx::Postgres> for Auto<T>
where
T: sqlx::Decode<'r, sqlx::Postgres>,
{
fn decode(
value: <sqlx::Postgres as sqlx::Database>::ValueRef<'r>,
) -> Result<Self, sqlx::error::BoxDynError> {
Ok(Self::Set(T::decode(value)?))
}
}
#[cfg(feature = "postgres")]
impl<T> sqlx::Type<sqlx::Postgres> for Auto<T>
where
T: sqlx::Type<sqlx::Postgres>,
{
fn type_info() -> sqlx::postgres::PgTypeInfo {
T::type_info()
}
fn compatible(ty: &sqlx::postgres::PgTypeInfo) -> bool {
T::compatible(ty)
}
}
#[cfg(feature = "mysql")]
impl<'r, T> sqlx::Decode<'r, sqlx::MySql> for Auto<T>
where
T: sqlx::Decode<'r, sqlx::MySql>,
{
fn decode(
value: <sqlx::MySql as sqlx::Database>::ValueRef<'r>,
) -> Result<Self, sqlx::error::BoxDynError> {
Ok(Self::Set(T::decode(value)?))
}
}
#[cfg(feature = "mysql")]
impl<T> sqlx::Type<sqlx::MySql> for Auto<T>
where
T: sqlx::Type<sqlx::MySql>,
{
fn type_info() -> sqlx::mysql::MySqlTypeInfo {
T::type_info()
}
fn compatible(ty: &sqlx::mysql::MySqlTypeInfo) -> bool {
T::compatible(ty)
}
}
#[cfg(feature = "sqlite")]
impl<'r, T> sqlx::Decode<'r, sqlx::Sqlite> for Auto<T>
where
T: sqlx::Decode<'r, sqlx::Sqlite>,
{
fn decode(
value: <sqlx::Sqlite as sqlx::Database>::ValueRef<'r>,
) -> Result<Self, sqlx::error::BoxDynError> {
Ok(Self::Set(T::decode(value)?))
}
}
#[cfg(feature = "sqlite")]
impl<T> sqlx::Type<sqlx::Sqlite> for Auto<T>
where
T: sqlx::Type<sqlx::Sqlite>,
{
fn type_info() -> sqlx::sqlite::SqliteTypeInfo {
T::type_info()
}
fn compatible(ty: &sqlx::sqlite::SqliteTypeInfo) -> bool {
T::compatible(ty)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_is_unset() {
let a: Auto<i64> = Auto::default();
assert!(a.is_unset());
assert!(!a.is_set());
assert!(a.get().is_none());
}
#[test]
fn from_t_is_set() {
let a: Auto<i64> = 42_i64.into();
assert!(a.is_set());
assert_eq!(a.get(), Some(&42));
}
#[test]
fn into_inner_returns_value_or_none() {
assert_eq!(Auto::Set(7_i64).into_inner(), Some(7));
assert_eq!(Auto::<i64>::Unset.into_inner(), None);
}
#[test]
fn take_leaves_unset_behind() {
let mut a: Auto<i64> = 99_i64.into();
let v = a.take();
assert_eq!(v, Some(99));
assert!(a.is_unset());
assert_eq!(a.take(), None);
}
#[test]
fn serde_round_trip_set() {
let a: Auto<i64> = Auto::Set(42);
let json = serde_json::to_string(&a).unwrap();
let back: Auto<i64> = serde_json::from_str(&json).unwrap();
assert_eq!(back, Auto::Set(42));
}
#[test]
fn serde_round_trip_unset() {
let a: Auto<i64> = Auto::Unset;
let json = serde_json::to_string(&a).unwrap();
let back: Auto<i64> = serde_json::from_str(&json).unwrap();
assert_eq!(back, Auto::Unset);
}
#[test]
fn serialize_set_emits_bare_value() {
assert_eq!(serde_json::to_string(&Auto::Set(42_i64)).unwrap(), "42");
assert_eq!(
serde_json::to_string(&Auto::Set("hello".to_owned())).unwrap(),
"\"hello\""
);
}
#[test]
fn serialize_unset_emits_null() {
assert_eq!(serde_json::to_string(&Auto::<i64>::Unset).unwrap(), "null");
}
#[test]
fn deserialize_accepts_bare_shape() {
assert_eq!(
serde_json::from_str::<Auto<i64>>("42").unwrap(),
Auto::Set(42)
);
assert_eq!(
serde_json::from_str::<Auto<i64>>("null").unwrap(),
Auto::Unset
);
}
#[test]
fn deserialize_accepts_legacy_tagged_enum() {
assert_eq!(
serde_json::from_str::<Auto<i64>>("\"Unset\"").unwrap(),
Auto::Unset
);
assert_eq!(
serde_json::from_str::<Auto<i64>>(r#"{"Set": 42}"#).unwrap(),
Auto::Set(42)
);
}
#[test]
fn within_a_struct_round_trip_is_clean() {
#[derive(serde::Serialize, serde::Deserialize, PartialEq, Debug)]
struct Row {
id: Auto<i64>,
name: String,
}
let row = Row {
id: Auto::Set(7),
name: "x".to_owned(),
};
let json = serde_json::to_string(&row).unwrap();
assert_eq!(json, r#"{"id":7,"name":"x"}"#);
let back: Row = serde_json::from_str(&json).unwrap();
assert_eq!(back, row);
}
}