postgres_query/
extract.rs

1//! Extract typed values from rows.
2
3use postgres_types::FromSql;
4use std::collections::{BTreeSet, HashSet};
5use std::fmt::{Display, Write};
6use std::hash::Hash;
7use std::error::Error as StdError;
8use std::iter;
9use std::ops::Range;
10use thiserror::Error;
11use postgres_types::WasNull;
12use tokio_postgres::{error::Error as SqlError, row::RowIndex, Column};
13
14/// An error that can occur while extracting values from a row.
15#[derive(Debug, Error)]
16pub enum Error {
17    #[error("{msg}")]
18    Custom { msg: String },
19
20    #[error("invalid number of columns, found {found} but expected {expected}")]
21    ColumnCount { found: usize, expected: usize },
22
23    #[error("failed to get column: `{index}` (columns were: {columns})")]
24    SliceLookup { index: String, columns: String },
25
26    #[error("failed to split on: `{split}` (columns were: {columns})")]
27    InvalidSplit { split: String, columns: String },
28
29    #[error(
30        "failed to slice row on: `{start}..{end}` (len was: {len})", 
31        start = range.start,
32        end = range.end
33    )]
34    SliceIndex { range: Range<usize>, len: usize },
35
36    /// An error occured within postgres itself.
37    #[error("internal postgres error")]
38    Sql(#[from] SqlError),
39}
40
41impl Error {
42    /// Construct a new error message with a custom message.
43    pub fn new<D>(msg: D) -> Error
44    where
45        D: Display,
46    {
47        Error::Custom {
48            msg: msg.to_string(),
49        }
50    }
51
52    /// A soft error is an error that can be converted into an `Option::None`.
53    fn is_soft(&self) -> bool {
54        match self {
55            Error::Sql(sql) => {
56                let mut error: &dyn StdError = sql;
57                loop {
58                    if let Some(WasNull) = error.downcast_ref() {
59                        break true;
60                    }
61
62                    match error.source() {
63                        Some(source) => error = source,
64                        None => break false,
65                    }
66                }
67            }
68
69            _ => false,
70        }
71    }
72}
73
74mod private {
75    pub mod row {
76        pub trait Sealed {}
77    }
78}
79
80/// Anything that provides a row-like interface.
81///
82/// This trait is sealed and cannot be implemented for types outside of this crate.
83pub trait Row: private::row::Sealed {
84    /// Return the name and type of each column.
85    fn columns(&self) -> &[Column];
86
87    /// Attempt to get a cell in the row by the column name or index.
88    fn try_get<'a, I, T>(&'a self, index: I) -> Result<T, Error>
89    where
90        I: RowIndex + Display,
91        T: FromSql<'a>;
92
93    /// The number of values (columns) in the row.
94    fn len(&self) -> usize {
95        self.columns().len()
96    }
97
98    /// `true` if the value did not contain any values, `false` otherwise.
99    fn is_empty(&self) -> bool {
100        self.len() == 0
101    }
102
103    /// Attempt to get a cell in the row by the column name or index.
104    ///
105    /// # Panics
106    ///
107    /// - If no cell was found with the given index.
108    fn get<'a, I, T>(&'a self, index: I) -> T
109    where
110        I: RowIndex + Display,
111        T: FromSql<'a>,
112    {
113        match self.try_get::<I, T>(index) {
114            Ok(value) => value,
115            Err(err) => panic!("failed to retrieve column: {}", err),
116        }
117    }
118
119    /// Return a subslice of this row's columns.
120    fn slice(&self, range: Range<usize>) -> Result<RowSlice<Self>, Error>
121    where
122        Self: Sized,
123    {
124        if range.end > self.len() {
125            Err(Error::SliceIndex {
126                range,
127                len: self.len(),
128            })
129        } else {
130            let slice = RowSlice { row: self, range };
131            Ok(slice)
132        }
133    }
134}
135
136/// A contiguous subset of columns in a row.
137pub struct RowSlice<'a, R>
138where
139    R: Row,
140{
141    row: &'a R,
142    range: Range<usize>,
143}
144
145/// Extract values from a row.
146///
147/// May be derived for `struct`s using `#[derive(FromSqlRow)]`.
148///
149/// # Example
150///
151/// ```
152/// # use postgres_query_macro::FromSqlRow;
153/// # use postgres_types::Date;
154/// #[derive(FromSqlRow)]
155/// struct Person {
156///     age: i32,
157///     name: String,
158///     birthday: Option<Date<String>>,
159/// }
160/// ```
161pub trait FromSqlRow: Sized {
162    /// Number of columns required to construct this type.
163    ///
164    /// IMPORTANT: if not set correctly, extractors which depend on this value may produce errors.
165    const COLUMN_COUNT: usize;
166
167    /// Extract values from a single row.
168    fn from_row<R>(row: &R) -> Result<Self, Error>
169    where
170        R: Row;
171
172    /// Extract values from multiple rows.
173    ///
174    /// Implementors of this trait may override this method to enable optimizations not possible in
175    /// [`from_row`] by, for example, only looking up the indices of columns with a specific name
176    /// once.
177    ///
178    /// [`from_row`]: #tymethod.from_row
179    fn from_row_multi<R>(rows: &[R]) -> Result<Vec<Self>, Error>
180    where
181        R: Row,
182    {
183        rows.iter().map(Self::from_row).collect()
184    }
185}
186
187/// For collections that can be built from single elements.
188///
189/// Used by `#[derive(FromSqlRow)]` when a field is tagged with the attribute `#[row(merge)]`.
190pub trait Merge {
191    /// The type of item being merged.
192    type Item;
193
194    /// Insert one item into this collection.
195    fn insert(&mut self, item: Self::Item);
196}
197
198impl<T> Merge for Vec<T> {
199    type Item = T;
200    fn insert(&mut self, item: T) {
201        self.push(item)
202    }
203}
204
205impl<T> Merge for HashSet<T>
206where
207    T: Hash + Eq,
208{
209    type Item = T;
210    fn insert(&mut self, item: T) {
211        HashSet::insert(self, item);
212    }
213}
214
215impl<T> Merge for BTreeSet<T>
216where
217    T: Ord,
218{
219    type Item = T;
220    fn insert(&mut self, item: T) {
221        BTreeSet::insert(self, item);
222    }
223}
224
225impl private::row::Sealed for tokio_postgres::Row {}
226
227impl Row for tokio_postgres::Row {
228    fn columns(&self) -> &[Column] {
229        tokio_postgres::Row::columns(self)
230    }
231
232    fn try_get<'a, I, T>(&'a self, index: I) -> Result<T, Error>
233    where
234        I: RowIndex + Display,
235        T: FromSql<'a>,
236    {
237        tokio_postgres::Row::try_get(self, index).map_err(Error::from)
238    }
239
240    fn len(&self) -> usize {
241        tokio_postgres::Row::len(self)
242    }
243
244    fn is_empty(&self) -> bool {
245        tokio_postgres::Row::is_empty(self)
246    }
247
248    fn get<'a, I, T>(&'a self, index: I) -> T
249    where
250        I: RowIndex + Display,
251        T: FromSql<'a>,
252    {
253        tokio_postgres::Row::get(self, index)
254    }
255}
256
257impl<R> private::row::Sealed for RowSlice<'_, R> where R: Row {}
258
259impl<R> Row for RowSlice<'_, R>
260where
261    R: Row,
262{
263    fn columns(&self) -> &[Column] {
264        &self.row.columns()[self.range.clone()]
265    }
266
267    fn try_get<'a, I, T>(&'a self, index: I) -> Result<T, Error>
268    where
269        I: RowIndex + Display,
270        T: FromSql<'a>,
271    {
272        if let Some(index) = index.__idx(self.columns()) {
273            self.row.try_get(self.range.start + index)
274        } else {
275            Err(Error::SliceLookup {
276                index: index.to_string(),
277                columns: format_columns(self.columns()),
278            })
279        }
280    }
281}
282
283impl<R> RowSlice<'_, R>
284where
285    R: Row,
286{
287    /// Return a subslice of this row's columns.
288    ///
289    /// This is an optimized version of `Row::slice` which reduces the number of
290    /// pointer-indirections.
291    pub fn slice(&self, range: Range<usize>) -> Result<RowSlice<R>, Error>
292    where
293        Self: Sized,
294    {
295        if range.end > self.range.end {
296            Err(Error::SliceIndex {
297                range,
298                len: self.range.end,
299            })
300        } else {
301            let slice = RowSlice {
302                row: self.row,
303                range,
304            };
305            Ok(slice)
306        }
307    }
308}
309
310/// Split a row's columns into multiple partitions based on some split-points.
311///
312/// # Split
313///
314/// Given a list of column labels, a split is made right before the first column with a matching
315/// name following the previous split:
316///
317/// ```text
318/// Labels:       a,    a,      c,  a
319/// Indices:      0 1 2 3 4 5 6 7 8 9 10
320/// Columns:      a b c a b a b c b a c
321/// Splits:      |     |       |   |   
322/// Partitions: + +---+ +-----+ +-+ +-+
323/// Ranges:     [0..0, 0..3, 3..7, 7..9, 9..11]`
324/// ```
325///
326/// The first partition always contains the leading columns (zero or more):
327///
328/// ```text
329/// Labels:         b,  a
330/// Indices:    0 1 2 3 4 5
331/// Columns:    d a b c a b
332/// Splits:        |   |
333/// Partitions: +-+ +-+ +-+
334/// Ranges:     [0..2, 2..4, 4..6]
335/// ```
336///
337/// # Errors
338///
339/// Will return an error if the columns could not be split (ie. no column with a matching name was
340/// found in the remaining columns).
341pub fn split_columns_many<'a, S>(
342    columns: &'a [Column],
343    splits: &'a [S],
344) -> impl Iterator<Item = Result<Range<usize>, Error>> + 'a
345where
346    S: AsRef<str>,
347{
348    let column_names = columns.iter().map(|col| col.name());
349    partition_many(column_names, splits.iter()).map(move |split| match split {
350        SplitResult::Range(range) => Ok(range),
351        SplitResult::NotFound { split, start } => Err(Error::InvalidSplit {
352            split,
353            columns: format_columns(&columns[start..]),
354        }),
355    })
356}
357
358#[cfg_attr(test, derive(Debug, PartialEq))]
359enum SplitResult {
360    NotFound { split: String, start: usize },
361    Range(Range<usize>),
362}
363
364fn partition_many<'a>(
365    columns: impl Iterator<Item = impl AsRef<str> + 'a> + 'a,
366    splits: impl Iterator<Item = impl AsRef<str> + 'a> + 'a,
367) -> impl Iterator<Item = SplitResult> + 'a {
368    let mut columns = columns.enumerate();
369    let mut splits = splits;
370
371    let mut previous_end = 0;
372
373    iter::from_fn(move || -> Option<_> {
374        if let Some(split) = splits.next() {
375            let split = split.as_ref();
376            if let Some((end, _)) = columns.find(|(_, name)| name.as_ref() == split) {
377                let range = previous_end..end;
378                previous_end = end;
379                Some(SplitResult::Range(range))
380            } else {
381                Some(SplitResult::NotFound {
382                    split: split.to_owned(),
383                    start: previous_end,
384                })
385            }
386        } else {
387            let (last, _) = columns.by_ref().last()?;
388            let len = last + 1;
389            Some(SplitResult::Range(previous_end..len))
390        }
391    })
392}
393
394fn format_columns(columns: &[Column]) -> String {
395    let mut total = String::with_capacity(16 * columns.len());
396    for col in columns {
397        if !total.is_empty() {
398            total.push_str(", ");
399        }
400        write!(total, "`{}`", col.name()).unwrap();
401    }
402    total
403}
404
405mod from_row_sql_impls {
406    use super::*;
407
408    use std::rc::Rc;
409    use std::sync::Arc;
410
411    macro_rules! impl_from_row_for_tuple {
412        (($($elem:ident),+)) => {
413            impl<$($elem),+> FromSqlRow for ($($elem,)+)
414                where $($elem: for<'a> FromSql<'a> + std::fmt::Display),+
415                {
416                    const COLUMN_COUNT: usize = impl_from_row_for_tuple!(@count ($($elem),*));
417
418                    fn from_row<R>(row: &R) -> Result<Self, Error>
419                        where R: Row {
420                            if row.len() != Self::COLUMN_COUNT {
421                                Err(Error::ColumnCount {
422                                    expected: Self::COLUMN_COUNT,
423                                    found: row.len(),
424                                })
425                            } else {
426                                let result = (
427                                    $(
428                                        row.try_get::<usize, $elem>(
429                                            impl_from_row_for_tuple!(@index $elem)
430                                        )?,
431                                    )+
432                                );
433
434                                Ok(result)
435                            }
436                        }
437                }
438        };
439
440        (@index A) => { 0 };
441        (@index B) => { 1 };
442        (@index C) => { 2 };
443        (@index D) => { 3 };
444        (@index E) => { 4 };
445        (@index F) => { 5 };
446        (@index G) => { 6 };
447        (@index H) => { 7 };
448
449        (@count ()) => { 0 };
450        (@count ($head:ident $(, $tail:ident)*)) => {{
451            1 + impl_from_row_for_tuple!(@count ($($tail),*))
452        }};
453    }
454
455    impl_from_row_for_tuple!((A));
456    impl_from_row_for_tuple!((A, B));
457    impl_from_row_for_tuple!((A, B, C));
458    impl_from_row_for_tuple!((A, B, C, D));
459    impl_from_row_for_tuple!((A, B, C, D, E));
460    impl_from_row_for_tuple!((A, B, C, D, E, F));
461    impl_from_row_for_tuple!((A, B, C, D, E, F, G));
462    impl_from_row_for_tuple!((A, B, C, D, E, F, G, H));
463
464    impl<T> FromSqlRow for Option<T>
465    where
466        T: FromSqlRow,
467    {
468        const COLUMN_COUNT: usize = T::COLUMN_COUNT;
469
470        fn from_row<R>(row: &R) -> Result<Self, Error>
471        where
472            R: Row,
473        {
474            match T::from_row(row) {
475                Ok(value) => Ok(Some(value)),
476                Err(error) if error.is_soft() => Ok(None),
477                Err(error) => Err(error),
478            }
479        }
480    }
481
482    impl<T, E> FromSqlRow for Result<T, E>
483    where
484        T: FromSqlRow,
485        E: From<Error>,
486    {
487        const COLUMN_COUNT: usize = T::COLUMN_COUNT;
488
489        fn from_row<R>(row: &R) -> Result<Self, Error>
490        where
491            R: Row,
492        {
493            match T::from_row(row) {
494                Ok(value) => Ok(Ok(value)),
495                Err(error) => Ok(Err(E::from(error))),
496            }
497        }
498    }
499
500    macro_rules! impl_from_row_for_wrapper {
501        ($wrapper:ident, $constructor:expr) => {
502            impl<T> FromSqlRow for $wrapper<T>
503            where
504                T: FromSqlRow,
505            {
506                const COLUMN_COUNT: usize = T::COLUMN_COUNT;
507
508                fn from_row<R>(row: &R) -> Result<Self, Error>
509                where
510                    R: Row,
511                {
512                    let value = T::from_row(row)?;
513                    Ok($constructor(value))
514                }
515            }
516        };
517    }
518
519    impl_from_row_for_wrapper!(Box, Box::new);
520    impl_from_row_for_wrapper!(Rc, Rc::new);
521    impl_from_row_for_wrapper!(Arc, Arc::new);
522}
523
524#[cfg(test)]
525mod tests {
526    use super::*;
527
528    fn split_chars_fallible<'a>(
529        columns: &'a str,
530        splits: &'a str,
531    ) -> impl Iterator<Item = SplitResult> + 'a {
532        let cols = columns.chars().map(|ch| ch.to_string());
533        let splits = splits.chars().map(|ch| ch.to_string());
534        partition_many(cols, splits)
535    }
536
537    fn split_chars<'a>(
538        columns: &'a str,
539        splits: &'a str,
540    ) -> impl Iterator<Item = Range<usize>> + 'a {
541        let cols = columns.chars().map(|ch| ch.to_string());
542        let splits = splits.chars().map(|ch| ch.to_string());
543        partition_many(cols, splits).map(move |split| match split {
544            SplitResult::Range(range) => range,
545            SplitResult::NotFound { split, start } => panic!(
546                "failed to split {:?} on {:?}",
547                columns.chars().skip(start).collect::<String>(),
548                split,
549            ),
550        })
551    }
552
553    #[test]
554    fn split_columns_many_no_excess() {
555        let partitions = split_chars("abcabdab", "aaa").collect::<Vec<_>>();
556        assert_eq!(partitions, vec![0..0, 0..3, 3..6, 6..8,])
557    }
558
559    #[test]
560    fn split_columns_many_leading_columns() {
561        let partitions = split_chars("deabcabdab", "aaa").collect::<Vec<_>>();
562        assert_eq!(partitions, vec![0..2, 2..5, 5..8, 8..10,])
563    }
564
565    #[test]
566    fn split_columns_many_too_many_splits() {
567        let partitions = split_chars_fallible("abcabc", "aaa").collect::<Vec<_>>();
568        assert_eq!(
569            partitions,
570            vec![
571                SplitResult::Range(0..0),
572                SplitResult::Range(0..3),
573                SplitResult::NotFound {
574                    split: "a".to_owned(),
575                    start: 3,
576                }
577            ]
578        )
579    }
580}