use crate::{type_info::PgType, PgArgumentBuffer, PgHasArrayType, PgTypeInfo, Postgres};
use core::cell::Cell;
use sqlx_core::{
database::Database,
encode::{Encode, IsNull},
error::BoxDynError,
types::Type,
};
pub struct PgBindIter<I>(Cell<Option<I>>);
pub trait PgBindIterExt: Iterator + Sized {
fn bind_iter(self) -> PgBindIter<Self>;
}
impl<I: Iterator + Sized> PgBindIterExt for I {
fn bind_iter(self) -> PgBindIter<I> {
PgBindIter(Cell::new(Some(self)))
}
}
impl<I> Type<Postgres> for PgBindIter<I>
where
I: Iterator,
<I as Iterator>::Item: Type<Postgres> + PgHasArrayType,
{
fn type_info() -> <Postgres as Database>::TypeInfo {
<I as Iterator>::Item::array_type_info()
}
fn compatible(ty: &PgTypeInfo) -> bool {
<I as Iterator>::Item::array_compatible(ty)
}
}
impl<'q, I> PgBindIter<I>
where
I: Iterator,
<I as Iterator>::Item: Type<Postgres> + Encode<'q, Postgres>,
{
fn encode_inner(
mut iter: I,
buf: &mut PgArgumentBuffer,
) -> Result<IsNull, BoxDynError> {
let lower_size_hint = iter.size_hint().0;
let first = iter.next();
let type_info = first
.as_ref()
.and_then(Encode::produces)
.unwrap_or_else(<I as Iterator>::Item::type_info);
buf.extend(&1_i32.to_be_bytes()); buf.extend(&0_i32.to_be_bytes());
match type_info.0 {
PgType::DeclareWithName(name) => buf.patch_type_by_name(&name),
PgType::DeclareArrayOf(array) => buf.patch_array_type(array),
ty => {
buf.extend(&ty.oid().0.to_be_bytes());
}
}
let len_start = buf.len();
buf.extend(0_i32.to_be_bytes()); buf.extend(1_i32.to_be_bytes());
match first {
Some(first) => buf.encode(first)?,
None => return Ok(IsNull::No),
}
let mut count = 1_i32;
const MAX: usize = i32::MAX as usize - 1;
for value in (&mut iter).take(MAX) {
buf.encode(value)?;
count += 1;
}
const OVERFLOW: usize = i32::MAX as usize + 1;
if iter.next().is_some() {
let iter_size = std::cmp::max(lower_size_hint, OVERFLOW);
return Err(format!("encoded iterator is too large for Postgres: {iter_size}").into());
}
buf[len_start..(len_start + 4)].copy_from_slice(&count.to_be_bytes());
Ok(IsNull::No)
}
}
impl<'q, I> Encode<'q, Postgres> for PgBindIter<I>
where
I: Iterator,
<I as Iterator>::Item: Type<Postgres> + Encode<'q, Postgres>,
{
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
Self::encode_inner(self.0.take().expect("PgBindIter is only used once"), buf)
}
fn encode(self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError>
where
Self: Sized,
{
Self::encode_inner(
self.0.into_inner().expect("PgBindIter is only used once"),
buf,
)
}
}