use bytes::{BufMut, BytesMut};
use crate::error::{Error, Result};
use crate::types::{FromSql, Oid, ToSql};
const RANGE_EMPTY: u8 = 0x01;
const RANGE_LB_INC: u8 = 0x02;
const RANGE_UB_INC: u8 = 0x04;
const RANGE_LB_INF: u8 = 0x08;
const RANGE_UB_INF: u8 = 0x10;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RangeBound<T> {
Inclusive(T),
Exclusive(T),
Unbounded,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PgRange<T> {
pub lower: RangeBound<T>,
pub upper: RangeBound<T>,
pub is_empty: bool,
pub range_oid: Oid,
pub element_oid: Oid,
}
impl<T> PgRange<T> {
pub fn empty(range_oid: Oid, element_oid: Oid) -> Self {
PgRange {
lower: RangeBound::Unbounded,
upper: RangeBound::Unbounded,
is_empty: true,
range_oid,
element_oid,
}
}
}
impl<T: ToSql> ToSql for PgRange<T> {
fn oid(&self) -> Oid {
self.range_oid
}
fn to_sql(&self, buf: &mut BytesMut) -> Result<()> {
if self.is_empty {
buf.put_u8(RANGE_EMPTY);
return Ok(());
}
let mut flags: u8 = 0;
match &self.lower {
RangeBound::Inclusive(_) => flags |= RANGE_LB_INC,
RangeBound::Exclusive(_) => {}
RangeBound::Unbounded => flags |= RANGE_LB_INF,
}
match &self.upper {
RangeBound::Inclusive(_) => flags |= RANGE_UB_INC,
RangeBound::Exclusive(_) => {}
RangeBound::Unbounded => flags |= RANGE_UB_INF,
}
buf.put_u8(flags);
match &self.lower {
RangeBound::Inclusive(v) | RangeBound::Exclusive(v) => {
let len_pos = buf.len();
buf.put_i32(0); let data_start = buf.len();
v.to_sql(buf)?;
let data_len = (buf.len() - data_start) as i32;
buf[len_pos..len_pos + 4].copy_from_slice(&data_len.to_be_bytes());
}
RangeBound::Unbounded => {}
}
match &self.upper {
RangeBound::Inclusive(v) | RangeBound::Exclusive(v) => {
let len_pos = buf.len();
buf.put_i32(0); let data_start = buf.len();
v.to_sql(buf)?;
let data_len = (buf.len() - data_start) as i32;
buf[len_pos..len_pos + 4].copy_from_slice(&data_len.to_be_bytes());
}
RangeBound::Unbounded => {}
}
Ok(())
}
}
impl<T: FromSql> PgRange<T> {
pub fn from_sql_with_oids(buf: &[u8], range_oid: Oid, element_oid: Oid) -> Result<Self> {
if buf.is_empty() {
return Err(Error::Decode("range: empty buffer".into()));
}
let flags = buf[0];
if flags & RANGE_EMPTY != 0 {
return Ok(PgRange::empty(range_oid, element_oid));
}
let mut offset = 1;
let lower = if flags & RANGE_LB_INF != 0 {
RangeBound::Unbounded
} else {
if offset + 4 > buf.len() {
return Err(Error::Decode("range: lower bound truncated".into()));
}
let len = i32::from_be_bytes([
buf[offset],
buf[offset + 1],
buf[offset + 2],
buf[offset + 3],
]) as usize;
offset += 4;
if offset + len > buf.len() {
return Err(Error::Decode("range: lower bound data truncated".into()));
}
let val = T::from_sql(&buf[offset..offset + len])?;
offset += len;
if flags & RANGE_LB_INC != 0 {
RangeBound::Inclusive(val)
} else {
RangeBound::Exclusive(val)
}
};
let upper = if flags & RANGE_UB_INF != 0 {
RangeBound::Unbounded
} else {
if offset + 4 > buf.len() {
return Err(Error::Decode("range: upper bound truncated".into()));
}
let len = i32::from_be_bytes([
buf[offset],
buf[offset + 1],
buf[offset + 2],
buf[offset + 3],
]) as usize;
offset += 4;
if offset + len > buf.len() {
return Err(Error::Decode("range: upper bound data truncated".into()));
}
let val = T::from_sql(&buf[offset..offset + len])?;
if flags & RANGE_UB_INC != 0 {
RangeBound::Inclusive(val)
} else {
RangeBound::Exclusive(val)
}
};
Ok(PgRange {
lower,
upper,
is_empty: false,
range_oid,
element_oid,
})
}
}