use crate::{
pg_sys, AnyNumeric, Date, FromDatum, IntoDatum, Numeric, Timestamp, TimestampWithTimeZone,
};
use core::fmt::{Display, Formatter};
use pgx_sql_entity_graph::metadata::{
ArgumentError, Returns, ReturnsError, SqlMapping, SqlTranslatable,
};
use std::ops::{Deref, DerefMut, RangeFrom, RangeInclusive, RangeTo, RangeToInclusive};
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub enum RangeBound<T> {
Infinite,
Inclusive(T),
Exclusive(T),
}
impl<T> RangeBound<T>
where
T: RangeSubType,
{
#[inline]
pub fn get(&self) -> Option<&T> {
match self {
RangeBound::Infinite => None,
RangeBound::Inclusive(v) | RangeBound::Exclusive(v) => Some(v),
}
}
#[inline]
pub fn is_infinite(&self) -> bool {
matches!(self, RangeBound::Infinite)
}
#[inline]
pub fn is_inclusive(&self) -> bool {
matches!(self, RangeBound::Inclusive(_))
}
#[inline]
pub fn is_exclusive(&self) -> bool {
matches!(self, RangeBound::Exclusive(_))
}
fn into_pg(self) -> pg_sys::RangeBound {
match self {
RangeBound::Infinite => pg_sys::RangeBound {
val: pg_sys::Datum::from(0),
infinite: true,
inclusive: false,
lower: false,
},
RangeBound::Inclusive(v) => pg_sys::RangeBound {
val: v.into_datum().unwrap(),
infinite: false,
inclusive: true,
lower: false,
},
RangeBound::Exclusive(v) => pg_sys::RangeBound {
val: v.into_datum().unwrap(),
infinite: false,
inclusive: false,
lower: false,
},
}
}
pub unsafe fn from_pg(range_bound: pg_sys::RangeBound) -> RangeBound<T> {
if range_bound.infinite {
RangeBound::Infinite
} else if range_bound.inclusive {
unsafe { RangeBound::Inclusive(T::from_datum(range_bound.val, false).unwrap()) }
} else {
unsafe { RangeBound::Exclusive(T::from_datum(range_bound.val, false).unwrap()) }
}
}
}
impl<T> From<T> for RangeBound<T>
where
T: RangeSubType,
{
#[inline]
fn from(value: T) -> Self {
RangeBound::Inclusive(value)
}
}
impl<T> From<Option<T>> for RangeBound<T>
where
T: RangeSubType,
{
#[inline]
fn from(value: Option<T>) -> Self {
match value {
Some(value) => RangeBound::Inclusive(value),
None => RangeBound::Infinite,
}
}
}
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct Range<T: RangeSubType> {
inner: Option<(RangeBound<T>, RangeBound<T>)>,
}
impl<T> Display for Range<T>
where
T: RangeSubType + Display,
{
#[rustfmt::skip]
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
match self.as_ref() {
None => write!(f, "empty"),
Some((RangeBound::Infinite, RangeBound::Infinite)) => write!(f, "(,)"),
Some((RangeBound::Infinite, RangeBound::Inclusive(v))) => write!(f, "(,{}]", v),
Some((RangeBound::Infinite, RangeBound::Exclusive(v))) => write!(f, "(,{})", v),
Some((RangeBound::Inclusive(v), RangeBound::Infinite)) => write!(f, "[{},)", v),
Some((RangeBound::Inclusive(l), RangeBound::Inclusive(u))) => write!(f, "[{},{}]", l, u),
Some((RangeBound::Inclusive(l), RangeBound::Exclusive(u))) => write!(f, "[{},{})", l, u),
Some((RangeBound::Exclusive(v), RangeBound::Infinite)) => write!(f, "({},)", v),
Some((RangeBound::Exclusive(l), RangeBound::Inclusive(u))) => write!(f, "({},{}]", l, u),
Some((RangeBound::Exclusive(l), RangeBound::Exclusive(u))) => write!(f, "({},{})", l, u),
}
}
}
impl<T> Range<T>
where
T: RangeSubType,
{
#[inline]
pub fn new<L, U>(lower: L, upper: U) -> Self
where
L: Into<RangeBound<T>>,
U: Into<RangeBound<T>>,
{
Self { inner: Some((lower.into(), upper.into())) }
}
#[inline]
pub fn empty() -> Self {
Self { inner: None }
}
#[inline]
pub fn infinite() -> Self {
Self::new(RangeBound::Infinite, RangeBound::Infinite)
}
#[inline]
pub fn lower(&self) -> Option<&RangeBound<T>> {
match &self.inner {
Some((l, _)) => Some(l),
None => None,
}
}
#[inline]
pub fn upper(&self) -> Option<&RangeBound<T>> {
match &self.inner {
Some((_, u)) => Some(u),
None => None,
}
}
#[inline]
pub fn is_empty(&self) -> bool {
self.inner.is_none()
}
#[inline]
pub fn is_infinite(&self) -> bool {
match (self.lower(), self.upper()) {
(Some(RangeBound::Infinite), Some(RangeBound::Infinite)) => true,
_ => false,
}
}
#[inline]
pub fn into_inner(self) -> Option<(RangeBound<T>, RangeBound<T>)> {
self.inner
}
#[inline]
pub fn take(&mut self) -> Option<(RangeBound<T>, RangeBound<T>)> {
self.inner.take()
}
#[inline]
pub fn replace(
&mut self,
new: Option<(RangeBound<T>, RangeBound<T>)>,
) -> Option<(RangeBound<T>, RangeBound<T>)> {
std::mem::replace(&mut self.inner, new)
}
}
impl<T> Deref for Range<T>
where
T: RangeSubType,
{
type Target = Option<(RangeBound<T>, RangeBound<T>)>;
#[inline]
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<T> DerefMut for Range<T>
where
T: RangeSubType,
{
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
impl<T> FromDatum for Range<T>
where
T: RangeSubType,
{
#[inline]
unsafe fn from_polymorphic_datum(
datum: pg_sys::Datum,
is_null: bool,
_: pg_sys::Oid,
) -> Option<Self>
where
Self: Sized,
{
if is_null || datum.is_null() {
None
} else {
let ptr: *mut pg_sys::varlena = datum.cast_mut_ptr();
let range_type =
unsafe { pg_sys::pg_detoast_datum(datum.cast_mut_ptr()) as *mut pg_sys::RangeType };
let mut lower_bound: pg_sys::RangeBound = Default::default();
let mut upper_bound: pg_sys::RangeBound = Default::default();
let mut is_empty = false;
unsafe {
let typecache = pg_sys::lookup_type_cache(
(*(range_type)).rangetypid,
pg_sys::TYPECACHE_RANGE_INFO as i32,
);
pg_sys::range_deserialize(
typecache,
range_type,
&mut lower_bound,
&mut upper_bound,
&mut is_empty,
);
let lower = RangeBound::from_pg(lower_bound);
let upper = RangeBound::from_pg(upper_bound);
if std::ptr::eq(ptr, range_type.cast()) == false {
pg_sys::pfree(range_type.cast());
}
Some(Range { inner: if is_empty { None } else { Some((lower, upper)) } })
}
}
}
}
impl<T> IntoDatum for Range<T>
where
T: RangeSubType,
{
#[inline]
fn into_datum(self) -> Option<pg_sys::Datum> {
unsafe {
let typecache =
pg_sys::lookup_type_cache(T::range_type_oid(), pg_sys::TYPECACHE_RANGE_INFO as i32);
let is_empty = self.is_empty();
let (mut lower_bound, mut upper_bound) = self.inner.map_or_else(
|| (pg_sys::RangeBound::default(), pg_sys::RangeBound::default()),
|(l, u)| (l.into_pg(), u.into_pg()),
);
lower_bound.lower = true;
let range_type =
pg_sys::make_range(typecache, &mut lower_bound, &mut upper_bound, is_empty);
Some(pg_sys::Datum::from(range_type))
}
}
#[inline]
fn type_oid() -> pg_sys::Oid {
T::range_type_oid()
}
}
impl<T> From<std::ops::Range<T>> for Range<T>
where
T: RangeSubType,
{
#[inline]
fn from(value: std::ops::Range<T>) -> Self {
Range::new(RangeBound::Inclusive(value.start), RangeBound::Exclusive(value.end))
}
}
impl<T> From<std::ops::RangeFrom<T>> for Range<T>
where
T: RangeSubType,
{
#[inline]
fn from(value: RangeFrom<T>) -> Self {
Range::new(Some(value.start), None)
}
}
impl<T> From<std::ops::RangeFull> for Range<T>
where
T: RangeSubType,
{
#[inline]
fn from(_: std::ops::RangeFull) -> Self {
Range::new(RangeBound::Infinite, RangeBound::Infinite)
}
}
impl<T> From<std::ops::RangeInclusive<T>> for Range<T>
where
T: RangeSubType,
{
#[inline]
fn from(value: RangeInclusive<T>) -> Self {
Range::new(
RangeBound::Inclusive(Clone::clone(value.start())),
RangeBound::Inclusive(Clone::clone(value.end())),
)
}
}
impl<T> From<std::ops::RangeTo<T>> for Range<T>
where
T: RangeSubType,
{
#[inline]
fn from(value: RangeTo<T>) -> Self {
Range::new(RangeBound::Infinite, RangeBound::Exclusive(value.end))
}
}
impl<T> From<std::ops::RangeToInclusive<T>> for Range<T>
where
T: RangeSubType,
{
#[inline]
fn from(value: RangeToInclusive<T>) -> Self {
Range::new(RangeBound::Infinite, RangeBound::Inclusive(value.end))
}
}
pub unsafe trait RangeSubType: Clone + FromDatum + IntoDatum {
fn range_type_oid() -> pg_sys::Oid;
}
unsafe impl RangeSubType for i32 {
fn range_type_oid() -> pg_sys::Oid {
pg_sys::INT4RANGEOID
}
}
unsafe impl RangeSubType for i64 {
fn range_type_oid() -> pg_sys::Oid {
pg_sys::INT8RANGEOID
}
}
unsafe impl RangeSubType for AnyNumeric {
fn range_type_oid() -> pg_sys::Oid {
pg_sys::NUMRANGEOID
}
}
unsafe impl<const P: u32, const S: u32> RangeSubType for Numeric<P, S> {
fn range_type_oid() -> pg_sys::Oid {
pg_sys::NUMRANGEOID
}
}
unsafe impl RangeSubType for Date {
fn range_type_oid() -> pg_sys::Oid {
pg_sys::DATERANGEOID
}
}
unsafe impl RangeSubType for Timestamp {
fn range_type_oid() -> pg_sys::Oid {
pg_sys::TSRANGEOID
}
}
unsafe impl RangeSubType for TimestampWithTimeZone {
fn range_type_oid() -> pg_sys::Oid {
pg_sys::TSTZRANGEOID
}
}
unsafe impl SqlTranslatable for Range<i32> {
fn argument_sql() -> Result<SqlMapping, ArgumentError> {
Ok(SqlMapping::literal("int4range"))
}
fn return_sql() -> Result<Returns, ReturnsError> {
Ok(Returns::One(SqlMapping::literal("int4range")))
}
}
unsafe impl SqlTranslatable for Range<i64> {
fn argument_sql() -> Result<SqlMapping, ArgumentError> {
Ok(SqlMapping::literal("int8range"))
}
fn return_sql() -> Result<Returns, ReturnsError> {
Ok(Returns::One(SqlMapping::literal("int8range")))
}
}
unsafe impl SqlTranslatable for Range<AnyNumeric> {
fn argument_sql() -> Result<SqlMapping, ArgumentError> {
Ok(SqlMapping::literal("numrange"))
}
fn return_sql() -> Result<Returns, ReturnsError> {
Ok(Returns::One(SqlMapping::literal("numrange")))
}
}
unsafe impl<const P: u32, const S: u32> SqlTranslatable for Range<Numeric<P, S>> {
fn argument_sql() -> Result<SqlMapping, ArgumentError> {
Ok(SqlMapping::literal("numrange"))
}
fn return_sql() -> Result<Returns, ReturnsError> {
Ok(Returns::One(SqlMapping::literal("numrange")))
}
}
unsafe impl SqlTranslatable for Range<Date> {
fn argument_sql() -> Result<SqlMapping, ArgumentError> {
Ok(SqlMapping::literal("daterange"))
}
fn return_sql() -> Result<Returns, ReturnsError> {
Ok(Returns::One(SqlMapping::literal("daterange")))
}
}
unsafe impl SqlTranslatable for Range<TimestampWithTimeZone> {
fn argument_sql() -> Result<SqlMapping, ArgumentError> {
Ok(SqlMapping::literal("tstzrange"))
}
fn return_sql() -> Result<Returns, ReturnsError> {
Ok(Returns::One(SqlMapping::literal("tstzrange")))
}
}
unsafe impl SqlTranslatable for Range<Timestamp> {
fn argument_sql() -> Result<SqlMapping, ArgumentError> {
Ok(SqlMapping::literal("tsrange"))
}
fn return_sql() -> Result<Returns, ReturnsError> {
Ok(Returns::One(SqlMapping::literal("tsrange")))
}
}