use sqlx_core::bytes::Buf;
use sqlx_core::types::Text;
use std::borrow::Cow;
use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
use crate::types::Oid;
use crate::types::Type;
use crate::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres};
pub trait PgHasArrayType {
fn array_type_info() -> PgTypeInfo;
fn array_compatible(ty: &PgTypeInfo) -> bool {
*ty == Self::array_type_info()
}
}
impl<T> PgHasArrayType for &T
where
T: PgHasArrayType,
{
fn array_type_info() -> PgTypeInfo {
T::array_type_info()
}
fn array_compatible(ty: &PgTypeInfo) -> bool {
T::array_compatible(ty)
}
}
impl<T> PgHasArrayType for Option<T>
where
T: PgHasArrayType,
{
fn array_type_info() -> PgTypeInfo {
T::array_type_info()
}
fn array_compatible(ty: &PgTypeInfo) -> bool {
T::array_compatible(ty)
}
}
impl<T> PgHasArrayType for Text<T> {
fn array_type_info() -> PgTypeInfo {
String::array_type_info()
}
fn array_compatible(ty: &PgTypeInfo) -> bool {
String::array_compatible(ty)
}
}
impl<T> Type<Postgres> for [T]
where
T: PgHasArrayType,
{
fn type_info() -> PgTypeInfo {
T::array_type_info()
}
fn compatible(ty: &PgTypeInfo) -> bool {
T::array_compatible(ty)
}
}
impl<T> Type<Postgres> for Vec<T>
where
T: PgHasArrayType,
{
fn type_info() -> PgTypeInfo {
T::array_type_info()
}
fn compatible(ty: &PgTypeInfo) -> bool {
T::array_compatible(ty)
}
}
impl<T, const N: usize> Type<Postgres> for [T; N]
where
T: PgHasArrayType,
{
fn type_info() -> PgTypeInfo {
T::array_type_info()
}
fn compatible(ty: &PgTypeInfo) -> bool {
T::array_compatible(ty)
}
}
impl<'q, T> Encode<'q, Postgres> for Vec<T>
where
for<'a> &'a [T]: Encode<'q, Postgres>,
T: Encode<'q, Postgres>,
{
#[inline]
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
self.as_slice().encode_by_ref(buf)
}
}
impl<'q, T, const N: usize> Encode<'q, Postgres> for [T; N]
where
for<'a> &'a [T]: Encode<'q, Postgres>,
T: Encode<'q, Postgres>,
{
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
self.as_slice().encode_by_ref(buf)
}
}
impl<'q, T> Encode<'q, Postgres> for &'_ [T]
where
T: Encode<'q, Postgres> + Type<Postgres>,
{
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
i32::try_from(self.len()).map_err(|_| {
format!(
"encoded array length is too large for Postgres: {}",
self.len()
)
})?;
crate::PgBindIterExt::bind_iter(self.iter()).encode(buf)
}
}
impl<'r, T, const N: usize> Decode<'r, Postgres> for [T; N]
where
T: for<'a> Decode<'a, Postgres> + Type<Postgres>,
{
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
let vec: Vec<T> = Decode::decode(value)?;
let array: [T; N] = vec.try_into().map_err(|_| "wrong number of elements")?;
Ok(array)
}
}
impl<'r, T> Decode<'r, Postgres> for Vec<T>
where
T: for<'a> Decode<'a, Postgres> + Type<Postgres>,
{
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
let format = value.format();
match format {
PgValueFormat::Binary => {
let mut buf = value.as_bytes()?;
let ndim = buf.get_i32();
if ndim == 0 {
return Ok(Vec::new());
}
if ndim != 1 {
return Err(format!("encountered an array of {ndim} dimensions; only one-dimensional arrays are supported").into());
}
let _flags = buf.get_i32();
let element_type_oid = Oid(buf.get_u32());
let element_type_info: PgTypeInfo = PgTypeInfo::try_from_oid(element_type_oid)
.or_else(|| value.type_info.try_array_element().map(Cow::into_owned))
.ok_or_else(|| {
BoxDynError::from(format!(
"failed to resolve array element type for oid {}",
element_type_oid.0
))
})?;
let len = buf.get_i32();
let len = usize::try_from(len)
.map_err(|_| format!("overflow converting array len ({len}) to usize"))?;
let lower = buf.get_i32();
if lower != 1 {
return Err(format!("encountered an array with a lower bound of {lower} in the first dimension; only arrays starting at one are supported").into());
}
let mut elements = Vec::with_capacity(len);
for _ in 0..len {
let value_ref = PgValueRef::get(&mut buf, format, element_type_info.clone())?;
elements.push(T::decode(value_ref)?);
}
Ok(elements)
}
PgValueFormat::Text => {
let element_type_info = T::type_info();
let s = value.as_str()?;
let s = &s[1..(s.len() - 1)];
if s.is_empty() {
return Ok(Vec::new());
}
let delimiter = ',';
let mut done = false;
let mut in_quotes = false;
let mut in_escape = false;
let mut value = String::with_capacity(10);
let mut chars = s.chars();
let mut elements = Vec::with_capacity(4);
while !done {
loop {
match chars.next() {
Some(ch) => match ch {
_ if in_escape => {
value.push(ch);
in_escape = false;
}
'"' => {
in_quotes = !in_quotes;
}
'\\' => {
in_escape = true;
}
_ if ch == delimiter && !in_quotes => {
break;
}
_ => {
value.push(ch);
}
},
None => {
done = true;
break;
}
}
}
let value_opt = if value == "NULL" {
None
} else {
Some(value.as_bytes())
};
elements.push(T::decode(PgValueRef {
value: value_opt,
row: None,
type_info: element_type_info.clone(),
format,
})?);
value.clear();
}
Ok(elements)
}
}
}
}