Skip to main content

sqlx_postgres/
bind_iter.rs

1use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, Postgres};
2use core::cell::Cell;
3use sqlx_core::{
4    database::Database,
5    encode::{Encode, IsNull},
6    error::BoxDynError,
7    types::Type,
8};
9
10// not exported but pub because it is used in the extension trait
11pub struct PgBindIter<I>(Cell<Option<I>>);
12
13/// Iterator extension trait enabling iterators to encode arrays in Postgres.
14///
15/// Because of the blanket impl of `PgHasArrayType` for all references
16/// we can borrow instead of needing to clone or copy in the iterators
17/// and it still works
18///
19/// Previously, 3 separate arrays would be needed in this example which
20/// requires iterating 3 times to collect items into the array and then
21/// iterating over them again to encode.
22///
23/// This now requires only iterating over the array once for each field
24/// while using less memory giving both speed and memory usage improvements
25/// along with allowing much more flexibility in the underlying collection.
26///
27/// ```rust,no_run
28/// # async fn test_bind_iter() -> Result<(), sqlx::error::BoxDynError> {
29/// # use sqlx::types::chrono::{DateTime, Utc};
30/// # use sqlx::Connection;
31/// # fn people() -> &'static [Person] {
32/// #   &[]
33/// # }
34/// # let mut conn = <sqlx::Postgres as sqlx::Database>::Connection::connect("dummyurl").await?;
35/// use sqlx::postgres::PgBindIterExt;
36///
37/// #[derive(sqlx::FromRow)]
38/// struct Person {
39///     id: i64,
40///     name: String,
41///     birthdate: DateTime<Utc>,
42/// }
43///
44/// # let people: &[Person] = people();
45/// sqlx::query("insert into person(id, name, birthdate) select * from unnest($1, $2, $3)")
46///     .bind(people.iter().map(|p| p.id).bind_iter())
47///     .bind(people.iter().map(|p| &p.name).bind_iter())
48///     .bind(people.iter().map(|p| &p.birthdate).bind_iter())
49///     .execute(&mut conn)
50///     .await?;
51///
52/// # Ok(())
53/// # }
54/// ```
55pub trait PgBindIterExt: Iterator + Sized {
56    fn bind_iter(self) -> PgBindIter<Self>;
57}
58
59impl<I: Iterator + Sized> PgBindIterExt for I {
60    fn bind_iter(self) -> PgBindIter<I> {
61        PgBindIter(Cell::new(Some(self)))
62    }
63}
64
65impl<I> Type<Postgres> for PgBindIter<I>
66where
67    I: Iterator,
68    <I as Iterator>::Item: Type<Postgres> + PgHasArrayType,
69{
70    fn type_info() -> <Postgres as Database>::TypeInfo {
71        <I as Iterator>::Item::array_type_info()
72    }
73    fn compatible(ty: &PgTypeInfo) -> bool {
74        <I as Iterator>::Item::array_compatible(ty)
75    }
76}
77
78impl<'q, I> PgBindIter<I>
79where
80    I: Iterator,
81    <I as Iterator>::Item: Type<Postgres> + Encode<'q, Postgres>,
82{
83    fn encode_inner(
84        // need ownership to iterate
85        mut iter: I,
86        buf: &mut PgArgumentBuffer,
87    ) -> Result<IsNull, BoxDynError> {
88        let lower_size_hint = iter.size_hint().0;
89        let first = iter.next();
90        let type_info = first
91            .as_ref()
92            .and_then(Encode::produces)
93            .unwrap_or_else(<I as Iterator>::Item::type_info);
94
95        buf.extend(&1_i32.to_be_bytes()); // number of dimensions
96        buf.extend(&0_i32.to_be_bytes()); // flags
97
98        if let Some(oid) = type_info.oid() {
99            buf.extend(oid.0.to_be_bytes());
100        } else {
101            buf.push_hole(type_info);
102        }
103
104        let len_start = buf.len();
105        buf.extend(0_i32.to_be_bytes()); // len (unknown so far)
106        buf.extend(1_i32.to_be_bytes()); // lower bound
107
108        match first {
109            Some(first) => buf.encode(first)?,
110            None => return Ok(IsNull::No),
111        }
112
113        let mut count = 1_i32;
114        const MAX: usize = i32::MAX as usize - 1;
115
116        for value in (&mut iter).take(MAX) {
117            buf.encode(value)?;
118            count += 1;
119        }
120
121        const OVERFLOW: usize = i32::MAX as usize + 1;
122        if iter.next().is_some() {
123            let iter_size = std::cmp::max(lower_size_hint, OVERFLOW);
124            return Err(format!("encoded iterator is too large for Postgres: {iter_size}").into());
125        }
126
127        // set the length now that we know what it is.
128        buf[len_start..(len_start + 4)].copy_from_slice(&count.to_be_bytes());
129
130        Ok(IsNull::No)
131    }
132}
133
134impl<'q, I> Encode<'q, Postgres> for PgBindIter<I>
135where
136    I: Iterator,
137    <I as Iterator>::Item: Type<Postgres> + Encode<'q, Postgres>,
138{
139    fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
140        Self::encode_inner(self.0.take().expect("PgBindIter is only used once"), buf)
141    }
142    fn encode(self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError>
143    where
144        Self: Sized,
145    {
146        Self::encode_inner(
147            self.0.into_inner().expect("PgBindIter is only used once"),
148            buf,
149        )
150    }
151}